Open3D (C++ API)  0.12.0
TensorFlowHelper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2020 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/shape_inference.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 
34 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
35  ::tensorflow::shape_inference::InferenceContext* c,
36  ::tensorflow::shape_inference::ShapeHandle shape_handle) {
37  using namespace open3d::ml::op_util;
38  if (!c->RankKnown(shape_handle)) {
39  return std::vector<DimValue>();
40  }
41 
42  std::vector<DimValue> shape;
43  const int rank = c->Rank(shape_handle);
44  for (int i = 0; i < rank; ++i) {
45  auto d = c->DimKnownRank(shape_handle, i);
46  if (c->ValueKnown(d)) {
47  shape.push_back(c->Value(d));
48  } else {
49  shape.push_back(DimValue());
50  }
51  }
52  return shape;
53 }
54 
56  class TDimX,
57  class... TArgs>
58 std::tuple<bool, std::string> CheckShape(
59  ::tensorflow::shape_inference::InferenceContext* c,
60  ::tensorflow::shape_inference::ShapeHandle shape_handle,
61  TDimX&& dimex,
62  TArgs&&... args) {
63  if (!c->RankKnown(shape_handle)) {
64  // without rank we cannot check
65  return std::make_tuple(true, std::string());
66  }
67  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
68  std::forward<TDimX>(dimex),
69  std::forward<TArgs>(args)...);
70 }
71 
72 inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
73  const tensorflow::Tensor& tensor) {
74  using namespace open3d::ml::op_util;
75 
76  std::vector<DimValue> shape;
77  for (int i = 0; i < tensor.dims(); ++i) {
78  shape.push_back(tensor.dim_size(i));
79  }
80  return shape;
81 }
82 
84  class TDimX,
85  class... TArgs>
86 std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
87  TDimX&& dimex,
88  TArgs&&... args) {
89  return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
90  std::forward<TDimX>(dimex),
91  std::forward<TArgs>(args)...);
92 }
93 
94 //
95 // Helper function for creating a ShapeHandle from dim expressions.
96 // Dim expressions which are not constant will translate to unknown dims in
97 // the returned shape handle.
98 //
99 // Usage:
100 // // ctx is of type tensorflow::shape_inference::InferenceContext*
101 // {
102 // using namespace open3d::ml::op_util;
103 // Dim w("w");
104 // Dim h("h");
105 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
106 // // 10 and assigns w and h
107 // // based on the shape of
108 // // handle1
109 //
110 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
111 // // last dim of handle2 matches the
112 // // last dim of handle1. The first
113 // // two dims must match 10, 20.
114 //
115 // ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
116 // ctx->set_output(0, out_shape);
117 // }
118 //
119 //
120 // See "../ShapeChecking.h" for more info and limitations.
121 //
122 template <class TDimX, class... TArgs>
123 ::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
124  ::tensorflow::shape_inference::InferenceContext* ctx,
125  TDimX&& dimex,
126  TArgs&&... args) {
127  using namespace tensorflow::shape_inference;
128  using namespace open3d::ml::op_util;
129  std::vector<int64_t> shape = CreateDimVector(
130  int64_t(InferenceContext::kUnknownDim), dimex, args...);
131  std::vector<DimensionHandle> dims;
132  for (int64_t d : shape) {
133  dims.push_back(ctx->MakeDim(d));
134  }
135  return ctx->MakeShape(dims);
136 }
137 
138 //
139 // Macros for checking the shape of ShapeHandle during shape inference.
140 //
141 // Usage:
142 // // ctx is of type tensorflow::shape_inference::InferenceContext*
143 // {
144 // using namespace open3d::ml::op_util;
145 // Dim w("w");
146 // Dim h("h");
147 // CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
148 // // 10 and assigns w and h
149 // // based on the shape of
150 // // handle1
151 //
152 // CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
153 // // last dim of handle2 matches the
154 // // last dim of handle1. The first
155 // // two dims must match 10, 20.
156 // }
157 //
158 //
159 // See "../ShapeChecking.h" for more info and limitations.
160 //
161 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
162  do { \
163  bool cs_success_; \
164  std::string cs_errstr_; \
165  std::tie(cs_success_, cs_errstr_) = \
166  CheckShape(ctx, shape_handle, __VA_ARGS__); \
167  if (TF_PREDICT_FALSE(!cs_success_)) { \
168  return tensorflow::errors::InvalidArgument( \
169  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
170  } \
171  } while (0)
172 
173 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
174  do { \
175  bool cs_success_; \
176  std::string cs_errstr_; \
177  std::tie(cs_success_, cs_errstr_) = \
178  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
179  __VA_ARGS__); \
180  if (TF_PREDICT_FALSE(!cs_success_)) { \
181  return tensorflow::errors::InvalidArgument( \
182  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
183  } \
184  } while (0)
185 
186 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
187  do { \
188  bool cs_success_; \
189  std::string cs_errstr_; \
190  std::tie(cs_success_, cs_errstr_) = \
191  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
192  __VA_ARGS__); \
193  if (TF_PREDICT_FALSE(!cs_success_)) { \
194  return tensorflow::errors::InvalidArgument( \
195  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
196  } \
197  } while (0)
198 
199 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
200  do { \
201  bool cs_success_; \
202  std::string cs_errstr_; \
203  std::tie(cs_success_, cs_errstr_) = \
204  CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
205  __VA_ARGS__); \
206  if (TF_PREDICT_FALSE(!cs_success_)) { \
207  return tensorflow::errors::InvalidArgument( \
208  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
209  } \
210  } while (0)
211 
212 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
213  do { \
214  bool cs_success_; \
215  std::string cs_errstr_; \
216  std::tie(cs_success_, cs_errstr_) = \
217  CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
218  __VA_ARGS__); \
219  if (TF_PREDICT_FALSE(!cs_success_)) { \
220  return tensorflow::errors::InvalidArgument( \
221  "invalid shape for '" #shape_handle "', " + cs_errstr_); \
222  } \
223  } while (0)
224 
225 //
226 // Macros for checking the shape of Tensors.
227 // Usage:
228 // // ctx is of type tensorflow::OpKernelContext*
229 // {
230 // using namespace open3d::ml::op_util;
231 // Dim w("w");
232 // Dim h("h");
233 // CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
234 // // and assigns w and h based on
235 // // the shape of tensor1
236 //
237 // CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
238 // // of tensor2 matches the last dim
239 // // of tensor1. The first two dims
240 // // must match 10, 20.
241 // }
242 //
243 //
244 // See "../ShapeChecking.h" for more info and limitations.
245 //
246 #define CHECK_SHAPE(ctx, tensor, ...) \
247  do { \
248  bool cs_success_; \
249  std::string cs_errstr_; \
250  std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
251  OP_REQUIRES( \
252  ctx, cs_success_, \
253  tensorflow::errors::InvalidArgument( \
254  "invalid shape for '" #tensor "', " + cs_errstr_)); \
255  } while (0)
256 
257 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
258  do { \
259  bool cs_success_; \
260  std::string cs_errstr_; \
261  std::tie(cs_success_, cs_errstr_) = \
262  CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
263  OP_REQUIRES( \
264  ctx, cs_success_, \
265  tensorflow::errors::InvalidArgument( \
266  "invalid shape for '" #tensor "', " + cs_errstr_)); \
267  } while (0)
268 
269 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
270  do { \
271  bool cs_success_; \
272  std::string cs_errstr_; \
273  std::tie(cs_success_, cs_errstr_) = \
274  CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
275  OP_REQUIRES( \
276  ctx, cs_success_, \
277  tensorflow::errors::InvalidArgument( \
278  "invalid shape for '" #tensor "', " + cs_errstr_)); \
279  } while (0)
280 
281 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
282  do { \
283  bool cs_success_; \
284  std::string cs_errstr_; \
285  std::tie(cs_success_, cs_errstr_) = \
286  CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
287  OP_REQUIRES( \
288  ctx, cs_success_, \
289  tensorflow::errors::InvalidArgument( \
290  "invalid shape for '" #tensor "', " + cs_errstr_)); \
291  } while (0)
292 
293 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
294  do { \
295  bool cs_success_; \
296  std::string cs_errstr_; \
297  std::tie(cs_success_, cs_errstr_) = \
298  CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
299  OP_REQUIRES( \
300  ctx, cs_success_, \
301  tensorflow::errors::InvalidArgument( \
302  "invalid shape for '" #tensor "', " + cs_errstr_)); \
303  } while (0)
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:34
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:123
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
Definition: ShapeChecking.h:35