Headline
CVE-2021-41213: Fix the deadlock issue of recursive tf.function. · tensorflow/tensorflow@afac815
TensorFlow is an open source platform for machine learning. In affected versions the code behind tf.function
API can be made to deadlock when two tf.function
decorated Python functions are mutually recursive. This occurs due to using a non-reentrant Lock
Python object. Loading any model which contains mutually recursive functions is vulnerable. An attacker can cause denial of service by causing users to load such models and calling a recursive tf.function
, although this is not a frequent scenario. The fix will be included in TensorFlow 2.7.0. We will also cherrypick this commit on TensorFlow 2.6.1, TensorFlow 2.5.2, and TensorFlow 2.4.4, as these are also affected and still in supported range.
@@ -25,6 +25,7 @@ from six.moves import range
from tensorflow.python.autograph.core import converter from tensorflow.python.eager import backprop from tensorflow.python.eager import def_function from tensorflow.python.eager import lift_to_graph from tensorflow.python.framework import constant_op @@ -36,6 +37,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.module import module from tensorflow.python.ops import array_ops from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -1261,6 +1263,117 @@ def testDouble(self, a): self.assertAllEqual(obj2.testDouble.experimental_get_tracing_count(), 3) self.assertAllEqual(obj1.testDouble.experimental_get_tracing_count(), 2)
def test_recursive_tf_function(self):
@def_function.function def recursive_fn(n): if n > 0: return recursive_fn(n - 1) return 1
self.assertEqual(recursive_fn(5).numpy(), 1)
def test_recursive_tf_function_with_gradients(self):
@def_function.function def recursive_fn(n, x): if n > 0: return n * recursive_fn(n - 1, x) else: return x
x = variables.Variable(1.0) with backprop.GradientTape() as tape: g = recursive_fn(5, x)
dg_dx = tape.gradient(g, x) self.assertEqual(dg_dx.numpy(), 120)
def test_recursive_python_function(self):
def recursive_py_fn(n): if n > 0: return recursive_py_fn(n - 1) return 1
@def_function.function def recursive_fn(n): return recursive_py_fn(n)
self.assertEqual(recursive_fn(5).numpy(), 1)
def test_recursive_python_function_with_gradients(self):
def recursive_py_fn(n, x): if n > 0: return n * recursive_py_fn(n - 1, x) return x
@def_function.function def recursive_fn(n, x): return recursive_py_fn(n, x)
x = variables.Variable(1.0) with backprop.GradientTape() as tape: g = recursive_fn(5, x)
dg_dx = tape.gradient(g, x) self.assertEqual(dg_dx.numpy(), 120)
def test_recursive_tf_function_call_each_other(self):
@def_function.function def recursive_fn1(n): if n <= 1: return 1 return recursive_fn2(n - 1)
@def_function.function def recursive_fn2(n): if n <= 1: return 2 return recursive_fn1(n - 1)
self.assertEqual(recursive_fn1(5).numpy(), 1) self.assertEqual(recursive_fn1(6).numpy(), 2) self.assertEqual(recursive_fn2(5).numpy(), 2) self.assertEqual(recursive_fn2(6).numpy(), 1)
def test_recursive_tf_function_call_each_other_with_gradients(self):
@def_function.function def recursive_fn1(n, x): if n <= 1: return x return n * recursive_fn2(n - 1, x)
@def_function.function def recursive_fn2(n, x): if n <= 1: return 2 * x return n * recursive_fn1(n - 1, x)
x = variables.Variable(1.0) with backprop.GradientTape() as tape: g1 = recursive_fn1(5, x)
dg1_dx = tape.gradient(g1, x) self.assertEqual(dg1_dx.numpy(), 120)
with backprop.GradientTape() as tape: g2 = recursive_fn2(5, x)
dg2_dx = tape.gradient(g2, x) self.assertEqual(dg2_dx.numpy(), 240)
def test_recursive_tf_function_with_cond(self): @def_function.function(autograph=False) def recursive_fn(n): return cond_v2.cond_v2(n > 0, recursive_fn(n - 1), 1)
with self.assertRaises(RecursionError): recursive_fn(constant_op.constant(5))
if __name__ == '__main__’: ops.enable_eager_execution()