25 #include "../TensorFlowHelper.h" 27 #include "tensorflow/core/framework/op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/lib/core/errors.h" 35 class KnnSearchOpKernel :
public tensorflow::OpKernel {
37 explicit KnnSearchOpKernel(tensorflow::OpKernelConstruction* construction)
38 : OpKernel(construction) {
41 std::string metric_str;
42 OP_REQUIRES_OK(construction,
43 construction->GetAttr(
"metric", &metric_str));
44 if (metric_str ==
"L1")
49 OP_REQUIRES_OK(construction,
50 construction->GetAttr(
"ignore_query_point",
51 &ignore_query_point));
53 OP_REQUIRES_OK(construction, construction->GetAttr(
"return_distances",
57 void Compute(tensorflow::OpKernelContext* context)
override {
59 static_assert(
sizeof(int64) ==
sizeof(int64_t),
60 "int64 type is not compatible");
62 const Tensor&
points = context->input(0);
63 const Tensor& queries = context->input(1);
64 const Tensor& k_tensor = context->input(2);
65 const TensorShape k_shape(k_tensor.shape());
66 OP_REQUIRES(context, k_shape.dims() == 0,
67 errors::InvalidArgument(
"k must be a rank 0 tensor"));
68 const int k = k_tensor.scalar<
int32_t>()();
69 const Tensor& points_row_splits = context->input(3);
70 const Tensor& queries_row_splits = context->input(4);
74 Dim num_points(
"num_points");
75 Dim num_queries(
"num_queries");
76 Dim batch_size(
"batch_size");
79 CHECK_SHAPE(context, points_row_splits, batch_size + 1);
80 CHECK_SHAPE(context, queries_row_splits, batch_size + 1);
83 Tensor* query_neighbors_row_splits = 0;
84 TensorShape query_neighbors_row_splits_shape(
85 {queries.shape().dim_size(0) + 1});
86 OP_REQUIRES_OK(context, context->allocate_output(
87 1, query_neighbors_row_splits_shape,
88 &query_neighbors_row_splits));
90 Kernel(context, points, queries, k, points_row_splits,
91 queries_row_splits, *query_neighbors_row_splits);
94 virtual void Kernel(tensorflow::OpKernelContext* context,
95 const tensorflow::Tensor& points,
96 const tensorflow::Tensor& queries,
98 const tensorflow::Tensor& points_row_splits,
99 const tensorflow::Tensor& queries_row_splits,
100 tensorflow::Tensor& query_neighbors_row_splits) = 0;
104 bool ignore_query_point;
105 bool return_distances;
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:204
Metric
Supported metrics.
Definition: NeighborSearchCommon.h:38
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t int32_t
Definition: K4aPlugin.cpp:398
Definition: NanoFlannIndex.h:55
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:69
Definition: ContinuousConv.h:35
int points
Definition: FilePCD.cpp:73
Definition: NanoFlannIndex.h:55
Definition: ShapeChecking.h:35