10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
17 namespace invert_neighbors_list_opkernel {
20 class InvertNeighborsListOpKernel :
public tensorflow::OpKernel {
22 explicit InvertNeighborsListOpKernel(
23 tensorflow::OpKernelConstruction* construction)
24 : OpKernel(construction) {}
26 void Compute(tensorflow::OpKernelContext*
context)
override {
27 using namespace tensorflow;
28 static_assert(
sizeof(int64) ==
sizeof(int64_t),
29 "int64 type is not compatible");
31 const Tensor& num_points_tensor =
context->input(0);
33 TensorShapeUtils::IsScalar(num_points_tensor.shape()),
34 errors::InvalidArgument(
35 "num_points must be scalar, got shape ",
36 num_points_tensor.shape().DebugString()));
37 const int64 num_points = num_points_tensor.scalar<int64>()();
39 const Tensor& inp_neighbors_index =
context->input(1);
41 const Tensor& inp_neighbors_row_splits =
context->input(2);
43 const Tensor& inp_neighbors_attributes =
context->input(3);
48 Dim num_neighbors(
"num_neighbors");
58 if (inp_neighbors_attributes.shape().dim_size(0) == 0) {
62 for (
int i = 1; i < inp_neighbors_attributes.shape().dims(); ++i)
63 num_attributes *= inp_neighbors_attributes.shape().dim_size(i);
66 Tensor* neighbors_index = 0;
67 TensorShape neighbors_index_shape(inp_neighbors_index.shape());
69 context->allocate_output(0, neighbors_index_shape,
72 Tensor* neighbors_row_splits = 0;
73 TensorShape neighbors_row_splits_shape({num_points + 1});
75 context->allocate_output(1, neighbors_row_splits_shape,
76 &neighbors_row_splits));
78 Tensor* neighbors_attributes = 0;
79 TensorShape neighbors_attributes_shape(
80 inp_neighbors_attributes.shape());
82 context->allocate_output(2, neighbors_attributes_shape,
83 &neighbors_attributes));
85 Kernel(
context, inp_neighbors_index, inp_neighbors_row_splits,
86 inp_neighbors_attributes, num_attributes, *neighbors_index,
87 *neighbors_row_splits, *neighbors_attributes);
91 virtual void Kernel(tensorflow::OpKernelContext*
context,
92 const tensorflow::Tensor& inp_neighbors_index,
93 const tensorflow::Tensor& inp_neighbors_row_splits,
94 const tensorflow::Tensor& inp_neighbors_attributes,
95 const int num_attributes,
96 tensorflow::Tensor& neighbors_index,
97 tensorflow::Tensor& neighbors_row_splits,
98 tensorflow::Tensor& neighbors_attributes) = 0;
#define CHECK_SHAPE_IGNORE_LAST_DIMS(tensor,...)
Definition: TorchHelper.h:225
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:186
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16