/*
* This class is for artificial neural network prediction.
* You can load Visual Gene Developer's trained neural network file (*.vgn)
* then use it to predict output values for any given input values.
* Currently, this class does not include modules for training networks
* In addition, only hyperbolic tangent function (transfer function) is supported
*
* THIS PROGRAM IS DISTRIBUTED "AS IS".
* NO WARRANTY OF ANY KIND IS EXPRESSED OR IMPLIED.
* YOU USE THE PROGRAM AT YOUR OWN RISK.
* THE AUTHOR WILL NOT BE LIABLE FOR DATA LOSS, DAMAGES,
* LOSS OF PROFITS OR ANY OTHER KIND OF LOSS WHILE USING
* OR MISUSING THIS SOFTWARE.
* ANYONE CAN USE AND MODIFY CODES WITHOUT CHARGE.
*/
import java.io.File;
import java.util.List;
import javax.swing.JOptionPane;
/**
* @author SangKyu
*/
public class NeuralNet {
public int maxLayerCount = 5;
public int maxNodeCount = 200;
public int hiddenLayerCount;
public int inputCount;
public int outputCount;
public int[] nodeCountInHiddenlayers = new int[maxLayerCount + 1];
public double[][] nodeValue = new double[maxLayerCount + 1][maxNodeCount];
public double[][] nodeThreshold = new double[maxLayerCount + 1][maxNodeCount];
public double[][] nodeTotAct = new double[maxLayerCount + 1][maxNodeCount];
public double[][][] connectWeightFactor
= new double[maxLayerCount + 1][maxNodeCount][maxNodeCount];
public void test() {
//User your own path
String trainedNetworkFile = "D:\\....."
+ "\\Sample SinCos - Trained network.vgn";
String retMsg = openTrainedNetworkFile(trainedNetworkFile);
if (retMsg.equals("")) {
double[] outputValues = predict(new double[]{0.37146, 0.88627, 0.41384});
if (outputValues != null )
JOptionPane.showMessageDialog(null, "Output2= " + outputValues[1]);
//Correct result: 0.6740054...
} else {
JOptionPane.showMessageDialog(null, retMsg);
}
}
//Get transfer function
//Hyperbolic tangent function is calculated
public double computeTransferFunction(double inX) {
//Hyperbolic tangent
return (Math.exp(inX) - Math.exp(-inX)) / (Math.exp(inX) + Math.exp(-inX));
}
//Predict output values for given inputValues
public double[] predict(double[] inputValues) {
if (inputValues.length != inputCount) return null;
//Assign input values
System.arraycopy(inputValues, 0, nodeValue[0], 0, inputCount);
propagateNetwork();
//Get output values
double[] outputValues = new double[outputCount];
System.arraycopy(nodeValue[hiddenLayerCount + 2], 0,
outputValues, 0, outputCount);
return outputValues;
}
//Propagate network
private void propagateNetwork() {
for (int w_n = 0; w_n < nodeCountInHiddenlayers[1]; w_n++) {
nodeValue[1][w_n]
= computeTransferFunction(nodeValue[0][w_n]
+ nodeThreshold[1][w_n]);
}
double current_Sum;
for (int cur_Layer = 2; cur_Layer <= hiddenLayerCount + 2; cur_Layer++) {
for (int w_n = 0; w_n < nodeCountInHiddenlayers[cur_Layer]; w_n++) {
current_Sum = 0;
for (int w_n_1 = 0; w_n_1 < nodeCountInHiddenlayers[cur_Layer - 1]; w_n_1++) {
current_Sum = current_Sum
+ connectWeightFactor[cur_Layer][w_n_1][w_n]
* nodeValue[cur_Layer - 1][w_n_1];
}
current_Sum = current_Sum + nodeThreshold[cur_Layer][w_n];
nodeValue[cur_Layer][w_n]
= computeTransferFunction(current_Sum);
nodeTotAct[cur_Layer][w_n] = current_Sum;
}
}
}
//Return "" if everything is OK
//Return error message if something is wrong
public String openTrainedNetworkFile(String fileName) {
List<String> linesList
= Utilities.getLinesFromFile(new File(fileName));
if (linesList == null) {
return "Empty file";
}
String[] strLines = linesList.toArray(new String[linesList.size()]);
if (!strLines[1].equals("Name=Visual Gene Developer - Neural Network")) {
return "Not valid trained neuralnet file";
}
for (int curLine = 0; curLine < strLines.length; curLine++) {
if (strLines[curLine].startsWith("Total input=")) {
inputCount = getIntegerFromString(strLines[curLine], "=");
} else if (strLines[curLine].startsWith("Total output=")) {
outputCount = getIntegerFromString(strLines[curLine], "=");
} else if (strLines[curLine].startsWith("Total layer=")) {
hiddenLayerCount = getIntegerFromString(strLines[curLine], "=") - 2;
} else if (strLines[curLine].startsWith("Transfer function=")) {
if (!strLines[curLine].equals("Transfer function=Hyperbolic tangent")) {
return "Transfer function is not hyperbolic tangent";
}
} else if (strLines[curLine].startsWith("layer=total node")) {
for (int i = 0; i < hiddenLayerCount; i++) {
curLine++;
nodeCountInHiddenlayers[i + 2]
= getIntegerFromString(strLines[curLine], "=");
}
} else if (strLines[curLine].startsWith("layer-node=threshold value")) {
curLine++;
do {
int curLayer = getInteger(strLines[curLine].substring(0, 2));
int node = getInteger(strLines[curLine].substring(3, 5));
double threshold = getDoubleFromString(strLines[curLine], "=");
nodeThreshold[curLayer][node - 1] = threshold;
curLine++;
if (curLine >= strLines.length) {
break;
}
} while (strLines[curLine].equals("") == false);
} else if (strLines[curLine].startsWith(
"layer-node(layer n-1)-node(layer n)=weight factor")) {
curLine++;
do {
int curLayer = getInteger(strLines[curLine].substring(0, 2));
int node1 = getInteger(strLines[curLine].substring(3, 5));
int node2 = getInteger(strLines[curLine].substring(6, 8));
double weightFactor = getDoubleFromString(strLines[curLine], "=");
connectWeightFactor[curLayer][node1 - 1][node2 - 1] = weightFactor;
curLine++;
if (curLine >= strLines.length) {
break;
}
} while (strLines[curLine].equals("") == false);
}
}
nodeCountInHiddenlayers[0] = inputCount;
nodeCountInHiddenlayers[1] = inputCount;
nodeCountInHiddenlayers[hiddenLayerCount + 2] = outputCount;
return "";
}
private int getIntegerFromString(String srcStr, String valueSeparator) {
return getInteger(srcStr.substring(srcStr.indexOf(valueSeparator) + 1));
}
private double getDoubleFromString(String srcStr, String valueSeparator) {
return getDouble(srcStr.substring(srcStr.indexOf(valueSeparator) + 1));
}
public Integer getInteger(String str) {
str = str.trim();
if (str == null) {
return null;
}
Integer ret;
try {
ret = new Integer(str);
} catch (NumberFormatException e) {
return null;
}
return ret;
}
public Double getDouble(String str) {
str = str.trim();
if (str == null) {
return null;
}
Double ret;
try {
ret = new Double(str);
} catch (NumberFormatException e) {
return null;
}
return ret;
}
}