Visualizing H2O GBM and Random Forest MOJO Models Trees in python

In this example we will build a tree based model first using H2O machine learning library and the save that model as MOJO. Using GraphViz/Dot library we will extract individual trees/cross validated model trees from the MOJO and visualize them. If you are new to H2O MOJO model, learn here.

You can also get full working Ipython Notebook for this example from here.

Lets build the model first using H2O GBM algorithm. You can also use Distributed Random Forest Model as well for tree visualization.

Let’s first import key python models:

import h2o
import subprocess
from IPython.display import Image

Now we will be building GBM Model using a public PROSTATE dataset:

h2o.init()
df = h2o.import_file('https://raw.githubusercontent.com/h2oai/sparkling-water/master/examples/smalldata/prostate.csv')
y = 'CAPSULE'
x = df.col_names
x.remove(y)
df[y] = df[y].asfactor()
train, valid, test = df.split_frame(ratios=[.8,.1])
from h2o.estimators.gbm import H2OGradientBoostingEstimator
gbm_cv3 = H2OGradientBoostingEstimator(nfolds=3)
gbm_cv3.train(x=x, y=y, training_frame=train)

## Getting all cross validated models 
all_models = gbm_cv3.cross_validation_models()
print("Total cross validation models: " + str(len(all_models)))

Now lets set all the default parameters to create the graph tree first and then tree images (in PNG format) in the local disk. Make sure you have a writable path where you can create and save these intermediate files. You also need to provide the path for latest H2O (h2o.jar) which is used to generate MOJO Model.

mojo_file_name = "/Users/avkashchauhan/Downloads/my_gbm_mojo.zip"
h2o_jar_path= '/Users/avkashchauhan/tools/h2o-3/h2o-3.14.0.3/h2o.jar'
mojo_full_path = mojo_file_name
gv_file_path = "/Users/avkashchauhan/Downloads/my_gbm_graph.gv"

Now lets definie Image file name which we will generate from the Tree ID.  Based on Tree ID the image file will have my_gbm_tree_ID.png file name

image_file_name = "/Users/avkashchauhan/Downloads/my_gbm_tree"
Now we will be downloading GBM MOJO Model by saving to disk:
 gbm_cv3.download_mojo(mojo_file_name)

Now lets define the function to generate graphViz tree from the saved MOJO model:

def generateTree(h2o_jar_path, mojo_full_path, gv_file_path, image_file_path, tree_id = 0):
    image_file_path = image_file_path + "_" + str(tree_id) + ".png"
    result = subprocess.call(["java", "-cp", h2o_jar_path, "hex.genmodel.tools.PrintMojo", "--tree", str(tree_id), "-i", mojo_full_path , "-o", gv_file_path ], shell=False)
    result = subprocess.call(["ls",gv_file_path], shell = False)
    if result is 0:
        print("Success: Graphviz file " + gv_file_path + " is generated.")
    else: 
        print("Error: Graphviz file " + gv_file_path + " could not be generated.")

Now lets defined the method to generate Tree image as PNG from the saved GraphViz tree:

def generateTreeImage(gv_file_path, image_file_path, tree_id):
    image_file_path = image_file_path + "_" + str(tree_id) + ".png"
    result = subprocess.call(["dot", "-Tpng", gv_file_path, "-o", image_file_path], shell=False)
    result = subprocess.call(["ls",image_file_path], shell = False)
    if result is 0:
        print("Success: Image File " + image_file_path + " is generated.")
        print("Now you can execute the follow line as-it-is to see the tree graph:") 
        print("Image(filename='" + image_file_path + "\')")
    else:
        print("Error: Image file " + image_file_path + " could not be generated.")

Note: I had to write 2 steps process above because If I put all in 1 step the process hung after graphviz is created.

Now lets generate tree by passing all parameters defined above and proper TREE ID as the last parameter.

#Just change the tree id in the function below to get which particular tree you want
generateTree(h2o_jar_path, mojo_full_path, gv_file_path, image_file_name, 3)

Now we will be generating PNG Tree Image from the saved GraphViz content.

generateTreeImage(gv_file_path, image_file_name, 3)
# Note: If this step hangs, you can look at "dot" active process in osx and try killing it

Lets visualize the main model tree:

# Just pass the Tree Image file name depending on your tree
Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_0.png')

tree-0

Lets Visualize the first Cross Validation tree (Cross Validation ID- 1)

# Just pass the Tree Image file name depending on your tree
Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_1.png')

tree-1

Lets Visualize the first Cross Validation tree (Cross Validation ID- 2)

# Just pass the Tree Image file name depending on your tree
Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_2.png')

tree-2

Lets Visualize the first Cross Validation tree (Cross Validation ID- 3)

Just pass the Tree Image file name depending on your tree

Image(filename=’/Users/avkashchauhan/Downloads/my_gbm_tree_3.png’)

tree-3

After looking at these tree, you can visualize how the decision are made.

Helpful documentation:

Thats it, enjoy!!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s