25 #include "../TensorFlowHelper.h" 27 #include "tensorflow/core/framework/op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/lib/core/errors.h" 37 class OutputAllocator {
39 OutputAllocator(tensorflow::OpKernelContext* context) : context(context) {}
41 void AllocIndices(tensorflow::int32** ptr,
size_t num) {
45 TensorShape shape({int64_t(num)});
46 OP_REQUIRES_OK(context, context->allocate_output(0, shape, &tensor));
47 auto flat_tensor = tensor->flat<int32>();
48 *ptr = flat_tensor.data();
51 void AllocDistances(T** ptr,
size_t num) {
55 TensorShape shape({int64_t(num)});
56 OP_REQUIRES_OK(context, context->allocate_output(2, shape, &tensor));
57 auto flat_tensor = tensor->flat<T>();
58 *ptr = flat_tensor.data();
62 tensorflow::OpKernelContext* context;
65 class FixedRadiusSearchOpKernel :
public tensorflow::OpKernel {
67 explicit FixedRadiusSearchOpKernel(
68 tensorflow::OpKernelConstruction* construction)
69 : OpKernel(construction) {
73 std::string metric_str;
74 OP_REQUIRES_OK(construction,
75 construction->GetAttr(
"metric", &metric_str));
76 if (metric_str ==
"L1")
78 else if (metric_str ==
"L2")
83 OP_REQUIRES_OK(construction,
84 construction->GetAttr(
"ignore_query_point",
85 &ignore_query_point));
87 OP_REQUIRES_OK(construction, construction->GetAttr(
"return_distances",
91 void Compute(tensorflow::OpKernelContext* context)
override {
93 static_assert(
sizeof(int64) ==
sizeof(int64_t),
94 "int64 type is not compatible");
96 const Tensor&
points = context->input(0);
97 const Tensor& queries = context->input(1);
99 const Tensor& radius = context->input(2);
100 OP_REQUIRES(context, TensorShapeUtils::IsScalar(radius.shape()),
101 errors::InvalidArgument(
"radius must be scalar, got shape ",
102 radius.shape().DebugString()));
104 const Tensor& points_row_splits = context->input(3);
105 const Tensor& queries_row_splits = context->input(4);
107 const Tensor& hash_table_splits = context->input(5);
108 const Tensor& hash_table_index = context->input(6);
109 const Tensor& hash_table_cell_splits = context->input(7);
114 Dim num_points(
"num_points");
115 Dim num_queries(
"num_queries");
116 Dim batch_size(
"batch_size");
117 Dim num_cells(
"num_cells");
119 CHECK_SHAPE(context, hash_table_index, num_points);
121 CHECK_SHAPE(context, points_row_splits, batch_size + 1);
122 CHECK_SHAPE(context, queries_row_splits, batch_size + 1);
123 CHECK_SHAPE(context, hash_table_splits, batch_size + 1);
124 CHECK_SHAPE(context, hash_table_cell_splits, num_cells + 1);
126 Tensor* query_neighbors_row_splits = 0;
127 TensorShape query_neighbors_row_splits_shape(
128 {queries.shape().dim_size(0) + 1});
129 OP_REQUIRES_OK(context, context->allocate_output(
130 1, query_neighbors_row_splits_shape,
131 &query_neighbors_row_splits));
133 Kernel(context, points, queries, radius, points_row_splits,
134 queries_row_splits, hash_table_splits, hash_table_index,
135 hash_table_cell_splits, *query_neighbors_row_splits);
138 virtual void Kernel(tensorflow::OpKernelContext* context,
139 const tensorflow::Tensor& points,
140 const tensorflow::Tensor& queries,
141 const tensorflow::Tensor& radius,
142 const tensorflow::Tensor& points_row_splits,
143 const tensorflow::Tensor& queries_row_splits,
144 const tensorflow::Tensor& hash_table_splits,
145 const tensorflow::Tensor& hash_table_index,
146 const tensorflow::Tensor& hash_table_cell_splits,
147 tensorflow::Tensor& query_neighbors_row_splits) = 0;
151 bool ignore_query_point;
152 bool return_distances;
Definition: NanoFlannIndex.h:53
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:204
Metric
Supported metrics.
Definition: NeighborSearchCommon.h:38
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
Definition: NanoFlannIndex.h:53
Definition: ContinuousConv.h:35
int points
Definition: FilePCD.cpp:73
Definition: NanoFlannIndex.h:53
Definition: ShapeChecking.h:35