4. Introduction to graphs and tf.function

2022. 3. 12. 13:52Tool/TensorFlow

This guide goes beneath the surface of TensorFlow and Keras to demonstrate how TensorFlow works.

In this guide, you'll learn

  • How TensorFlow allows you to make simple to your code to get graphs
  • How graphs are stored and represented
  • How you can use them to accelerate your models

This is a big-picture overview that covers how tf.function allows you to switch from eager execution to graph execution.

import tensorflow as tf

Graphs

Eager execution means TensorFlow operations are executed by Python, operation to operation, and returning results back to Python.

Graph execution means that tensor computations are executed as a TensorFlow graph(=tf.Graph, "graph").

Graph are data structures that contain 1) a set of tf.Operation objects, which represent units of computation and 2) tf.Tensor objects, which represent the units of data that flow between operations.

TensorFlow uses graphs as the format for saved models when it exports them from Python.

Graphs are extremely useful and let your TensorFlow run fast, run in parallel, and run efficiently on multiple devices.

How to build and run graphs

1. tf.function

<Function> = tf.function(<function>)
  1. Convert Eager function(Python code) into Graph function(Graph-generating code) by AutoGraph:
    Any function contain mixture of built-in TF operations and Python logic. While TensorFlow operation are easily captured by tf.Graph. But Python-specific logic need extra step in order to become a part of the graph becuase they aren't Tensor operation.
    • Conditionals
      A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph.
      tf.cond traces and adds both braches of the conditional to the graph, dynamically selecting a branch at execution time.
    • Loops
      A Python loop executes during tracing, adding additional ops to the tf.Graph for every iteration of the loop.
      A TensorFlow loop traces the body of the loop, and dynamically selecs how many iterations to run at execution time.
    • <graph_function> = tf.autograph.to_graph(<function>)
  2. Return Function which builds TensorFlow graphs from Graph function.
  3. <Function> = tf.function(autograph=False)(<graph_function>)

2. Function(=create and runs the tf.Graph)

<Function>(<input>)
  1. Tracing: Function creates a new tf.Graph. From top to bottom, Python code runs normally, but all TensorFlow operations are deferred: they captured by the tf.Graph and not run.
     # create ConcreteFunction
     <concrete_function> = <Function>.get_concrete_function(<input>)
    
     # Observe graph
     <graph> = <concrete_function>.graph
     for node in <graph>.as_graph_def().node:
         print(f'{node.input} -> {node.name}')
    Rules of tracing
    A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key(=signatures) is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules:
    • The key generated for a tf.Tensor is its shape and dtype
    • The key generated for a tf.Variable is a unique variable id
    • The key generated for a Python primitive is its value
    • The key generated for nested dicts, lists, tuples, namedtuples, and attr is the flattened tuple of leaf-keys.
    • For all other Python types the key is unique to the object
    • <Function>.pretty_printed_concrete_signatures()
  2. If Python primitive(ex. Integer, Boolean) operates with Tensor, tf.Graph capture it. Because it is Tensor operations.
    If Python primitive is return value, tf.Graph also capture it.
  3. tf.Graph which contains everything that was deferred is run.
    • The return value of the function
    • Documented well-known side-effects such as:
      • Input/Ouput operations, like tf.print
      • Debugging operations, such as the assert functions in tf.debugging
      • Mutations of tf.Variable
      <concrete_function>(<input>)
  4. Non-strict execution
    Graph execution only executes the operations necessary to produce the observable effects, which includes:

Terminology

  • tf.function wraps a Python function, returning a Function object.
  • Function encapsulates several tf.Graphs behind one API.
  • Function manages a cache of ConcreteFunction and picks the right one for your models.
  • tf.Graph is the raw language_agnostic, portable representation of a TensorFlow computation.
  • ConcreteFunction wraps a tf.Graph.
# WAY TO CREATE AND RUN GRAPH

# define Python function
def func(x, y, b):
    x = tf.matmul(x, y)
    x = x + b
    return x

# 1. direct call
direct_func = tf.function(func)

x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[1.0], [2.0]])
b1 = tf.constant(4.0)

print('1. Direct call: ', direct_func(x1, y1, b1))

# 2. decorator
@tf.function
def decorate_func(x, y, b):
    x = tf.matmul(x, y)
    x = x + b
    return x

print('2. Decorator: ', decorate_func(x1, y1, b1))
1. Direct call:  tf.Tensor([[9.]], shape=(1, 1), dtype=float32)
2. Decorator:  tf.Tensor([[9.]], shape=(1, 1), dtype=float32)

Performance and trade-offs

Graphs can speed up your code, but the process of creating them has some overhead. For some functions, the creation of the graph takes more time than the execution of the graph. This investment is usually quickly paid back with performance boost of subsequent executions ,but it's important to be aware that the first few steps of any large model training can be slower due to tracing. If you find you are getting unusually poor performance, it's a good idea to check if you are retracing accidentally.

Controlling retracing

To control the tracing behavior, you can use the following techniques:

  • Specify input_signature in tf.function to limit tracing
  • Specify a [None] dimension in tf.TensorSpec to allow for flexibility in trace reuse
  • Cast Python arguments to Tensors to reduce retracing
@tf.function(input_signature=(tf.TensorSpec(shape=[2, 2], dtype=tf.int32), ))
def example1(x):
    print('Tracing with: ', x)
    return x + 3

@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32), ))
def example2(x):
    print('Tracing with: ', x)
    return x + 100

@tf.function
def example3(x):
    print('Tracing with: ', x)
    for _ in tf.range(x):
        pass

print('Example 1: ')
print('input is [[1, 2], [3, 4]]: ', example1(tf.constant([[1, 2], [3, 4]])))
print('input is [[10, 20], [30, 40]]: ', example1(tf.constant([[10, 20], [30, 40]])))
print('-'*100)
print('Example 2: ')
print('input is [1, 2, 3]: ', example2(tf.constant([1, 2, 3])))
print('input is [10, 20, 30, 40]: ', example2(tf.constant([10, 20, 30, 40])))
print('-'*100)
print('Example 3: ')
print('input is tf.constant(10): ', example3(tf.constant(10)))
print('input is tf.constant(15): ', example3(tf.constant(15)))
Example 1: 
Tracing with:  Tensor("x:0", shape=(2, 2), dtype=int32)
input is [[1, 2], [3, 4]]:  tf.Tensor(
[[4 5]
 [6 7]], shape=(2, 2), dtype=int32)
input is [[10, 20], [30, 40]]:  tf.Tensor(
[[13 23]
 [33 43]], shape=(2, 2), dtype=int32)
----------------------------------------------------------------------------------------------------
Example 2: 
Tracing with:  Tensor("x:0", shape=(None,), dtype=int32)
input is [1, 2, 3]:  tf.Tensor([101 102 103], shape=(3,), dtype=int32)
input is [10, 20, 30, 40]:  tf.Tensor([110 120 130 140], shape=(4,), dtype=int32)
----------------------------------------------------------------------------------------------------
Example 3: 
Tracing with:  Tensor("x:0", shape=(), dtype=int32)
input is tf.constant(10):  None
input is tf.constant(15):  None

Debudding

You should ensure that your code executes error-free in eager mode before decorating with tf.function. To assist in the debugging process, you can call tf.config.run_functions_eagerly(True) to globally disable and reenable tf.function.

When tracking down issues that only appear within tf.function, here is some tips:

  • Plain old Python print calls only execute dring tracing, helping you track down when your function get (re)traced
  • tf.printcalls will execute every time,and can help you track down intermediate values during execution
  • tf.debugging.enable_check_numerics is an easy way to track down where NaNs and Inf are created

Limitations

TensorFlow Function has a few limitations by design that you should be aware of when conveting a Python function to a Function.

# Executing Python side effects

## 1. Changing Python global and free variables
external_list = []

@tf.function
def side_effect(x):
    print('Python side effect')
    external_list.append(x)

side_effect(1)
side_effect(1)
side_effect(1)
print('length of list: ', len(external_list))
print('-'*50)

## 2. Using Python iterators and generators
@tf.function
def buggy_consume_next(iterator):
    tf.print('value;', next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
buggy_consume_next(iterator)
buggy_consume_next(iterator)
print('-'*50)

# All outputs of a tf.function must be return values
x = None

@tf.function
def leaky_function(a):
    global x 
    x = a + 1
    return a + 2

correct_a = leaky_function(tf.constant(1))

print('Good: ', correct_a)
try:
    x.numpy()
except:
    print('Bad way')
print('-'*50)

# Recursive tf.functions are not supported
@tf.function
def recursive_fn(n):
    if n > 0:
        return recursive_fn(n - 1)
    else:
        return 1

try:
    recursive_fn(tf.constant(5))
except:
    print('Do not support recursive tf.function')
Python side effect
length of list:  1
--------------------------------------------------
value; 1
value; 1
value; 1
--------------------------------------------------
Good:  tf.Tensor(3, shape=(), dtype=int32)
Bad way
--------------------------------------------------
Do not support recursive tf.function