22 DimValue(int64_t v) : value_(v), constant_(true) {}
24 if (constant_ && b.constant_)
32 return std::to_string(value_);
37 if (!constant_)
throw std::runtime_error(
"DimValue is not constant");
52 explicit Dim() : value_(0), constant_(false), origin_(this) {}
55 : value_(0), constant_(false), origin_(this), name_(
name) {}
58 : value_(
value), constant_(true), origin_(nullptr), name_(
name) {}
61 : value_(other.value_),
62 constant_(other.constant_),
63 origin_(other.origin_),
72 return origin_->value_;
79 return origin_->constant_;
103 return std::to_string(
value());
121 static int64_t
apply(int64_t a, int64_t b) {
return a + b; }
123 template <
class T1,
class T2>
125 if (!a.constant() && a.constant() == b.constant()) {
128 throw std::runtime_error(
"Illegal dim expression: " + exstr);
130 }
else if (!a.constant()) {
131 return a.assign(ans - b.value());
133 return b.assign(ans - a.value());
142 static int64_t
apply(int64_t a, int64_t b) {
return a - b; }
144 template <
class T1,
class T2>
146 if (!a.constant() && a.constant() == b.constant()) {
149 throw std::runtime_error(
"Illegal dim expression: " + exstr);
151 }
else if (!a.constant()) {
152 return a.assign(ans + b.value());
154 return b.assign(a.value() - ans);
163 static int64_t
apply(int64_t a, int64_t b) {
return a * b; }
165 template <
class T1,
class T2>
169 throw std::runtime_error(
"Illegal dim expression: " + exstr);
178 static int64_t
apply(int64_t a, int64_t b) {
return a / b; }
180 template <
class T1,
class T2>
184 throw std::runtime_error(
"Illegal dim expression: " + exstr);
193 static int64_t
apply(int64_t a, int64_t b) {
194 throw std::runtime_error(
"Cannot evaluate OR expression");
197 template <
class T1,
class T2>
199 return a.assign(ans) || b.assign(ans);
206 template <
class TLeft,
class TRight,
class TOp>
210 return DimX(left, right);
215 return TOp::apply(left_.value(), right_.value());
227 return TOp::backprop(a, left_, right_);
232 return left_.ToString(show_value) + TOp::ToString() +
233 right_.ToString(show_value);
237 DimX(TLeft left, TRight right) : left_(left), right_(right) {
238 constant_ = left.constant() && right.constant() && TOp::constant();
249 #define DEFINE_DIMX_OPERATOR(opclass, symbol) \
250 inline DimX<Dim, Dim, opclass> operator symbol(Dim a, Dim b) { \
251 return DimX<Dim, Dim, opclass>::Create(a, b); \
254 template <class TL, class TR, class TOp> \
255 inline DimX<Dim, DimX<TL, TR, TOp>, opclass> operator symbol( \
256 Dim a, DimX<TL, TR, TOp>&& b) { \
257 return DimX<Dim, DimX<TL, TR, TOp>, opclass>::Create(a, b); \
260 template <class TL, class TR, class TOp> \
261 inline DimX<DimX<TL, TR, TOp>, Dim, opclass> operator symbol( \
262 DimX<TL, TR, TOp>&& a, Dim b) { \
263 return DimX<DimX<TL, TR, TOp>, Dim, opclass>::Create(a, b); \
266 template <class TL1, class TR1, class TOp1, class TL2, class TR2, \
268 inline DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, opclass> \
269 operator symbol(DimX<TL1, TR1, TOp1>&& a, DimX<TL2, TR2, TOp2>&& b) { \
270 return DimX<DimX<TL1, TR1, TOp1>, DimX<TL2, TR2, TOp2>, \
271 opclass>::Create(a, b); \
279 #undef DEFINE_DIMX_OPERATOR
286 template <
class TLeft,
class TRight,
class TOp>
290 return b_copy.assign(a.
value());
306 template <
class... args>
308 static const size_t value =
sizeof...(args);
311 template <
class TLeft,
class TRight,
class TOp>
320 template <
class TLeft,
class TRight,
class TOp>
325 template <
class TLeft,
class TRight,
class TOp>
330 return unknown_dim_value;
341 return unknown_dim_value;
347 template <
class TDimX>
352 template <
class TDimX,
class... TArgs>
357 template <
class TDimX>
359 int64_t unknown_dim_value,
361 out.push_back(
GetValue(dimex, unknown_dim_value));
364 template <
class TDimX,
class... TArgs>
366 int64_t unknown_dim_value,
369 out.push_back(
GetValue(dimex, unknown_dim_value));
373 template <
class TDimX>
375 std::vector<int64_t> out;
380 template <
class TDimX,
class... TArgs>
384 std::vector<int64_t> out;
393 template <
class TLeft,
class TRight,
class TOp>
395 bool status = (lhs == std::forward<DimX<TLeft, TRight, TOp>>(rhs));
400 bool status = lhs == d;
413 template <CSOpt Opt = CSOpt::NONE,
class TDimX>
414 bool _CheckShape(
const std::vector<DimValue>& shape, TDimX&& dimex) {
416 const int rank_diff = shape.size() - 1;
422 if (rank_diff != 0) {
431 for (
int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
432 status =
CheckDim(s, std::forward<TDimX>(dimex));
434 status =
CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
438 status =
CheckDim(s, std::forward<TDimX>(dimex));
440 status =
CheckDim(shape[0], std::forward<TDimX>(dimex));
450 const int rank_diff = shape.size() - (
CountArgs<TArgs...>::value + 1);
456 if (rank_diff != 0) {
465 for (
int i = 0; i < rank_diff + 1; ++i) s *= shape[i];
466 status =
CheckDim(s, std::forward<TDimX>(dimex));
468 status =
CheckDim(shape[rank_diff], std::forward<TDimX>(dimex));
470 status =
CheckDim(shape[0], std::forward<TDimX>(dimex));
477 std::vector<DimValue> shape2(shape.begin() +
offset, shape.end());
478 bool status2 = _CheckShape<Opt>(shape2, std::forward<TArgs>(args)...);
480 return status && status2;
574 std::tuple<bool, std::string>
CheckShape(
const std::vector<DimValue>& shape,
577 const bool status = _CheckShape<Opt>(shape, std::forward<TDimX>(dimex),
578 std::forward<TArgs>(args)...);
580 return std::make_tuple(status, std::string());
582 const int rank_diff = shape.size() - (
CountArgs<TArgs...>::value + 1);
586 std::string shape_str;
587 if (rank_diff <= 0) {
589 for (
int i = 0; i <
int(shape.size()); ++i) {
590 shape_str += shape[i].ToString();
591 if (i + 1 <
int(shape.size())) shape_str +=
", ";
597 for (
int i = 0; i < rank_diff; ++i) {
598 shape_str += shape[i].ToString();
599 if (i + 1 <
int(shape.size())) shape_str +=
"*";
603 for (
int i = 0; i < rank_diff; ++i) {
604 shape_str += shape[i].ToString();
605 if (i + 1 < rank_diff) shape_str +=
", ";
617 int end = shape.size();
619 end -= rank_diff + 1;
623 for (
int i = start; i < end; ++i) {
624 shape_str += shape[i].ToString();
625 if (i + 1 < end) shape_str +=
", ";
629 for (
int i = std::max<int>(0, shape.size() - rank_diff - 1);
630 i <
int(shape.size()); ++i) {
631 shape_str += shape[i].ToString();
632 if (i + 1 <
int(shape.size())) shape_str +=
"*";
637 for (
int i = std::max<int>(0, shape.size() - rank_diff);
638 i <
int(shape.size()); ++i) {
639 shape_str += shape[i].ToString();
640 if (i + 1 <
int(shape.size())) shape_str +=
", ";
649 std::string expected_shape;
651 expected_shape =
"[" +
GetString(dimex) +
"]";
654 expected_shape =
"[" +
GetString(dimex) +
", " +
662 errstr =
"got rank " + std::to_string(shape.size()) +
" " +
663 shape_str +
", expected rank " +
667 errstr =
"got " + shape_str +
", expected " + expected_shape;
669 return std::make_tuple(status, errstr);
#define DEFINE_DIMX_OPERATOR(opclass, symbol)
Definition: ShapeChecking.h:249
Class for dimensions for which the value should be inferred.
Definition: ShapeChecking.h:50
bool assign(int64_t a)
Definition: ShapeChecking.h:86
Dim(const Dim &other)
Definition: ShapeChecking.h:60
Dim()
Definition: ShapeChecking.h:52
Dim(const std::string &name)
Definition: ShapeChecking.h:54
bool & constant()
Definition: ShapeChecking.h:77
~Dim()
Definition: ShapeChecking.h:66
int64_t & value()
Definition: ShapeChecking.h:70
Dim & operator=(const Dim &)=delete
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:94
Dim(int64_t value, const std::string &name="")
Definition: ShapeChecking.h:57
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
DimValue(int64_t v)
Definition: ShapeChecking.h:22
DimValue & operator*=(const DimValue &b)
Definition: ShapeChecking.h:23
int64_t & value()
Definition: ShapeChecking.h:36
DimValue()
Definition: ShapeChecking.h:21
bool & constant()
Definition: ShapeChecking.h:40
std::string ToString() const
Definition: ShapeChecking.h:30
Dim expression class.
Definition: ShapeChecking.h:207
static DimX< TLeft, TRight, TOp > Create(TLeft left, TRight right)
Definition: ShapeChecking.h:209
std::string ToString(bool show_value=true)
Definition: ShapeChecking.h:231
bool assign(int64_t a)
assigns a value to the expression
Definition: ShapeChecking.h:223
bool & constant()
Definition: ShapeChecking.h:220
int64_t value()
Definition: ShapeChecking.h:213
const char const char value recording_handle imu_sample recording_handle uint8_t size_t data_size k4a_record_configuration_t config target_format k4a_capture_t capture_handle k4a_imu_sample_t imu_sample playback_handle k4a_logging_message_cb_t void min_level device_handle k4a_imu_sample_t timeout_in_ms capture_handle capture_handle capture_handle image_handle temperature_c int
Definition: K4aPlugin.cpp:474
std::tuple< bool, std::string > CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex, TArgs &&... args)
Definition: ShapeChecking.h:574
bool operator==(DimValue a, DimX< TLeft, TRight, TOp > &&b)
Definition: ShapeChecking.h:287
std::string GetString(DimX< TLeft, TRight, TOp > a, bool show_value=true)
Definition: ShapeChecking.h:312
DimValue UnknownValue()
Definition: ShapeChecking.h:47
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358
bool CheckDim(const DimValue &lhs, DimX< TLeft, TRight, TOp > &&rhs)
Definition: ShapeChecking.h:394
bool _CheckShape(const std::vector< DimValue > &shape, TDimX &&dimex)
Definition: ShapeChecking.h:414
int64_t GetValue(DimX< TLeft, TRight, TOp > a)
Definition: ShapeChecking.h:321
std::string CreateDimXString()
Definition: ShapeChecking.h:345
Definition: PinholeCameraIntrinsic.cpp:16
Definition: ShapeChecking.h:307
static const size_t value
Definition: ShapeChecking.h:308
Definition: ShapeChecking.h:176
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:181
static bool constant()
Definition: ShapeChecking.h:177
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:178
static std::string ToString()
Definition: ShapeChecking.h:188
Definition: ShapeChecking.h:140
static std::string ToString()
Definition: ShapeChecking.h:158
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:145
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:142
static bool constant()
Definition: ShapeChecking.h:141
Definition: ShapeChecking.h:161
static bool constant()
Definition: ShapeChecking.h:162
static std::string ToString()
Definition: ShapeChecking.h:173
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:163
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:166
Definition: ShapeChecking.h:191
static bool constant()
Definition: ShapeChecking.h:192
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:193
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:198
static std::string ToString()
Definition: ShapeChecking.h:202
Definition: ShapeChecking.h:119
static bool backprop(int64_t ans, T1 a, T2 b)
Definition: ShapeChecking.h:124
static bool constant()
Definition: ShapeChecking.h:120
static std::string ToString()
Definition: ShapeChecking.h:137
static int64_t apply(int64_t a, int64_t b)
Definition: ShapeChecking.h:121