Open3D (C++ API)  0.11.0
BlasWrapper.h
Go to the documentation of this file.
1 // ----------------------------------------------------------------------------
2 // - Open3D: www.open3d.org -
3 // ----------------------------------------------------------------------------
4 // The MIT License (MIT)
5 //
6 // Copyright (c) 2018 www.open3d.org
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a copy
9 // of this software and associated documentation files (the "Software"), to deal
10 // in the Software without restriction, including without limitation the rights
11 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 // copies of the Software, and to permit persons to whom the Software is
13 // furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
23 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 // IN THE SOFTWARE.
25 // ----------------------------------------------------------------------------
26 
27 #pragma once
28 
31 #include "open3d/utility/Console.h"
32 
33 namespace open3d {
34 namespace core {
35 
36 template <typename scalar_t>
37 inline void gemm_cpu(CBLAS_LAYOUT layout,
38  CBLAS_TRANSPOSE trans_A,
39  CBLAS_TRANSPOSE trans_B,
43  scalar_t alpha,
44  const scalar_t *A_data,
46  const scalar_t *B_data,
48  scalar_t beta,
49  scalar_t *C_data,
51  utility::LogError("Unsupported data type.");
52 }
53 
54 template <>
55 inline void gemm_cpu<float>(CBLAS_LAYOUT layout,
56  CBLAS_TRANSPOSE trans_A,
57  CBLAS_TRANSPOSE trans_B,
61  float alpha,
62  const float *A_data,
64  const float *B_data,
66  float beta,
67  float *C_data,
69  cblas_sgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
70  ldb, beta, C_data, ldc);
71 }
72 
73 template <>
74 inline void gemm_cpu<double>(CBLAS_LAYOUT layout,
75  CBLAS_TRANSPOSE trans_A,
76  CBLAS_TRANSPOSE trans_B,
80  double alpha,
81  const double *A_data,
83  const double *B_data,
85  double beta,
86  double *C_data,
88  cblas_dgemm(layout, trans_A, trans_B, m, n, k, alpha, A_data, lda, B_data,
89  ldb, beta, C_data, ldc);
90 }
91 
92 #ifdef BUILD_CUDA_MODULE
93 template <typename scalar_t>
94 inline cublasStatus_t gemm_cuda(cublasHandle_t handle,
95  cublasOperation_t transa,
96  cublasOperation_t transb,
97  int m,
98  int n,
99  int k,
100  const scalar_t *alpha,
101  const scalar_t *A_data,
102  int lda,
103  const scalar_t *B_data,
104  int ldb,
105  const scalar_t *beta,
106  scalar_t *C_data,
107  int ldc) {
108  utility::LogError("Unsupported data type.");
109  return CUBLAS_STATUS_NOT_SUPPORTED;
110 }
111 
112 template <typename scalar_t>
113 inline cublasStatus_t trsm_cuda(cublasHandle_t handle,
114  cublasSideMode_t side,
115  cublasFillMode_t uplo,
116  cublasOperation_t trans,
117  cublasDiagType_t diag,
118  int m,
119  int n,
120  const scalar_t *alpha,
121  const scalar_t *A,
122  int lda,
123  scalar_t *B,
124  int ldb) {
125  utility::LogError("Unsupported data type.");
126  return CUBLAS_STATUS_NOT_SUPPORTED;
127 }
128 
129 template <>
130 inline cublasStatus_t gemm_cuda<float>(cublasHandle_t handle,
131  cublasOperation_t transa,
132  cublasOperation_t transb,
133  int m,
134  int n,
135  int k,
136  const float *alpha,
137  const float *A_data,
138  int lda,
139  const float *B_data,
140  int ldb,
141  const float *beta,
142  float *C_data,
143  int ldc) {
144  return cublasSgemm(handle, transa,
145  transb, // A, B transpose flag
146  m, n, k, // dimensions
147  alpha, static_cast<const float *>(A_data), lda,
148  static_cast<const float *>(B_data),
149  ldb, // input and their leading dims
150  beta, static_cast<float *>(C_data), ldc);
151 }
152 
153 template <>
154 inline cublasStatus_t gemm_cuda<double>(cublasHandle_t handle,
155  cublasOperation_t transa,
156  cublasOperation_t transb,
157  int m,
158  int n,
159  int k,
160  const double *alpha,
161  const double *A_data,
162  int lda,
163  const double *B_data,
164  int ldb,
165  const double *beta,
166  double *C_data,
167  int ldc) {
168  return cublasDgemm(handle, transa,
169  transb, // A, B transpose flag
170  m, n, k, // dimensions
171  alpha, static_cast<const double *>(A_data), lda,
172  static_cast<const double *>(B_data),
173  ldb, // input and their leading dims
174  beta, static_cast<double *>(C_data), ldc);
175 }
176 
177 template <>
178 inline cublasStatus_t trsm_cuda<float>(cublasHandle_t handle,
179  cublasSideMode_t side,
180  cublasFillMode_t uplo,
181  cublasOperation_t trans,
182  cublasDiagType_t diag,
183  int m,
184  int n,
185  const float *alpha,
186  const float *A,
187  int lda,
188  float *B,
189  int ldb) {
190  return cublasStrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
191  ldb);
192 }
193 
194 template <>
195 inline cublasStatus_t trsm_cuda<double>(cublasHandle_t handle,
196  cublasSideMode_t side,
197  cublasFillMode_t uplo,
198  cublasOperation_t trans,
199  cublasDiagType_t diag,
200  int m,
201  int n,
202  const double *alpha,
203  const double *A,
204  int lda,
205  double *B,
206  int ldb) {
207  return cublasDtrsm(handle, side, uplo, trans, diag, m, n, alpha, A, lda, B,
208  ldb);
209 }
210 #endif
211 } // namespace core
212 } // namespace open3d
void gemm_cpu< double >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, double alpha, const double *A_data, OPEN3D_CPU_LINALG_INT lda, const double *B_data, OPEN3D_CPU_LINALG_INT ldb, double beta, double *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:74
void LogError(const char *format, const Args &... args)
Definition: Console.h:176
void gemm_cpu(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, scalar_t alpha, const scalar_t *A_data, OPEN3D_CPU_LINALG_INT lda, const scalar_t *B_data, OPEN3D_CPU_LINALG_INT ldb, scalar_t beta, scalar_t *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:37
void gemm_cpu< float >(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_A, CBLAS_TRANSPOSE trans_B, OPEN3D_CPU_LINALG_INT m, OPEN3D_CPU_LINALG_INT n, OPEN3D_CPU_LINALG_INT k, float alpha, const float *A_data, OPEN3D_CPU_LINALG_INT lda, const float *B_data, OPEN3D_CPU_LINALG_INT ldb, float beta, float *C_data, OPEN3D_CPU_LINALG_INT ldc)
Definition: BlasWrapper.h:55
Definition: PinholeCameraIntrinsic.cpp:35
#define OPEN3D_CPU_LINALG_INT
Definition: LinalgHeadersCPU.h:45