Headline
CVE-2022-21732: [tf.data] Set limit on number of threads used in threadpool_dataset. · tensorflow/tensorflow@e3749a6
Tensorflow is an Open Source Machine Learning Framework. The implementation of ThreadPoolHandle
can be used to trigger a denial of service attack by allocating too much memory. This is because the num_threads
argument is only checked to not be negative, but there is no upper bound on its value. The fix will be included in TensorFlow 2.8.0. We will also cherrypick this commit on TensorFlow 2.7.1, TensorFlow 2.6.3, and TensorFlow 2.5.3, as these are also affected and still in supported range.
@@ -39,6 +39,22 @@ namespace experimental { PrivateThreadPoolDatasetOp::kDatasetType; /* static */ constexpr const char* const PrivateThreadPoolDatasetOp::kDatasetOp;
namespace { // To prevent integer overflow issues when allocating threadpool memory for an // unreasonable number of threads. constexpr int kThreadLimit = 65536;
Status ValidateNumThreads(int32_t num_threads) { if (num_threads < 0) { return errors::InvalidArgument(“`num_threads` must be >= 0”); } if (num_threads >= kThreadLimit) { return errors::InvalidArgument("`num_threads` must be < ", kThreadLimit); } return Status::OK(); } } // namespace
class ThreadPoolResource : public ResourceBase { public: ThreadPoolResource(Env* env, const ThreadOptions& thread_options, @@ -83,9 +99,7 @@ class ThreadPoolHandleOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", &max_intra_op_parallelism_)); OP_REQUIRES( ctx, num_threads_ > 0, errors::InvalidArgument(“`num_threads` must be greater than zero.”)); OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads_)); }
// The resource is deleted from the resource manager only when it is private @@ -531,8 +545,7 @@ void PrivateThreadPoolDatasetOp::MakeDatasetFromOptions(OpKernelContext* ctx, DatasetBase* input, int32_t num_threads, DatasetBase** output) { OP_REQUIRES(ctx, num_threads >= 0, errors::InvalidArgument(“`num_threads` must be >= 0”)); OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads)); *output = new Dataset(ctx, DatasetContext(DatasetContext::Params( {PrivateThreadPoolDatasetOp::kDatasetType, @@ -546,8 +559,7 @@ void PrivateThreadPoolDatasetOp::MakeDataset(OpKernelContext* ctx, int64_t num_threads = 0; OP_REQUIRES_OK( ctx, ParseScalarArgument<int64_t>(ctx, "num_threads", &num_threads)); OP_REQUIRES(ctx, num_threads >= 0, errors::InvalidArgument(“`num_threads` must be >= 0”)); OP_REQUIRES_OK(ctx, ValidateNumThreads(num_threads)); *output = new Dataset(ctx, input, num_threads); }