In this example we will build a tree based model first using H2O machine learning library and the save that model as MOJO. Using GraphViz/Dot library we will extract individual trees/cross validated model trees from the MOJO and visualize them. If you are new to H2O MOJO model, learn here.
You can also get full working Ipython Notebook for this example from here.
Lets build the model first using H2O GBM algorithm. You can also use Distributed Random Forest Model as well for tree visualization.
Let’s first import key python models:
import h2o import subprocess from IPython.display import Image
Now we will be building GBM Model using a public PROSTATE dataset:
h2o.init() df = h2o.import_file('https://raw.githubusercontent.com/h2oai/sparkling-water/master/examples/smalldata/prostate.csv') y = 'CAPSULE' x = df.col_names x.remove(y) df[y] = df[y].asfactor() train, valid, test = df.split_frame(ratios=[.8,.1]) from h2o.estimators.gbm import H2OGradientBoostingEstimator gbm_cv3 = H2OGradientBoostingEstimator(nfolds=3) gbm_cv3.train(x=x, y=y, training_frame=train) ## Getting all cross validated models all_models = gbm_cv3.cross_validation_models() print("Total cross validation models: " + str(len(all_models)))
Now lets set all the default parameters to create the graph tree first and then tree images (in PNG format) in the local disk. Make sure you have a writable path where you can create and save these intermediate files. You also need to provide the path for latest H2O (h2o.jar) which is used to generate MOJO Model.
mojo_file_name = "/Users/avkashchauhan/Downloads/my_gbm_mojo.zip" h2o_jar_path= '/Users/avkashchauhan/tools/h2o-3/h2o-184.108.40.206/h2o.jar' mojo_full_path = mojo_file_name gv_file_path = "/Users/avkashchauhan/Downloads/my_gbm_graph.gv"
Now lets definie Image file name which we will generate from the Tree ID. Based on Tree ID the image file will have my_gbm_tree_ID.png file name
image_file_name = "/Users/avkashchauhan/Downloads/my_gbm_tree"
Now we will be downloading GBM MOJO Model by saving to disk: gbm_cv3.download_mojo(mojo_file_name)
Now lets define the function to generate graphViz tree from the saved MOJO model:
def generateTree(h2o_jar_path, mojo_full_path, gv_file_path, image_file_path, tree_id = 0): image_file_path = image_file_path + "_" + str(tree_id) + ".png" result = subprocess.call(["java", "-cp", h2o_jar_path, "hex.genmodel.tools.PrintMojo", "--tree", str(tree_id), "-i", mojo_full_path , "-o", gv_file_path ], shell=False) result = subprocess.call(["ls",gv_file_path], shell = False) if result is 0: print("Success: Graphviz file " + gv_file_path + " is generated.") else: print("Error: Graphviz file " + gv_file_path + " could not be generated.")
Now lets defined the method to generate Tree image as PNG from the saved GraphViz tree:
def generateTreeImage(gv_file_path, image_file_path, tree_id): image_file_path = image_file_path + "_" + str(tree_id) + ".png" result = subprocess.call(["dot", "-Tpng", gv_file_path, "-o", image_file_path], shell=False) result = subprocess.call(["ls",image_file_path], shell = False) if result is 0: print("Success: Image File " + image_file_path + " is generated.") print("Now you can execute the follow line as-it-is to see the tree graph:") print("Image(filename='" + image_file_path + "\')") else: print("Error: Image file " + image_file_path + " could not be generated.")
Note: I had to write 2 steps process above because If I put all in 1 step the process hung after graphviz is created.
Now lets generate tree by passing all parameters defined above and proper TREE ID as the last parameter.
#Just change the tree id in the function below to get which particular tree you want generateTree(h2o_jar_path, mojo_full_path, gv_file_path, image_file_name, 3)
Now we will be generating PNG Tree Image from the saved GraphViz content.
generateTreeImage(gv_file_path, image_file_name, 3) # Note: If this step hangs, you can look at "dot" active process in osx and try killing it
Lets visualize the main model tree:
# Just pass the Tree Image file name depending on your tree Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_0.png')
Lets Visualize the first Cross Validation tree (Cross Validation ID- 1)
# Just pass the Tree Image file name depending on your tree Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_1.png')
Lets Visualize the first Cross Validation tree (Cross Validation ID- 2)
# Just pass the Tree Image file name depending on your tree Image(filename='/Users/avkashchauhan/Downloads/my_gbm_tree_2.png')
Lets Visualize the first Cross Validation tree (Cross Validation ID- 3)
Just pass the Tree Image file name depending on your tree
After looking at these tree, you can visualize how the decision are made.
Thats it, enjoy!!