29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/shape_inference.h" 31 #include "tensorflow/core/framework/tensor.h" 32 #include "tensorflow/core/lib/core/errors.h" 35 ::tensorflow::shape_inference::InferenceContext* c,
36 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
38 if (!c->RankKnown(shape_handle)) {
39 return std::vector<DimValue>();
42 std::vector<DimValue> shape;
43 const int rank = c->Rank(shape_handle);
44 for (
int i = 0; i < rank; ++i) {
45 auto d = c->DimKnownRank(shape_handle, i);
46 if (c->ValueKnown(d)) {
47 shape.push_back(c->Value(d));
59 ::tensorflow::shape_inference::InferenceContext* c,
60 ::tensorflow::shape_inference::ShapeHandle shape_handle,
63 if (!c->RankKnown(shape_handle)) {
65 return std::make_tuple(
true, std::string());
67 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(c, shape_handle),
68 std::forward<TDimX>(dimex),
69 std::forward<TArgs>(args)...);
73 const tensorflow::Tensor& tensor) {
76 std::vector<DimValue> shape;
77 for (
int i = 0; i < tensor.dims(); ++i) {
78 shape.push_back(tensor.dim_size(i));
86 std::tuple<bool, std::string>
CheckShape(
const tensorflow::Tensor& tensor,
89 return open3d::ml::op_util::CheckShape<Opt>(
GetShapeVector(tensor),
90 std::forward<TDimX>(dimex),
91 std::forward<TArgs>(args)...);
122 template <
class TDimX,
class... TArgs>
124 ::tensorflow::shape_inference::InferenceContext* ctx,
127 using namespace tensorflow::shape_inference;
130 int64_t(InferenceContext::kUnknownDim), dimex, args...);
131 std::vector<DimensionHandle> dims;
132 for (int64_t d : shape) {
133 dims.push_back(ctx->MakeDim(d));
135 return ctx->MakeShape(dims);
161 #define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \ 164 std::string cs_errstr_; \ 165 std::tie(cs_success_, cs_errstr_) = \ 166 CheckShape(ctx, shape_handle, __VA_ARGS__); \ 167 if (TF_PREDICT_FALSE(!cs_success_)) { \ 168 return tensorflow::errors::InvalidArgument( \ 169 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 173 #define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \ 176 std::string cs_errstr_; \ 177 std::tie(cs_success_, cs_errstr_) = \ 178 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \ 180 if (TF_PREDICT_FALSE(!cs_success_)) { \ 181 return tensorflow::errors::InvalidArgument( \ 182 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 186 #define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \ 189 std::string cs_errstr_; \ 190 std::tie(cs_success_, cs_errstr_) = \ 191 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \ 193 if (TF_PREDICT_FALSE(!cs_success_)) { \ 194 return tensorflow::errors::InvalidArgument( \ 195 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 199 #define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \ 202 std::string cs_errstr_; \ 203 std::tie(cs_success_, cs_errstr_) = \ 204 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \ 206 if (TF_PREDICT_FALSE(!cs_success_)) { \ 207 return tensorflow::errors::InvalidArgument( \ 208 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 212 #define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \ 215 std::string cs_errstr_; \ 216 std::tie(cs_success_, cs_errstr_) = \ 217 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \ 219 if (TF_PREDICT_FALSE(!cs_success_)) { \ 220 return tensorflow::errors::InvalidArgument( \ 221 "invalid shape for '" #shape_handle "', " + cs_errstr_); \ 246 #define CHECK_SHAPE(ctx, tensor, ...) \ 249 std::string cs_errstr_; \ 250 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \ 253 tensorflow::errors::InvalidArgument( \ 254 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 257 #define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \ 260 std::string cs_errstr_; \ 261 std::tie(cs_success_, cs_errstr_) = \ 262 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 265 tensorflow::errors::InvalidArgument( \ 266 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 269 #define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \ 272 std::string cs_errstr_; \ 273 std::tie(cs_success_, cs_errstr_) = \ 274 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \ 277 tensorflow::errors::InvalidArgument( \ 278 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 281 #define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \ 284 std::string cs_errstr_; \ 285 std::tie(cs_success_, cs_errstr_) = \ 286 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \ 289 tensorflow::errors::InvalidArgument( \ 290 "invalid shape for '" #tensor "', " + cs_errstr_)); \ 293 #define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \ 296 std::string cs_errstr_; \ 297 std::tie(cs_success_, cs_errstr_) = \ 298 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \ 301 tensorflow::errors::InvalidArgument( \ 302 "invalid shape for '" #tensor "', " + cs_errstr_)); \ Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:38
CSOpt
Check shape options.
Definition: ShapeChecking.h:424
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:377
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:34
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:123
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:593
Definition: ShapeChecking.h:35