9 #include <tensorflow/core/framework/op_kernel.h>
10 #include <tensorflow/core/framework/shape_inference.h>
11 #include <tensorflow/core/framework/tensor.h>
12 #include <tensorflow/core/lib/core/errors.h>
17 ::tensorflow::shape_inference::InferenceContext* c,
18 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
20 if (!c->RankKnown(shape_handle)) {
21 return std::vector<DimValue>();
24 std::vector<DimValue> shape;
25 const int rank = c->Rank(shape_handle);
26 for (
int i = 0; i < rank; ++i) {
27 auto d = c->DimKnownRank(shape_handle, i);
28 if (c->ValueKnown(d)) {
29 shape.push_back(c->Value(d));
41 ::tensorflow::shape_inference::InferenceContext* c,
42 ::tensorflow::shape_inference::ShapeHandle shape_handle,
45 if (!c->RankKnown(shape_handle)) {
47 return std::make_tuple(
true, std::string());
49 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(c, shape_handle),
50 std::forward<TDimX>(dimex),
51 std::forward<TArgs>(args)...);
55 const tensorflow::Tensor& tensor) {
58 std::vector<DimValue> shape;
59 for (
int i = 0; i < tensor.dims(); ++i) {
60 shape.push_back(tensor.dim_size(i));
68 std::tuple<bool, std::string>
CheckShape(
const tensorflow::Tensor& tensor,
71 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
72 std::forward<TDimX>(dimex),
73 std::forward<TArgs>(args)...);
104 template <
class TDimX,
class... TArgs>
106 ::tensorflow::shape_inference::InferenceContext* ctx,
109 using namespace tensorflow::shape_inference;
112 int64_t(InferenceContext::kUnknownDim), dimex, args...);
113 std::vector<DimensionHandle> dims;
114 for (int64_t d : shape) {
115 dims.push_back(ctx->MakeDim(d));
117 return ctx->MakeShape(dims);
143 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
146 std::string cs_errstr_; \
147 std::tie(cs_success_, cs_errstr_) = \
148 CheckShape(ctx, shape_handle, __VA_ARGS__); \
149 if (TF_PREDICT_FALSE(!cs_success_)) { \
150 return tensorflow::errors::InvalidArgument( \
151 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
155 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
158 std::string cs_errstr_; \
159 std::tie(cs_success_, cs_errstr_) = \
160 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
162 if (TF_PREDICT_FALSE(!cs_success_)) { \
163 return tensorflow::errors::InvalidArgument( \
164 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
168 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
171 std::string cs_errstr_; \
172 std::tie(cs_success_, cs_errstr_) = \
173 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
175 if (TF_PREDICT_FALSE(!cs_success_)) { \
176 return tensorflow::errors::InvalidArgument( \
177 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
181 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
184 std::string cs_errstr_; \
185 std::tie(cs_success_, cs_errstr_) = \
186 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
188 if (TF_PREDICT_FALSE(!cs_success_)) { \
189 return tensorflow::errors::InvalidArgument( \
190 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
194 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
197 std::string cs_errstr_; \
198 std::tie(cs_success_, cs_errstr_) = \
199 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
201 if (TF_PREDICT_FALSE(!cs_success_)) { \
202 return tensorflow::errors::InvalidArgument( \
203 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
228 #define CHECK_SHAPE(ctx, tensor, ...) \
231 std::string cs_errstr_; \
232 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
235 tensorflow::errors::InvalidArgument( \
236 "invalid shape for '" #tensor "', " + cs_errstr_)); \
239 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
242 std::string cs_errstr_; \
243 std::tie(cs_success_, cs_errstr_) = \
244 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
247 tensorflow::errors::InvalidArgument( \
248 "invalid shape for '" #tensor "', " + cs_errstr_)); \
251 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
254 std::string cs_errstr_; \
255 std::tie(cs_success_, cs_errstr_) = \
256 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
259 tensorflow::errors::InvalidArgument( \
260 "invalid shape for '" #tensor "', " + cs_errstr_)); \
263 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
266 std::string cs_errstr_; \
267 std::tie(cs_success_, cs_errstr_) = \
268 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
271 tensorflow::errors::InvalidArgument( \
272 "invalid shape for '" #tensor "', " + cs_errstr_)); \
275 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
278 std::string cs_errstr_; \
279 std::tie(cs_success_, cs_errstr_) = \
280 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
283 tensorflow::errors::InvalidArgument( \
284 "invalid shape for '" #tensor "', " + cs_errstr_)); \
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:40
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:105
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:16
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358