12 #include "tensorflow/core/framework/op.h"
13 #include "tensorflow/core/framework/op_kernel.h"
14 #include "tensorflow/core/lib/core/errors.h"
18 namespace nms_opkernel {
20 class OutputAllocator {
24 void AllocKeepIndices(int64_t** ptr, int64_t num) {
25 using namespace tensorflow;
28 TensorShape shape({num});
29 OP_REQUIRES_OK(
context,
context->allocate_output(0, shape, &tensor));
30 auto flat_tensor = tensor->flat<int64>();
31 *ptr = (int64_t*)flat_tensor.data();
35 tensorflow::OpKernelContext*
context;
39 class NmsOpKernel :
public tensorflow::OpKernel {
41 explicit NmsOpKernel(tensorflow::OpKernelConstruction* construction)
42 : OpKernel(construction) {
43 OP_REQUIRES_OK(construction,
44 construction->GetAttr(
"nms_overlap_thresh",
45 &nms_overlap_thresh));
48 void Compute(tensorflow::OpKernelContext*
context)
override {
49 using namespace tensorflow;
50 const Tensor& boxes =
context->input(0);
51 const Tensor& scores =
context->input(1);
55 Dim num_points(
"num_points");
65 virtual void Kernel(tensorflow::OpKernelContext*
context,
66 const tensorflow::Tensor& boxes,
67 const tensorflow::Tensor& scores) = 0;
70 float nms_overlap_thresh;
#define CHECK_SHAPE(tensor,...)
Definition: TorchHelper.h:186
ImGuiContext * context
Definition: Window.cpp:76
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
Definition: ShapeChecking.h:16