Here is a sample for multinomial classification problem using H2O Deep Learning algorithm and iris data set in Scala language.
The following sample is for multinomial classification problem. This sample is created using Spark 2.1.0 with Sparkling Water 2.1.4.
import org.apache.spark.h2o._ import water.support.SparkContextSupport.addFiles import org.apache.spark.SparkFiles import java.io.File import water.support.{H2OFrameSupport, SparkContextSupport, ModelMetricsSupport} import water.Key import _root_.hex.deeplearning.DeepLearningModel import _root_.hex.ModelMetricsMultinomial val hc = H2OContext.getOrCreate(sc) import hc._ import hc.implicits._ addFiles(sc, "/Users/avkashchauhan/smalldata/iris/iris.csv") val irisData = new H2OFrame(new File(SparkFiles.get("iris.csv"))) val ratios = Array[Double](0.8) val keys = Array[String]("train.hex", "valid.hex") val frs = H2OFrameSupport.split(irisData, keys, ratios) val (train, valid) = (frs(0), frs(1)) def buildDLModel(train: Frame, valid: Frame, response: String, epochs: Int = 10, l1: Double = 0.001, l2: Double = 0.0, hidden: Array[Int] = Array[Int](200, 200)) (implicit h2oContext: H2OContext): DeepLearningModel = { import h2oContext.implicits._ // Build a model import _root_.hex.deeplearning.DeepLearning import _root_.hex.deeplearning.DeepLearningModel.DeepLearningParameters val dlParams = new DeepLearningParameters() dlParams._train = train dlParams._valid = valid dlParams._response_column = response dlParams._epochs = epochs dlParams._l1 = l1 dlParams._hidden = hidden // Create a job val dl = new DeepLearning(dlParams, Key.make("dlModel.hex")) dl.trainModel.get } // Note: The response column name is C5 here so passing: val dlModel = buildDLModel(train, valid, 'C5)(hc) // Collect model metrics and evaluate model quality val trainMetrics = ModelMetricsSupport.modelMetrics[ModelMetricsMultinomial](dlModel, train) val validMetrics = ModelMetricsSupport.modelMetrics[ModelMetricsMultinomial](dlModel, valid) println(trainMetrics.rmse) println(validMetrics.rmse) println(trainMetrics.mse) println(validMetrics.mse) println(trainMetrics.r2) println(validMetrics.r2)
Thats it, enjoy!!