Machine Learning with Java libraries
When I think about Machine Learning there comes R and Python into my mind. There’s a nice set of ML libraries and packages that can be used to perform analysis or visualize data in both of those languages. But when it comes to Java ML libraries there aren’t that many. Of course there are nice Java frameworks, but they are mostly designed in such a way that you don’t actually do the coding. So how can a Java programmer easily incorporate ML into their application? I used 2 libraries which allowed me to do exactly that.
Contents
Weka
Weka is a data mining software written in Java. It provides many Machine Learning algorithms, that can be used out of the box. Weka has a nice clickable GUI, with which the end-to-end analysis can be performed. Additionally it also provides the Java API which allows to use ML in any Java app. There are many algorithms implemented, for supervised and unsupervised ML, for statistics, data analysis, filtering, attribute’s extraction and more. It allows to perform simple analysis without going deep into the algorithm’s implementation and it’s very useful for advanced users as well.
HOW TO
All the libraries can be imported with Maven:
1 2 3 4 5 |
<dependency> <groupid>nz.ac.waikato.cms.weka</groupid> <artifactid>weka-stable</artifactid> <version>3.8.2</version> </dependency> |
Weka provides nice tutorial to get the work started. Although it takes some time to apply more complex analysis.
Classification tree example
Before Weka version 3.5.5 we had to start the analysis with formatting our data to the ARFF format. We don’t have to do it any more, Weka handles the conversion by itself.
Example is based on public Breast Cancer Wisconsin dataset.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import weka.classifiers.evaluation.Evaluation; import weka.classifiers.evaluation.Prediction; import weka.classifiers.trees.J48; import weka.core.Instances; import weka.core.converters.ConverterUtils.DataSource; import weka.core.Attribute; import weka.filters.Filter; import weka.filters.unsupervised.attribute.NumericToNominal; import java.util.ArrayList; ... // First we need to load our data into Instances object DataSource source = new DataSource("breast-cancer-wisconsin.csv"); Instances data = source.getDataSet(); // We can check some statistics System.out.println("Number of data instances " + data.numInstances()); System.out.println("Number of data attributes " + data.numAttributes()); // Delete the id column data.deleteAttributeAt(0); // Change numeric format of Class attribute into actual values NumericToNominal convert = new NumericToNominal(); String[] options = new String[2]; options[0] = "-R"; //replace options[1] = "10"; // range of variables, here we start counting from 1! convert.setOptions(options); convert.setInputFormat(data); data = Filter.useFilter(data, convert); // Run the conversion // We need to tell Weka explicitly which column is our class column data.setClassIndex(data.numAttributes() - 1); System.out.println("Number of classes " + data.numClasses()); // Define train set size int trainSize = (int) Math.round(data.numInstances() * 0.8); // Create an unprunned J48 tree classifier // Note: āJā for Java, 48 for C4.8 algorithm, hence the J48 name J48 tree = new J48(); tree.setUnpruned(true); // Build classifier int testSize = data.numInstances() - trainSize; Instances trainSet = new Instances(data,0,trainSize); Instances testSet = new Instances(data, trainSize, testSize); tree.buildClassifier(trainSet); Evaluation evaluation = new Evaluation(trainSet); evaluation.evaluateModel(tree, testSet); // We can check some statistics and create Prediction objects System.out.println(evaluation.toSummaryString()); System.out.println("False negative rate: " + evaluation.falseNegativeRate(1)); System.out.println("False possitive rate: " + evaluation.falsePositiveRate(1)); System.out.println("True negative rate: " + evaluation.trueNegativeRate(1)); System.out.println("True positive rate: " + evaluation.truePositiveRate(1)); System.out.println(evaluation.toMatrixString()); // Confusion matrix // Prediction encapsulates a single evaluatable prediction: // the predicted value plus the actual class value ArrayList<Prediction> predictions = new ArrayList<>(); predictions.addAll(evaluation.predictions()); // Misclassified records for (int i = 0; i < predictions.size(); i++) { Prediction pred = predictions.get(i); if(pred.predicted() != pred.actual()){ System.out.println(data.get(i + trainSize)); } } |
Apache Commons Math
As name suggests Apache Commons Math is a Java library with mathematics and statistics components. Those components can be easily added into the code and used for analysis. You may find there stuff for geometry, algebra, neural networks, matrices, machine learning and more. Examples mentioned on the main page do not cover all the cases that can be done with Commons Math, but fortunately the docs complement that.
HOW TO
Libraries for Apache Commons Math can be imported with Maven:
1 2 3 4 5 |
<dependency> <groupid>org.apache.commons</groupid> <artifactid>commons-math3</artifactid> <version>3.6.1</version> </dependency> |
k-means example
For this example I used publicly available Wine dataset.
In order to use Apache Commons Math clustering algorithm we need to make our data Clusterable. Let’s create a wrapper class WineCluster for the Wine objects, which we create from the dataset.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import org.apache.commons.math3.ml.clustering.Clusterable; public class WineCluster implements Clusterable{ private double[] points; private Wine wine; //the object representing a line from the file public WineCluster(Wine wine){ this.wine = wine; this.points = new double[] {wine.getAlcohol(), wine.getMalicAcid(), wine.getAsh(), wine.getAlcalinityOfAsh(), wine.getMagnesium(), wine.getTotalPhenols(), wine.getFlavanoids(), wine.getNonflavanoidPhenols(), wine.getProanthocyanins(), wine.getColorIntensity(), wine.getHue(), wine.getODRatio(), wine.getProline()}; } public Wine getWine() { return wine; } public double[] getPoint() { return points; } public void setPoint(double[] points){ this.points = points; } public double getValue(int i){ return points[i]; } public void setValue(int i, double value){ this.points[i] = value; } } |
Now Wine objects can be wrapped by WineCluster objects.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import java.io.BufferedReader; import java.io.BufferedWriter; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; import org.apache.commons.math3.stat.StatUtils; import org.apache.commons.math3.ml.clustering.CentroidCluster; import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer; import org.apache.commons.math3.ml.clustering.Clusterable; ... List<WineCluster> clusterInput = new ArrayList<WineCluster>(); Path inputPath = Paths.get("wines.data"); BufferedReader br = Files.newBufferedReader(inputPath,StandardCharsets.UTF_8); // Wine objects need to be wrapped by WineCluster objects so that they // implement Clusterable String line = ""; while ((line = br.readLine()) != null) { // Parse line to extract individual fields String[] data = line.split(","); clusterInput.add(new WineCluster(new Wine(data))); } // Data normalization - needed for clustering algorithms for(int j = 0; j < 13; j++){ //read the values of the attribute for (int i = 0; i < wineClusters.size(); i++) { orginalAttribute[i] = wineClusters.get(i).getValue(j); } //normalize the attribute normalizedAttribute = StatUtils.normalize(orginalAttribute); //replace the attribute's value for (int i = 0; i < wineClusters.size(); i++) { wineClusters.get(i).setValue(j, normalizedAttribute[i]); } } // I use KMeans++ algorithm with 3 clusters and 10000 iterations. // By default euclidean distance will be used. KMeansPlusPlusClusterer<WineCluster> clusterer = new KMeansPlusPlusClusterer<WineCluster>(3, 10000); List<CentroidCluster<WineCluster>> clusterResults = clusterer.cluster(clusterInput); // Print centers of the Centroid clusters for (int i = 0; i < clusterResults.size(); i++) { System.out.println("Cluster " + i); System.out.println(clusterResults.get(i).getCenter()); } // and clusters for (int i = 0; i < clusterResults.size(); i++) { System.out.println("Cluster " + i); for (WineCluster wineCluster : clusterResults.get(i).getPoints()) { Wine wine = wineCluster.getWine(); System.out.println(wine); } } |
Both Weka and Apache Commons Math are pretty useful and a nice alternative to the clickable frameworks. Although it takes more time to code stuff than to click it all through.
Just wanna say that this is very beneficial, Thanks for taking your time to write this.