MXNet made simple: Clojure Symbol Visualization API

MXNet made simple: Clojure Symbol Visualization API

- 12 mins

In this post we will look at the MXNet visualization API. We will learn how to visualize pretrained models and user defined models.

Before we begin…

We will need to import certain packages:

(require '[org.apache.clojure-mxnet.module :as m])
(require '[org.apache.clojure-mxnet.visualization :as viz])

Pretrained Models

The MXNet Model Zoo is a central place for downloading state of the art pretrained models. One can download the model computation graphs and their trained parameters. It makes it straightforward to get started with making new predictions in no time.

We are going to download VGG16 and ResNet18: two common state of the art models to perform computer vision tasks such as classification, segmentation, etc.

Below is the bash script for downloading VGG16.

#!/bin/bash

set -evx

mkdir -p model
cd model
wget http://data.mxnet.io/models/imagenet/vgg/vgg16-symbol.json
wget http://data.mxnet.io/models/imagenet/vgg/vgg16-0000.params
cd ..
# Execute the bash script
$ chmod a+x download_vgg16.sh
$ sh download_vgg16.sh

And below is the bash script to download ResNet18

#!/bin/bash

set -evx

mkdir -p model
cd model
wget http://data.mxnet.io/models/imagenet/resnet/18-layers/resnet-18-symbol.json
wget http://data.mxnet.io/models/imagenet/resnet/18-layers/resnet-18-0000.params
cd ..
# Execute the bash script
$ chmod a+x download_resnet18.sh
$ sh download_resnet18.sh

Make sure that the models are properly downloaded

$ cd model
$ ls
resnet-18-0000.params  resnet-18-symbol.json
vgg16-0000.params      vgg16-symbol.json

One can load the computation graph of a model using the Module API

(def model-dir "model")

(def vgg16-mod
  "VGG16 Module"
  (m/load-checkpoint {:prefix (str model-dir "/vgg16") :epoch 0}))

(def resnet18-mod
  "Resnet18 Module"
  (m/load-checkpoint {:prefix (str model-dir "/resnet-18") :epoch 0}))

The visualization API uses graphviz under the hood to render computation graphs. We can write a small function that takes in the symbol to render and the path where to save the generated graphviz. By default, it generates pdf files as output format.

(defn render-model!
  "Render the `model-sym` and saves it as a pdf file in `path/model-name.pdf`"
  [{:keys [model-name model-sym input-data-shape path]}]
  (let [dot (viz/plot-network
              model-sym
              {"data" input-data-shape}
              {:title model-name
               :node-attrs {:shape "oval" :fixedsize "false"}})]
    (viz/render dot model-name path)))

Now we can visualize the pretrained models by calling this function

(def model-render-dir "model_render")

;; Rendering pretrained VGG16
(render-model! {:model-name "vgg16"
                :model-sym (m/symbol vgg16-mod)
                :input-data-shape [1 3 244 244]
                :path model-render-dir})

;; Rendering pretrained Resnet18
(render-model! {:model-name "resnet18"
                :model-sym (m/symbol resnet18-mod)
                :input-data-shape [1 3 244 244]
                :path model-render-dir})

User Defined Model

We can also visualize our own models with the same approach. We will define the LeNet model and visualize it with the Symbol Visualization API.

(require '[org.apache.clojure-mxnet.symbol :as sym])

(defn get-symbol
  "Return LeNet Symbol

  Input data shape [`batch-size` `channels` 28 28]
  Output data shape [`batch-size 10]"
  []
  (as-> (sym/variable "data") data

    ;; First `convolution` layer
    (sym/convolution "conv1" {:data data :kernel [5 5] :num-filter 20})
    (sym/activation "tanh1" {:data data :act-type "tanh"})
    (sym/pooling "pool1" {:data data :pool-type "max" :kernel [2 2] :stride [2 2]})

    ;; Second `convolution` layer
    (sym/convolution "conv2" {:data data :kernel [5 5] :num-filter 50})
    (sym/activation "tanh2" {:data data :act-type "tanh"})
    (sym/pooling "pool2" {:data data :pool-type "max" :kernel [2 2] :stride [2 2]})

    ;; Flattening before the Fully Connected Layers
    (sym/flatten "flatten" {:data data})

    ;; First `fully-connected` layer
    (sym/fully-connected "fc1" {:data data :num-hidden 500})
    (sym/activation "tanh3" {:data data :act-type "tanh"})

    ;; Second `fully-connected` layer
    (sym/fully-connected "fc2" {:data data :num-hidden 10})

    ;; Softmax Loss
    (sym/softmax-output "softmax" {:data data})))

Now we can render it the same way as the pretrained models

;; Rendering user defined LeNet
(render-model! {:model-name "lenet"
                :model-sym (get-symbol)
                :input-data-shape [1 3 28 28]
                :path model-render-dir})

Rendered Models: VGG16, ResNet18 and LeNet

Here is a summary of the models we rendered in this tutorial

VGG16 ResNet18 LeNet
VGG16 Topology ResNet18 Topology LeNet Topology

Conclusion

The Symbol Visualization API makes it simple to visualize any models: pretrained and user defined. It is good practice to make sure the topology of a model makes sense before training it or making predictions.

References and Resources

Here is also the code used in this post - also available in this repository

(ns mxnet-clj-tutorials.lenet
  (:require [org.apache.clojure-mxnet.symbol :as sym]))

(defn get-symbol
  "Return LeNet Symbol

  Input data shape [`batch-size` `channels` 28 28]
  Output data shape [`batch-size 10]"
  []
  (as-> (sym/variable "data") data

    ;; First `convolution` layer
    (sym/convolution "conv1" {:data data :kernel [5 5] :num-filter 20})
    (sym/activation "tanh1" {:data data :act-type "tanh"})
    (sym/pooling "pool1" {:data data :pool-type "max" :kernel [2 2] :stride [2 2]})

    ;; Second `convolution` layer
    (sym/convolution "conv2" {:data data :kernel [5 5] :num-filter 50})
    (sym/activation "tanh2" {:data data :act-type "tanh"})
    (sym/pooling "pool2" {:data data :pool-type "max" :kernel [2 2] :stride [2 2]})

    ;; Flattening before the Fully Connected Layers
    (sym/flatten "flatten" {:data data})

    ;; First `fully-connected` layer
    (sym/fully-connected "fc1" {:data data :num-hidden 500})
    (sym/activation "tanh3" {:data data :act-type "tanh"})

    ;; Second `fully-connected` layer
    (sym/fully-connected "fc2" {:data data :num-hidden 10})

    ;; Softmax Loss
    (sym/softmax-output "softmax" {:data data})))
(ns mxnet-clj-tutorials.visualization
  "Functions and utils to render pretrained and user defined models."
  (:require
    [org.apache.clojure-mxnet.module :as m]
    [org.apache.clojure-mxnet.visualization :as viz]

    [mxnet-clj-tutorials.lenet :as lenet]))

;; Run the `download_vgg16.sh` and `download_resnet18.sh`
;; prior to running the following code

(def model-dir "model")
(def model-render-dir "model_render")

;; Loading pretrained models

(def vgg16-mod
  "VGG16 Module"
  (m/load-checkpoint {:prefix (str model-dir "/vgg16") :epoch 0}))

(def resnet18-mod
  "Resnet18 Module"
  (m/load-checkpoint {:prefix (str model-dir "/resnet-18") :epoch 0}))

(defn render-model!
  "Render the `model-sym` and saves it as a pdf file in `path/model-name.pdf`"
  [{:keys [model-name model-sym input-data-shape path]}]
  (let [dot (viz/plot-network
              model-sym
              {"data" input-data-shape}
              {:title model-name
               :node-attrs {:shape "oval" :fixedsize "false"}})]
    (viz/render dot model-name path)))

(comment
  ;; Run the following function calls to render the models in `model-render-dir`

  ;; Rendering pretrained VGG16
  (render-model! {:model-name "vgg16"
                  :model-sym (m/symbol vgg16-mod)
                  :input-data-shape [1 3 244 244]
                  :path model-render-dir})

  ;; Rendering pretrained Resnet18
  (render-model! {:model-name "resnet18"
                  :model-sym (m/symbol resnet18-mod)
                  :input-data-shape [1 3 244 244]
                  :path model-render-dir})

  ;; Rendering user defined LeNet
  (render-model! {:model-name "lenet"
                  :model-sym (lenet/get-symbol)
                  :input-data-shape [1 3 28 28]
                  :path model-render-dir}))
Arthur Caillau

Arthur Caillau

A man who eats parentheses for breakfast

rss facebook twitter github gitlab youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora