Visualising the PyTorch Compute Graph for Bug Fixing

Static vs. Dynamic graphs

In both Tensorflow and PyTorch, a lot is made about the compute graph and Autograd. In a nutshell, all your operations are put into a big graph. Your tensors then flow through this graph and pop out at the other end. Fairly simple really. The reason for this is we can build a set of operations for which we can create derivatives. These derivatives are crucial for A.I. This has been covered elsewhere in many places so I won’t repeat it here.

Tool time

The first step is complaining and getting angry! The next step is to think of why this is the case? The final step is to build a tool and fix it! Maybe afterwards, you share it with the world because other folks may be struggling with the same issues.

  • During the forward pass, the graph is constructed. These functions are held on particular tensors using the grad_fn attribute.
  • This attribute has a list of next functions it passes it’s result on to. These are held in the next_functions attribute.
  • Some of these function objects have a variable attribute. The will most likely be the leaf nodes that PyTorch bangs on about all the time, and the places where we are likely to see the results.
  • To make this point more strongly, these variables have a is_leaf method that confirms the leafy-ness of this node.
  • Sometimes, the same node might reappear. It’s easy to follow that node again and end-up in some sort of loop, printing extra nodes that aren’t needed. We’ll need to keep a list of things we’ve seen.
  • PyTorch doesn’t have a labelling or naming system for tensors. So what we get when we print variables is memory location. We’ll need a list of the objects we are interested in.

What was the bug?

Once I’d built this tool, I set to work bug hunting. I noticed that there was a leaf node being output at the right place, right next to another tensor with a name I recognised. But why was this tensor not being recognised? It was as if the result I wanted was being thrown into the void and not ending up in the right place.

[[x], [y], [z], [w]]
[x, y, z, w]
target = target.reshape(-1,4,1)

The moral?

Tensorflow has a good set of visualisation tools in it’s tensorboard package. While they are quite heavyweight, it’s clear there is a need for them. We have debuggers and visual tools for a reason, and I think PyTorch is missing one or two of these; in fact I have another post in mind about visualising results from PyTorch.

# Our drawing graph functions. We rely / have borrowed from the following
# python libraries:
# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py
# https://github.com/willmcgugan/rich
# https://graphviz.readthedocs.io/en/stable/
def draw_graph(start, watch=[]):
from graphviz import Digraph
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
graph = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
assert(hasattr(start, "grad_fn"))
if start.grad_fn is not None:
_draw_graph(loss.grad_fn, graph, watch=watching)
size_per_element = 0.15
min_size = 12
# Get the approximate number of nodes and edges
num_rows = len(graph.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
graph.graph_attr.update(size=size_str)
graph.render(filename='net_graph.jpg')
def _draw_graph(var, graph, watch=[], seen=[], indent="", pobj=None):
''' recursive function going through the hierarchical graph printing off
what we need to see what autograd is doing.'''
from rich import print
if hasattr(var, "next_functions"):
for fun in var.next_functions:
joy = fun[0]
if joy is not None:
if joy not in seen:
label = str(type(joy)).replace(
"class", "").replace("'", "").replace(" ", "")
label_graph = label
colour_graph = ""
seen.append(joy)
if hasattr(joy, 'variable'):
happy = joy.variable
if happy.is_leaf:
label += " \U0001F343"
colour_graph = "green"
for (name, obj) in watch:
if obj is happy:
label += " \U000023E9 " + \
"[b][u][color=#FF00FF]" + name + \
"[/color][/u][/b]"
label_graph += name
colour_graph = "blue"
break
vv = [str(obj.shape[x])
for x in range(len(obj.shape))]
label += " [["
label += ', '.join(vv)
label += "]]"
label += " " + str(happy.var())
graph.node(str(joy), label_graph, fillcolor=colour_graph)
print(indent + label)
_draw_graph(joy, graph, watch, seen, indent + ".", joy)
if pobj is not None:
graph.edge(str(pobj), str(joy))

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Benjamin Blundell

Benjamin Blundell

718 Followers

Freelance Research Software Engineer and Bioinformatics Student.