Open3D (C++ API)  0.18.0
TensorFlowHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #include <tensorflow/core/framework/op_kernel.h>
10 #include <tensorflow/core/framework/shape_inference.h>
11 #include <tensorflow/core/framework/tensor.h>
12 #include <tensorflow/core/lib/core/errors.h>
13 
15 
16 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
17  ::tensorflow::shape_inference::InferenceContext* c,
18  ::tensorflow::shape_inference::ShapeHandle shape_handle) {
19  using namespace open3d::ml::op_util;
20  if (!c->RankKnown(shape_handle)) {
21  return std::vector<DimValue>();
22  }
23 
24  std::vector<DimValue> shape;
25  const int rank = c->Rank(shape_handle);
26  for (int i = 0; i < rank; ++i) {
27  auto d = c->DimKnownRank(shape_handle, i);
28  if (c->ValueKnown(d)) {
29  shape.push_back(c->Value(d));
30  } else {
31  shape.push_back(DimValue());
32  }
33  }
34  return shape;
35 }
36 
38  class TDimX,
39  class... TArgs>
40 std::tuple<bool, std::string> CheckShape(
41  ::tensorflow::shape_inference::InferenceContext* c,
42  ::tensorflow::shape_inference::ShapeHandle shape_handle,
43  TDimX&& dimex,
44  TArgs&&... args) {
45  if (!c->RankKnown(shape_handle)) {
46  // without rank we cannot check
47  return std::make_tuple(true, std::string());
48  }
49  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
50  std::forward<TDimX>(dimex),
51  std::forward<TArgs>(args)...);
52 }
53 
54 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
55  const tensorflow::Tensor& tensor) {
56  using namespace open3d::ml::op_util;
57 
58  std::vector<DimValue> shape;
59  for (int i = 0; i < tensor.dims(); ++i) {
60  shape.push_back(tensor.dim_size(i));
61  }
62  return shape;
63 }
64 
66  class TDimX,
67  class... TArgs>
68 std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
69  TDimX&& dimex,
70  TArgs&&... args) {
71  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
72  std::forward<TDimX>(dimex),
73  std::forward<TArgs>(args)...);
74 }
75 
76 //
77 // Helper function for creating a ShapeHandle from dim expressions.
78 // Dim expressions which are not constant will translate to unknown dims in
79 // the returned shape handle.
80 //
81 // Usage:
82 // // ctx is of type tensorflow::shape_inference::InferenceContext*
83 // {
84 // using namespace open3d::ml::op_util;
85 // Dim w("w");
86 // Dim h("h");
87 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
88 // // 10 and assigns w and h
89 // // based on the shape of
90 // // handle1
91 //
92 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
93 // // last dim of handle2 matches the
94 // // last dim of handle1. The first
95 // // two dims must match 10, 20.
96 //
97 // ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
98 // ctx->set_output(0, out_shape);
99 // }
100 //
101 //
102 // See "../ShapeChecking.h" for more info and limitations.
103 //
104 template <class TDimX, class... TArgs>
105 ::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
106  ::tensorflow::shape_inference::InferenceContext* ctx,
107  TDimX&& dimex,
108  TArgs&&... args) {
109  using namespace tensorflow::shape_inference;
110  using namespace open3d::ml::op_util;
111  std::vector<int64_t> shape = CreateDimVector(
112  int64_t(InferenceContext::kUnknownDim), dimex, args...);
113  std::vector<DimensionHandle> dims;
114  for (int64_t d : shape) {
115  dims.push_back(ctx->MakeDim(d));
116  }
117  return ctx->MakeShape(dims);
118 }
119 
120 //
121 // Macros for checking the shape of ShapeHandle during shape inference.
122 //
123 // Usage:
124 // // ctx is of type tensorflow::shape_inference::InferenceContext*
125 // {
126 // using namespace open3d::ml::op_util;
127 // Dim w("w");
128 // Dim h("h");
129 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
130 // // 10 and assigns w and h
131 // // based on the shape of
132 // // handle1
133 //
134 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
135 // // last dim of handle2 matches the
136 // // last dim of handle1. The first
137 // // two dims must match 10, 20.
138 // }
139 //
140 //
141 // See "../ShapeChecking.h" for more info and limitations.
142 //
143 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
144  do { \
145  bool cs_success_; \
146  std::string cs_errstr_; \
147  std::tie(cs_success_, cs_errstr_) = \
148  CheckShape(ctx, shape_handle, __VA_ARGS__); \
149  if (TF_PREDICT_FALSE(!cs_success_)) { \
150  return tensorflow::errors::InvalidArgument( \
151  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
152  } \
153  } while (0)
154 
155 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
156  do { \
157  bool cs_success_; \
158  std::string cs_errstr_; \
159  std::tie(cs_success_, cs_errstr_) = \
160  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
161  __VA_ARGS__); \
162  if (TF_PREDICT_FALSE(!cs_success_)) { \
163  return tensorflow::errors::InvalidArgument( \
164  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
165  } \
166  } while (0)
167 
168 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
169  do { \
170  bool cs_success_; \
171  std::string cs_errstr_; \
172  std::tie(cs_success_, cs_errstr_) = \
173  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
174  __VA_ARGS__); \
175  if (TF_PREDICT_FALSE(!cs_success_)) { \
176  return tensorflow::errors::InvalidArgument( \
177  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
178  } \
179  } while (0)
180 
181 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
182  do { \
183  bool cs_success_; \
184  std::string cs_errstr_; \
185  std::tie(cs_success_, cs_errstr_) = \
186  CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
187  __VA_ARGS__); \
188  if (TF_PREDICT_FALSE(!cs_success_)) { \
189  return tensorflow::errors::InvalidArgument( \
190  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
191  } \
192  } while (0)
193 
194 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
195  do { \
196  bool cs_success_; \
197  std::string cs_errstr_; \
198  std::tie(cs_success_, cs_errstr_) = \
199  CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
200  __VA_ARGS__); \
201  if (TF_PREDICT_FALSE(!cs_success_)) { \
202  return tensorflow::errors::InvalidArgument( \
203  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
204  } \
205  } while (0)
206 
207 //
208 // Macros for checking the shape of Tensors.
209 // Usage:
210 // // ctx is of type tensorflow::OpKernelContext*
211 // {
212 // using namespace open3d::ml::op_util;
213 // Dim w("w");
214 // Dim h("h");
215 // CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
216 // // and assigns w and h based on
217 // // the shape of tensor1
218 //
219 // CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
220 // // of tensor2 matches the last dim
221 // // of tensor1. The first two dims
222 // // must match 10, 20.
223 // }
224 //
225 //
226 // See "../ShapeChecking.h" for more info and limitations.
227 //
228 #define CHECK_SHAPE(ctx, tensor, ...) \
229  do { \
230  bool cs_success_; \
231  std::string cs_errstr_; \
232  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
233  OP_REQUIRES( \
234  ctx, cs_success_, \
235  tensorflow::errors::InvalidArgument( \
236  "invalid shape for '" #tensor "', " + cs_errstr_)); \
237  } while (0)
238 
239 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
240  do { \
241  bool cs_success_; \
242  std::string cs_errstr_; \
243  std::tie(cs_success_, cs_errstr_) = \
244  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
245  OP_REQUIRES( \
246  ctx, cs_success_, \
247  tensorflow::errors::InvalidArgument( \
248  "invalid shape for '" #tensor "', " + cs_errstr_)); \
249  } while (0)
250 
251 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
252  do { \
253  bool cs_success_; \
254  std::string cs_errstr_; \
255  std::tie(cs_success_, cs_errstr_) = \
256  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
257  OP_REQUIRES( \
258  ctx, cs_success_, \
259  tensorflow::errors::InvalidArgument( \
260  "invalid shape for '" #tensor "', " + cs_errstr_)); \
261  } while (0)
262 
263 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
264  do { \
265  bool cs_success_; \
266  std::string cs_errstr_; \
267  std::tie(cs_success_, cs_errstr_) = \
268  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
269  OP_REQUIRES( \
270  ctx, cs_success_, \
271  tensorflow::errors::InvalidArgument( \
272  "invalid shape for '" #tensor "', " + cs_errstr_)); \
273  } while (0)
274 
275 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
276  do { \
277  bool cs_success_; \
278  std::string cs_errstr_; \
279  std::tie(cs_success_, cs_errstr_) = \
280  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
281  OP_REQUIRES( \
282  ctx, cs_success_, \
283  tensorflow::errors::InvalidArgument( \
284  "invalid shape for '" #tensor "', " + cs_errstr_)); \
285  } while (0)
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:40
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:105
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:16
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358