Multinomial classification example in Scala and Deep Learning with H2O

Here is a sample for multinomial classification problem using H2O Deep Learning algorithm and iris data set in Scala language.

The following sample is for multinomial classification problem. This sample is created using Spark 2.1.0 with Sparkling Water 2.1.4.

import org.apache.spark.h2o._
import water.support.SparkContextSupport.addFiles
import org.apache.spark.SparkFiles
import java.io.File
import water.support.{H2OFrameSupport, SparkContextSupport, ModelMetricsSupport}
import water.Key
import _root_.hex.deeplearning.DeepLearningModel
import _root_.hex.ModelMetricsMultinomial


val hc = H2OContext.getOrCreate(sc)
import hc._
import hc.implicits._

addFiles(sc, "/Users/avkashchauhan/smalldata/iris/iris.csv")
val irisData = new H2OFrame(new File(SparkFiles.get("iris.csv")))

val ratios = Array[Double](0.8)
val keys = Array[String]("train.hex", "valid.hex")
val frs = H2OFrameSupport.split(irisData, keys, ratios)
val (train, valid) = (frs(0), frs(1))

def buildDLModel(train: Frame, valid: Frame, response: String,
 epochs: Int = 10, l1: Double = 0.001, l2: Double = 0.0,
 hidden: Array[Int] = Array[Int](200, 200))
 (implicit h2oContext: H2OContext): DeepLearningModel = {
 import h2oContext.implicits._
 // Build a model
 import _root_.hex.deeplearning.DeepLearning
 import _root_.hex.deeplearning.DeepLearningModel.DeepLearningParameters
 val dlParams = new DeepLearningParameters()
 dlParams._train = train
 dlParams._valid = valid
 dlParams._response_column = response
 dlParams._epochs = epochs
 dlParams._l1 = l1
 dlParams._hidden = hidden
 // Create a job
 val dl = new DeepLearning(dlParams, Key.make("dlModel.hex"))
 dl.trainModel.get
}


// Note: The response column name is C5 here so passing:
val dlModel = buildDLModel(train, valid, 'C5)(hc)

// Collect model metrics and evaluate model quality
val trainMetrics = ModelMetricsSupport.modelMetrics[ModelMetricsMultinomial](dlModel, train)
val validMetrics = ModelMetricsSupport.modelMetrics[ModelMetricsMultinomial](dlModel, valid)
println(trainMetrics.rmse)
println(validMetrics.rmse)
println(trainMetrics.mse)
println(validMetrics.mse)
println(trainMetrics.r2)
println(validMetrics.r2)

Thats it, enjoy!!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s