Open3D (C++ API)  0.18.0
BlasWrapper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // Copyright (c) 2018-2023 www.open3d.org
5 // SPDX-License-Identifier: MIT
6 // ----------------------------------------------------------------------------
7 
8 #pragma once
9 
12 #include "open3d/utility/Logging.h"
13 
14 namespace open3d {
15 namespace core {
16 
17 template <typename scalar_t>
18 inline void gemm_cpu(CBLAS_LAYOUT layout,
19  CBLAS_TRANSPOSE trans_A,
20  CBLAS_TRANSPOSE trans_B,
24  scalar_t alpha,
25  const scalar_t *A_data,
27  const scalar_t *B_data,
29  scalar_t beta,
30  scalar_t *C_data,
32  utility::LogError("Unsupported data type.");
33 }
34 
35 template <>
36 inline void gemm_cpu<float>(CBLAS_LAYOUT layout,
37  CBLAS_TRANSPOSE trans_A,
38  CBLAS_TRANSPOSE trans_B,
42  float alpha,
43  const float *A_data,
45  const float *B_data,
47  float beta,
48  float *C_data,
50  cblas_sgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
51  ldb, beta, C_data, ldc);
52 }
53 
54 template <>
55 inline void gemm_cpu<double>(CBLAS_LAYOUT layout,
56  CBLAS_TRANSPOSE trans_A,
57  CBLAS_TRANSPOSE trans_B,
61  double alpha,
62  const double *A_data,
64  const double *B_data,
66  double beta,
67  double *C_data,
69  cblas_dgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
70  ldb, beta, C_data, ldc);
71 }
72 
73 #ifdef BUILD_CUDA_MODULE
74 template <typename scalar_t>
75 inline cublasStatus_t gemm_cuda(cublasHandle_t handle,
76  cublasOperation_t transa,
77  cublasOperation_t transb,
78  int m,
79  int n,
80  int k,
81  const scalar_t *alpha,
82  const scalar_t *A_data,
83  int lda,
84  const scalar_t *B_data,
85  int ldb,
86  const scalar_t *beta,
87  scalar_t *C_data,
88  int ldc) {
89  utility::LogError("Unsupported data type.");
90  return CUBLAS_STATUS_NOT_SUPPORTED;
91 }
92 
93 template <typename scalar_t>
94 inline cublasStatus_t trsm_cuda(cublasHandle_t handle,
95  cublasSideMode_t side,
96  cublasFillMode_t uplo,
97  cublasOperation_t trans,
98  cublasDiagType_t diag,
99  int m,
100  int n,
101  const scalar_t *alpha,
102  const scalar_t *A,
103  int lda,
104  scalar_t *B,
105  int ldb) {
106  utility::LogError("Unsupported data type.");
107  return CUBLAS_STATUS_NOT_SUPPORTED;
108 }
109 
110 template <>
111 inline cublasStatus_t gemm_cuda<float>(cublasHandle_t handle,
112  cublasOperation_t transa,
113  cublasOperation_t transb,
114  int m,
115  int n,
116  int k,
117  const float *alpha,
118  const float *A_data,
119  int lda,
120  const float *B_data,
121  int ldb,
122  const float *beta,
123  float *C_data,
124  int ldc) {
125  return cublasSgemm(handle, transa,
126  transb, // A, B transpose flag
127  m, n, k, // dimensions
128  alpha, static_cast<const float *>(A_data), lda,
129  static_cast<const float *>(B_data),
130  ldb, // input and their leading dims
131  beta, static_cast<float *>(C_data), ldc);
132 }
133 
134 template <>
135 inline cublasStatus_t gemm_cuda<double>(cublasHandle_t handle,
136  cublasOperation_t transa,
137  cublasOperation_t transb,
138  int m,
139  int n,
140  int k,
141  const double *alpha,
142  const double *A_data,
143  int lda,
144  const double *B_data,
145  int ldb,
146  const double *beta,
147  double *C_data,
148  int ldc) {
149  return cublasDgemm(handle, transa,
150  transb, // A, B transpose flag
151  m, n, k, // dimensions
152  alpha, static_cast<const double *>(A_data), lda,
153  static_cast<const double *>(B_data),
154  ldb, // input and their leading dims
155  beta, static_cast<double *>(C_data), ldc);
156 }
157 
158 template <>
159 inline cublasStatus_t trsm_cuda<float>(cublasHandle_t handle,
160  cublasSideMode_t side,
161  cublasFillMode_t uplo,
162  cublasOperation_t trans,
163  cublasDiagType_t diag,
164  int m,
165  int n,
166  const float *alpha,
167  const float *A,
168  int lda,
169  float *B,
170  int ldb) {
171  return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
172  ldb);
173 }
174 
175 template <>
176 inline cublasStatus_t trsm_cuda<double>(cublasHandle_t handle,
177  cublasSideMode_t side,
178  cublasFillMode_t uplo,
179  cublasOperation_t trans,
180  cublasDiagType_t diag,
181  int m,
182  int n,
183  const double *alpha,
184  const double *A,
185  int lda,
186  double *B,
187  int ldb) {
188  return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
189  ldb);
190 }
191 #endif
192 
193 } // namespace core
194 } // namespace open3d
#define OPEN3D_CPU_LINALG_INT
Definition: LinalgHeadersCPU.h:23
#define LogError(...)
Definition: Logging.h:48
Eigen::Matrix3d B
Definition: PointCloudPlanarPatchDetection.cpp:506
void gemm_cpu< double >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, double alpha, const double *A_data, OPEN3D_CPU_LINALG_INT lda, const double *B_data, OPEN3D_CPU_LINALG_INT ldb, double beta, double *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:55
void gemm_cpu(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, scalar_t alpha, const scalar_t *A_data, OPEN3D_CPU_LINALG_INT lda, const scalar_t *B_data, OPEN3D_CPU_LINALG_INT ldb, scalar_t beta, scalar_t *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:18
void gemm_cpu< float >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, float alpha, const float *A_data, OPEN3D_CPU_LINALG_INT lda, const float *B_data, OPEN3D_CPU_LINALG_INT ldb, float beta, float *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:36
Definition: PinholeCameraIntrinsic.cpp:16