Open3D (C++ API)  0.17.0
TrilinearDevoxelizeKernel.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
10 #include "../TensorFlowHelper.h"
11 #include "tensorflow/core/framework/op.h"
12 #include "tensorflow/core/framework/op_kernel.h"
13 #include "tensorflow/core/lib/core/errors.h"
14 
15 class TrilinearDevoxelizeOpKernel : public tensorflow::OpKernel {
16 public:
18  tensorflow::OpKernelConstruction* context)
19  : tensorflow::OpKernel(context) {
20  using namespace tensorflow;
21  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
22  OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training));
23  OP_REQUIRES(context, r > 0,
24  errors::InvalidArgument(
25  "TrilinearDevoxelize expects positive resolution"));
26  }
27 
28  void Compute(tensorflow::OpKernelContext* context) override {
29  using namespace tensorflow;
30  const Tensor& coords = context->input(0);
31  OP_REQUIRES(
32  context, coords.dims() == 3 && coords.shape().dim_size(1) == 3,
33  errors::InvalidArgument("TrilinearDevoxelize expects "
34  "(batch_size, 3, N) coordinate shape"));
35  const Tensor& feat = context->input(1);
36  OP_REQUIRES(context, feat.dims() == 5,
37  errors::InvalidArgument("TrilinearDevoxelize expects "
38  "5 dimensions for features"));
39 
40  int batch_size = coords.shape().dim_size(0);
41  int num_points = coords.shape().dim_size(2);
42  int feat_dim = feat.shape().dim_size(1);
43 
44  auto coords_flat = coords.flat<float>();
45  auto feat_flat = feat.flat<float>();
46 
47  const float* inp_coords = &(coords_flat(0));
48  const float* inp_feat = &(feat_flat(0));
49 
50  Tensor* out_tensor_0;
51  OP_REQUIRES_OK(context,
52  context->allocate_output(
53  0, TensorShape{batch_size, feat_dim, num_points},
54  &out_tensor_0));
55  Tensor* out_tensor_1;
56  OP_REQUIRES_OK(context,
57  context->allocate_output(
58  1, TensorShape{batch_size, 8, num_points},
59  &out_tensor_1));
60  Tensor* out_tensor_2;
61  OP_REQUIRES_OK(context,
62  context->allocate_output(
63  2, TensorShape{batch_size, 8, num_points},
64  &out_tensor_2));
65  auto flat_0 = out_tensor_0->flat<float>();
66  auto flat_1 = out_tensor_1->flat<int>();
67  auto flat_2 = out_tensor_2->flat<float>();
68 
69  float* out_0 = &(flat_0(0));
70  int* out_1 = &(flat_1(0));
71  float* out_2 = &(flat_2(0));
72 
73  if (is_training)
74  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
75  r * r * r, true, inp_coords, inp_feat, out_1, out_2, out_0);
76  else
77  Kernel(context, batch_size, feat_dim, num_points, r, r * r,
78  r * r * r, false, inp_coords, inp_feat, out_1, out_2, out_0);
79  }
80 
81  virtual void Kernel(tensorflow::OpKernelContext* context,
82  int b,
83  int c,
84  int n,
85  int r,
86  int r2,
87  int r3,
88  bool training,
89  const float* coords,
90  const float* feat,
91  int* inds,
92  float* wgts,
93  float* outs) = 0;
94 
95 protected:
96  int r;
98 };
99 
100 class TrilinearDevoxelizeGradOpKernel : public tensorflow::OpKernel {
101 public:
103  tensorflow::OpKernelConstruction* context)
104  : tensorflow::OpKernel(context) {
105  using namespace tensorflow;
106  OP_REQUIRES_OK(context, context->GetAttr("resolution", &r));
107  OP_REQUIRES(
108  context, r > 0,
109  errors::InvalidArgument(
110  "TrilinearDevoxelizeGrad expects positive resolution"));
111  }
112 
113  void Compute(tensorflow::OpKernelContext* context) override {
114  using namespace tensorflow;
115  const Tensor& grad_y = context->input(0);
116  OP_REQUIRES(
117  context, grad_y.dims() == 3,
118  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
119  "(batch_size, C, N) gradient shape"));
120  const Tensor& inds = context->input(1);
121  OP_REQUIRES(
122  context, inds.dims() == 3 && inds.shape().dim_size(1) == 8,
123  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
124  "(batch_size, 8, N) indices shape"));
125  const Tensor& wgts = context->input(2);
126  OP_REQUIRES(
127  context, wgts.dims() == 3 && wgts.shape().dim_size(1) == 8,
128  errors::InvalidArgument("TrilinearDevoxelizeGrad expects "
129  "(batch_size, 8, N) weights shape"));
130 
131  int batch_size = grad_y.shape().dim_size(0);
132  int num_points = grad_y.shape().dim_size(2);
133  int feat_dim = grad_y.shape().dim_size(1);
134 
135  auto grad_y_flat = grad_y.flat<float>();
136  auto inds_flat = inds.flat<int>();
137  auto wgts_flat = wgts.flat<float>();
138 
139  const float* inp_grad_y = &(grad_y_flat(0));
140  const int* inp_inds = &(inds_flat(0));
141  const float* inp_wgts = &(wgts_flat(0));
142 
143  Tensor* out_tensor;
144  OP_REQUIRES_OK(context,
145  context->allocate_output(
146  0, TensorShape{batch_size, feat_dim, r, r, r},
147  &out_tensor));
148  auto flat_tensor = out_tensor->flat<float>();
149 
150  float* out = &(flat_tensor(0));
151 
152  Kernel(context, batch_size, feat_dim, num_points, r * r * r, inp_inds,
153  inp_wgts, inp_grad_y, out);
154  }
155 
156  virtual void Kernel(tensorflow::OpKernelContext* context,
157  int b,
158  int c,
159  int n,
160  int r3,
161  const int* inds,
162  const float* wgts,
163  const float* grad_y,
164  float* grad_x) = 0;
165 
166 protected:
167  int r;
168 };
ImGuiContext * context
Definition: Window.cpp:76
Definition: TrilinearDevoxelizeKernel.h:100
int r
Definition: TrilinearDevoxelizeKernel.h:167
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:113
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r3, const int *inds, const float *wgts, const float *grad_y, float *grad_x)=0
TrilinearDevoxelizeGradOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:102
Definition: TrilinearDevoxelizeKernel.h:15
virtual void Kernel(tensorflow::OpKernelContext *context, int b, int c, int n, int r, int r2, int r3, bool training, const float *coords, const float *feat, int *inds, float *wgts, float *outs)=0
TrilinearDevoxelizeOpKernel(tensorflow::OpKernelConstruction *context)
Definition: TrilinearDevoxelizeKernel.h:17
void Compute(tensorflow::OpKernelContext *context) override
Definition: TrilinearDevoxelizeKernel.h:28
int r
Definition: TrilinearDevoxelizeKernel.h:96
bool is_training
Definition: TrilinearDevoxelizeKernel.h:97