10 #include "tensorflow/core/framework/op.h"
11 #include "tensorflow/core/framework/op_kernel.h"
12 #include "tensorflow/core/lib/core/errors.h"
16 namespace reduce_subarrays_sum_opkernel {
19 class ReduceSubarraysSumOpKernel :
public tensorflow::OpKernel {
21 explicit ReduceSubarraysSumOpKernel(
22 tensorflow::OpKernelConstruction* construction)
23 : OpKernel(construction) {}
25 void Compute(tensorflow::OpKernelContext*
context)
override {
26 using namespace tensorflow;
27 static_assert(
sizeof(int64) ==
sizeof(int64_t),
28 "int64 type is not compatible");
30 const Tensor& values =
context->input(0);
31 OP_REQUIRES(
context, values.shape().dims() == 1,
32 errors::InvalidArgument(
"values must be a rank 1 tensor"));
34 const Tensor& row_splits =
context->input(1);
36 context, row_splits.shape().dims() == 1,
37 errors::InvalidArgument(
"row_splits must be a rank 1 tensor"));
40 if (values.shape().dim_size(0) == 0) {
41 Tensor* sums_tensor = 0;
42 OP_REQUIRES_OK(
context,
context->allocate_output(0, values.shape(),
47 Tensor* sums_tensor = 0;
48 TensorShape sums_shape({row_splits.shape().dim_size(0) - 1});
50 context->allocate_output(0, sums_shape, &sums_tensor));
52 Kernel(
context, values, row_splits, *sums_tensor);
56 virtual void Kernel(tensorflow::OpKernelContext*
context,
57 const tensorflow::Tensor& values,
58 const tensorflow::Tensor& row_splits,
59 tensorflow::Tensor& sums) = 0;
ImGuiContext * context
Definition: Window.cpp:76