Open3D (C++ API)  0.17.0
ShapeChecking.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 #include <iostream>
10 #include <string>
11 #include <tuple>
12 #include <vector>
13 
14 namespace open3d {
15 namespace ml {
16 namespace op_util {
17 
19 class DimValue {
20 public:
21  DimValue() : value_(0), constant_(false) {}
22  DimValue(int64_t v) : value_(v), constant_(true) {}
24  if (constant_ && b.constant_)
25  value_ *= b.value_;
26  else
27  constant_ = false;
28  return *this;
29  }
30  std::string ToString() const {
31  if (constant_)
32  return std::to_string(value_);
33  else
34  return "?";
35  }
36  int64_t& value() {
37  if (!constant_) throw std::runtime_error("DimValue is not constant");
38  return value_;
39  }
40  bool& constant() { return constant_; }
41 
42 private:
43  int64_t value_;
44  bool constant_;
45 };
46 
47 inline DimValue UnknownValue() { return DimValue(); }
48 
50 class Dim {
51 public:
52  explicit Dim() : value_(0), constant_(false), origin_(this) {}
53 
54  explicit Dim(const std::string& name)
55  : value_(0), constant_(false), origin_(this), name_(name) {}
56 
57  Dim(int64_t value, const std::string& name = "")
58  : value_(value), constant_(true), origin_(nullptr), name_(name) {}
59 
60  Dim(const Dim& other)
61  : value_(other.value_),
62  constant_(other.constant_),
63  origin_(other.origin_),
64  name_(other.name_) {}
65 
66  ~Dim() {}
67 
68  Dim& operator=(const Dim&) = delete;
69 
70  int64_t& value() {
71  if (origin_)
72  return origin_->value_;
73  else
74  return value_;
75  }
76 
77  bool& constant() {
78  if (origin_)
79  return origin_->constant_;
80  else
81  return constant_;
82  }
83 
86  bool assign(int64_t a) {
87  if (!constant()) {
88  value() = a;
89  constant() = true;
90  }
91  return value() == a;
92  }
93 
94  std::string ToString(bool show_value = true) {
95  if (name_.size()) {
96  if (show_value)
97  return name_ + "(" +
98  (constant() ? std::to_string(value()) : "?") + ")";
99  else
100  return name_;
101  }
102  if (constant())
103  return std::to_string(value());
104  else
105  return "?";
106  }
107 
108 private:
109  int64_t value_;
110  bool constant_;
111  Dim* origin_;
112  std::string name_;
113 };
114 
115 //
116 // Dim expression operator classes
117 //
118 
119 struct DimXPlus {
120  static bool constant() { return true; };
121  static int64_t apply(int64_t a, int64_t b) { return a + b; }
122 
123  template <class T1, class T2>
124  static bool backprop(int64_t ans, T1 a, T2 b) {
125  if (!a.constant() && a.constant() == b.constant()) {
126  std::string exstr =
127  GetString(a, false) + ToString() + GetString(b, false);
128  throw std::runtime_error("Illegal dim expression: " + exstr);
129  return false;
130  } else if (!a.constant()) {
131  return a.assign(ans - b.value());
132  } else {
133  return b.assign(ans - a.value());
134  }
135  }
136 
137  static std::string ToString() { return "+"; }
138 };
139 
140 struct DimXMinus {
141  static bool constant() { return true; };
142  static int64_t apply(int64_t a, int64_t b) { return a - b; }
143 
144  template <class T1, class T2>
145  static bool backprop(int64_t ans, T1 a, T2 b) {
146  if (!a.constant() && a.constant() == b.constant()) {
147  std::string exstr =
148  GetString(a, false) + ToString() + GetString(b, false);
149  throw std::runtime_error("Illegal dim expression: " + exstr);
150  return false;
151  } else if (!a.constant()) {
152  return a.assign(ans + b.value());
153  } else {
154  return b.assign(a.value() - ans);
155  }
156  }
157 
158  static std::string ToString() { return "-"; }
159 };
160 
161 struct DimXMultiply {
162  static bool constant() { return true; };
163  static int64_t apply(int64_t a, int64_t b) { return a * b; }
164 
165  template <class T1, class T2>
166  static bool backprop(int64_t ans, T1 a, T2 b) {
167  std::string exstr =
168  GetString(a, false) + ToString() + GetString(b, false);
169  throw std::runtime_error("Illegal dim expression: " + exstr);
170  return false;
171  }
172 
173  static std::string ToString() { return "*"; }
174 };
175 
176 struct DimXDivide {
177  static bool constant() { return true; };
178  static int64_t apply(int64_t a, int64_t b) { return a / b; }
179 
180  template <class T1, class T2>
181  static bool backprop(int64_t ans, T1 a, T2 b) {
182  std::string exstr =
183  GetString(a, false) + ToString() + GetString(b, false);
184  throw std::runtime_error("Illegal dim expression: " + exstr);
185  return false;
186  }
187 
188  static std::string ToString() { return "/"; }
189 };
190 
191 struct DimXOr {
192  static bool constant() { return false; };
193  static int64_t apply(int64_t a, int64_t b) {
194  throw std::runtime_error("Cannot evaluate OR expression");
195  return 0;
196  }
197  template <class T1, class T2>
198  static bool backprop(int64_t ans, T1 a, T2 b) {
199  return a.assign(ans) || b.assign(ans);
200  }
201 
202  static std::string ToString() { return "||"; }
203 };
204 
206 template <class TLeft, class TRight, class TOp>
207 class DimX {
208 public:
209  static DimX<TLeft, TRight, TOp> Create(TLeft left, TRight right) {
210  return DimX(left, right);
211  }
212 
213  int64_t value() {
214  if (constant_) {
215  return TOp::apply(left_.value(), right_.value());
216  }
217  return 0;
218  }
219 
220  bool& constant() { return constant_; }
221 
223  bool assign(int64_t a) {
224  if (constant_) {
225  return value() == a;
226  } else {
227  return TOp::backprop(a, left_, right_);
228  }
229  }
230 
231  std::string ToString(bool show_value = true) {
232  return left_.ToString(show_value) + TOp::ToString() +
233  right_.ToString(show_value);
234  }
235 
236 private:
237  DimX(TLeft left, TRight right) : left_(left), right_(right) {
238  constant_ = left.constant() && right.constant() && TOp::constant();
239  }
240  TLeft left_;
241  TRight right_;
242  bool constant_;
243 };
244 
245 //
246 // define operators for dim expressions
247 //
248 
249 #define DEFINE_DIMX_OPERATOR(opclass, symbol) \
250  inline DimX<Dim, Dim, opclass> operator symbol(Dim a, Dim b) { \
251  return DimX<Dim, Dim, opclass>::Create(a, b); \
252  } \
253  \
254  template <class TL, class TR, class TOp> \
255  inline DimX<Dim, DimX<TL, TR, TOp>, opclass> operator symbol( \
256  Dim a, DimX<TL, TR, TOp>&& b) { \
257  return DimX<Dim, DimX<TL, TR, TOp>, opclass>::Create(a, b); \
258  } \
259  \
260  template <class TL, class TR, class TOp> \
261  inline DimX<DimX<TL, TR, TOp>, Dim, opclass> operator symbol( \
262  DimX<TL, TR, TOp>&& a, Dim b) { \
263  return DimX<DimX<TL, TR, TOp>, Dim, opclass>::Create(a, b); \
264  } \
265  \
266  template <class TL1, class TR1, class TOp1, class TL2, class TR2, \
267  class TOp2> \
268  inline DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, opclass> \
269  operator symbol(DimX<TL1, TR1, TOp1>&& a, DimX<TL2, TR2, TOp2>&& b) { \
270  return DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, \
271  opclass>::Create(a, b); \
272  }
273 
274 DEFINE_DIMX_OPERATOR(DimXPlus, +)
275 DEFINE_DIMX_OPERATOR(DimXMinus, -)
276 DEFINE_DIMX_OPERATOR(DimXMultiply, *)
277 DEFINE_DIMX_OPERATOR(DimXDivide, /)
278 DEFINE_DIMX_OPERATOR(DimXOr, ||)
279 #undef DEFINE_DIMX_OPERATOR
280 
281 //
282 // define operators for comparing DimValue to dim expressions.
283 // Using these operators will try to assign the dim value to the expression.
284 //
285 
286 template <class TLeft, class TRight, class TOp>
288  if (a.constant()) {
289  auto b_copy(b);
290  return b_copy.assign(a.value());
291  } else
292  return true;
293 }
294 
295 inline bool operator==(DimValue a, Dim b) {
296  if (a.constant())
297  return b.assign(a.value());
298  else
299  return true;
300 }
301 
302 //
303 // some helper classes
304 //
305 
306 template <class... args>
307 struct CountArgs {
308  static const size_t value = sizeof...(args);
309 };
310 
311 template <class TLeft, class TRight, class TOp>
312 std::string GetString(DimX<TLeft, TRight, TOp> a, bool show_value = true) {
313  return a.ToString(show_value);
314 }
315 
316 inline std::string GetString(Dim a, bool show_value = true) {
317  return a.ToString(show_value);
318 }
319 
320 template <class TLeft, class TRight, class TOp>
322  return a.value();
323 }
324 
325 template <class TLeft, class TRight, class TOp>
326 int64_t GetValue(DimX<TLeft, TRight, TOp> a, int64_t unknown_dim_value) {
327  if (a.constant()) {
328  return a.value();
329  } else {
330  return unknown_dim_value;
331  }
332  return a.value();
333 }
334 
335 inline int64_t GetValue(Dim a) { return a.value(); }
336 
337 inline int64_t GetValue(Dim a, int64_t unknown_dim_value) {
338  if (a.constant()) {
339  return a.value();
340  } else {
341  return unknown_dim_value;
342  }
343 }
344 
345 inline std::string CreateDimXString() { return std::string(); }
346 
347 template <class TDimX>
348 std::string CreateDimXString(TDimX dimex) {
349  return GetString(dimex);
350 }
351 
352 template <class TDimX, class... TArgs>
353 std::string CreateDimXString(TDimX dimex, TArgs... args) {
354  return GetString(dimex) + ", " + CreateDimXString(args...);
355 }
356 
357 template <class TDimX>
358 void CreateDimVector(std::vector<int64_t>& out,
359  int64_t unknown_dim_value,
360  TDimX dimex) {
361  out.push_back(GetValue(dimex, unknown_dim_value));
362 }
363 
364 template <class TDimX, class... TArgs>
365 void CreateDimVector(std::vector<int64_t>& out,
366  int64_t unknown_dim_value,
367  TDimX dimex,
368  TArgs... args) {
369  out.push_back(GetValue(dimex, unknown_dim_value));
370  CreateDimVector(out, unknown_dim_value, args...);
371 }
372 
373 template <class TDimX>
374 std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value, TDimX dimex) {
375  std::vector<int64_t> out;
376  CreateDimVector(out, unknown_dim_value, dimex);
377  return out;
378 }
379 
380 template <class TDimX, class... TArgs>
381 std::vector<int64_t> CreateDimVector(int64_t unknown_dim_value,
382  TDimX dimex,
383  TArgs... args) {
384  std::vector<int64_t> out;
385  CreateDimVector(out, unknown_dim_value, dimex, args...);
386  return out;
387 }
388 
389 //
390 // classes which check if the dim value is compatible with the expression
391 //
392 
393 template <class TLeft, class TRight, class TOp>
394 bool CheckDim(const DimValue& lhs, DimX<TLeft, TRight, TOp>&& rhs) {
395  bool status = (lhs == std::forward<DimX<TLeft, TRight, TOp>>(rhs));
396  return status;
397 }
398 
399 inline bool CheckDim(const DimValue& lhs, Dim d) {
400  bool status = lhs == d;
401  return status;
402 }
403 
405 enum class CSOpt {
406  NONE = 0,
411 };
412 
413 template <CSOpt Opt = CSOpt::NONE, class TDimX>
414 bool _CheckShape(const std::vector<DimValue>& shape, TDimX&& dimex) {
415  // check rank
416  const int rank_diff = shape.size() - 1;
417  if (Opt != CSOpt::NONE) {
418  if (rank_diff < 0) {
419  return false;
420  }
421  } else {
422  if (rank_diff != 0) {
423  return false;
424  }
425  }
426 
427  // check dim
428  bool status;
429  if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
430  DimValue s(1);
431  for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
432  status = CheckDim(s, std::forward<TDimX>(dimex));
433  } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
434  status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
435  } else if (Opt == CSOpt::COMBINE_LAST_DIMS) {
436  DimValue s(1);
437  for (DimValue x : shape) s *= x;
438  status = CheckDim(s, std::forward<TDimX>(dimex));
439  } else {
440  status = CheckDim(shape[0], std::forward<TDimX>(dimex));
441  }
442  return status;
443 }
444 
445 template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
446 bool _CheckShape(const std::vector<DimValue>& shape,
447  TDimX&& dimex,
448  TArgs&&... args) {
449  // check rank
450  const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
451  if (Opt != CSOpt::NONE) {
452  if (rank_diff < 0) {
453  return false;
454  }
455  } else {
456  if (rank_diff != 0) {
457  return false;
458  }
459  }
460 
461  // check dim
462  bool status;
463  if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
464  DimValue s(1);
465  for (int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
466  status = CheckDim(s, std::forward<TDimX>(dimex));
467  } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
468  status = CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
469  } else {
470  status = CheckDim(shape[0], std::forward<TDimX>(dimex));
471  }
472 
473  const int offset = 1 + (Opt == CSOpt::COMBINE_FIRST_DIMS ||
475  ? rank_diff
476  : 0);
477  std::vector<DimValue> shape2(shape.begin() + offset, shape.end());
478  bool status2 = _CheckShape<Opt>(shape2, std::forward<TArgs>(args)...);
479 
480  return status && status2;
481 }
482 
573 template <CSOpt Opt = CSOpt::NONE, class TDimX, class... TArgs>
574 std::tuple<bool, std::string> CheckShape(const std::vector<DimValue>& shape,
575  TDimX&& dimex,
576  TArgs&&... args) {
577  const bool status = _CheckShape<Opt>(shape, std::forward<TDimX>(dimex),
578  std::forward<TArgs>(args)...);
579  if (status) {
580  return std::make_tuple(status, std::string());
581  } else {
582  const int rank_diff = shape.size() - (CountArgs<TArgs...>::value + 1);
583 
584  // generate string for the actual shape. This is a bit involved because
585  // of the many options.
586  std::string shape_str;
587  if (rank_diff <= 0) {
588  shape_str = "[";
589  for (int i = 0; i < int(shape.size()); ++i) {
590  shape_str += shape[i].ToString();
591  if (i + 1 < int(shape.size())) shape_str += ", ";
592  }
593  shape_str += "]";
594  } else {
595  if (Opt == CSOpt::COMBINE_FIRST_DIMS) {
596  shape_str += "[";
597  for (int i = 0; i < rank_diff; ++i) {
598  shape_str += shape[i].ToString();
599  if (i + 1 < int(shape.size())) shape_str += "*";
600  }
601  } else if (Opt == CSOpt::IGNORE_FIRST_DIMS) {
602  shape_str += "(";
603  for (int i = 0; i < rank_diff; ++i) {
604  shape_str += shape[i].ToString();
605  if (i + 1 < rank_diff) shape_str += ", ";
606  }
607  shape_str += ")[";
608  } else {
609  shape_str = "[";
610  }
611  int start = 0;
612  if (Opt == CSOpt::COMBINE_FIRST_DIMS ||
613  Opt == CSOpt::IGNORE_FIRST_DIMS) {
614  start = rank_diff;
615  }
616 
617  int end = shape.size();
618  if (Opt == CSOpt::COMBINE_LAST_DIMS) {
619  end -= rank_diff + 1;
620  } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
621  end -= rank_diff;
622  }
623  for (int i = start; i < end; ++i) {
624  shape_str += shape[i].ToString();
625  if (i + 1 < end) shape_str += ", ";
626  }
627  if (Opt == CSOpt::COMBINE_LAST_DIMS) {
628  shape_str += ", ";
629  for (int i = std::max<int>(0, shape.size() - rank_diff - 1);
630  i < int(shape.size()); ++i) {
631  shape_str += shape[i].ToString();
632  if (i + 1 < int(shape.size())) shape_str += "*";
633  }
634  shape_str += "]";
635  } else if (Opt == CSOpt::IGNORE_LAST_DIMS) {
636  shape_str += "](";
637  for (int i = std::max<int>(0, shape.size() - rank_diff);
638  i < int(shape.size()); ++i) {
639  shape_str += shape[i].ToString();
640  if (i + 1 < int(shape.size())) shape_str += ", ";
641  }
642  shape_str += ")";
643  } else {
644  shape_str += "]";
645  }
646  }
647 
648  // generate string for the expected shape with the dim expressions
649  std::string expected_shape;
650  if ((CountArgs<TArgs...>::value + 1) == 1) {
651  expected_shape = "[" + GetString(dimex) + "]";
652 
653  } else {
654  expected_shape = "[" + GetString(dimex) + ", " +
655  CreateDimXString(args...) + "]";
656  }
657 
658  std::string errstr;
659  // print rank information if there is a problem with the rank
660  if ((Opt != CSOpt::NONE && rank_diff < 0) ||
661  (Opt == CSOpt::NONE && rank_diff != 0)) {
662  errstr = "got rank " + std::to_string(shape.size()) + " " +
663  shape_str + ", expected rank " +
664  std::to_string(CountArgs<TArgs...>::value + 1) + " " +
665  expected_shape;
666  } else { // rank is OK print just the shapes
667  errstr = "got " + shape_str + ", expected " + expected_shape;
668  }
669  return std::make_tuple(status, errstr);
670  }
671 }
672 
673 } // namespace op_util
674 } // namespace ml
675 } // namespace open3d
#define DEFINE_DIMX_OPERATOR(opclass, symbol)
Definition: ShapeChecking.h:249
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
bool assign(int64_t a)
Definition: ShapeChecking.h:86
Dim(const Dim &other)
Definition: ShapeChecking.h:60
Dim()
Definition: ShapeChecking.h:52
Dim(const std::string &name)
Definition: ShapeChecking.h:54
bool & constant()
Definition: ShapeChecking.h:77
~Dim()
Definition: ShapeChecking.h:66
int64_t & value()
Definition: ShapeChecking.h:70
Dim & operator=(const Dim &)=delete
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:94
Dim(int64_t value, const std::string &name="")
Definition: ShapeChecking.h:57
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
DimValue(int64_t v)
Definition: ShapeChecking.h:22
DimValue & operator*=(const DimValue &b)
Definition: ShapeChecking.h:23
int64_t & value()
Definition: ShapeChecking.h:36
DimValue()
Definition: ShapeChecking.h:21
bool & constant()
Definition: ShapeChecking.h:40
std::string ToString() const
Definition: ShapeChecking.h:30
Dim expression class.
Definition: ShapeChecking.h:207
static DimX< TLeft, TRight, TOp > Create(TLeft left, TRight right)
Definition: ShapeChecking.h:209
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:231
bool assign(int64_t a)
assigns a value to the expression
Definition: ShapeChecking.h:223
bool & constant()
Definition: ShapeChecking.h:220
int64_t value()
Definition: ShapeChecking.h:213
std::string name
Definition: FilePCD.cpp:39
int offset
Definition: FilePCD.cpp:45
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c int
Definition: K4aPlugin.cpp:474
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:574
bool operator==(DimValue a, DimX< TLeft, TRight, TOp > &&b)
Definition: ShapeChecking.h:287
std::string GetString(DimX< TLeft, TRight, TOp > a, bool show_value=true)
Definition: ShapeChecking.h:312
DimValue UnknownValue()
Definition: ShapeChecking.h:47
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358
bool CheckDim(const DimValue &lhs, DimX< TLeft, TRight, TOp > &&rhs)
Definition: ShapeChecking.h:394
bool _CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex)
Definition: ShapeChecking.h:414
int64_t GetValue(DimX< TLeft, TRight, TOp > a)
Definition: ShapeChecking.h:321
std::string CreateDimXString()
Definition: ShapeChecking.h:345
Definition: PinholeCameraIntrinsic.cpp:16
Definition: ShapeChecking.h:307
static const size_t value
Definition: ShapeChecking.h:308
Definition: ShapeChecking.h:176
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:181
static bool constant()
Definition: ShapeChecking.h:177
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:178
static std::string ToString()
Definition: ShapeChecking.h:188
Definition: ShapeChecking.h:140
static std::string ToString()
Definition: ShapeChecking.h:158
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:145
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:142
static bool constant()
Definition: ShapeChecking.h:141
Definition: ShapeChecking.h:161
static bool constant()
Definition: ShapeChecking.h:162
static std::string ToString()
Definition: ShapeChecking.h:173
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:163
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:166
Definition: ShapeChecking.h:191
static bool constant()
Definition: ShapeChecking.h:192
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:193
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:198
static std::string ToString()
Definition: ShapeChecking.h:202
Definition: ShapeChecking.h:119
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:124
static bool constant()
Definition: ShapeChecking.h:120
static std::string ToString()
Definition: ShapeChecking.h:137
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:121