35#ifndef MADNESS_TENSOR_MXM_H__INCLUDED
36#define MADNESS_TENSOR_MXM_H__INCLUDED
67 template <
typename T,
typename Q,
typename S>
69 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
70 const S* MADNESS_RESTRICT
b) {
78 for (
long i=0; i<dimi; ++i) {
79 for (
long k=0;
k<dimk; ++
k) {
80 for (
long j=0; j<dimj; ++j) {
81 c[i*dimj+j] +=
a[i*dimk+
k]*
b[
k*dimj+j];
89 template <
typename T,
typename Q,
typename S>
92 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
93 const S* MADNESS_RESTRICT
b) {
103 for (
long k=0;
k<dimk; ++
k) {
104 for (
long j=0; j<dimj; ++j) {
105 for (
long i=0; i<dimi; ++i) {
106 c[i*dimj+j] +=
a[
k*dimi+i]*
b[
k*dimj+j];
113 template <
typename T,
typename Q,
typename S>
115 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
116 const S* MADNESS_RESTRICT
b) {
126 for (
long i=0; i<dimi; ++i) {
127 for (
long j=0; j<dimj; ++j) {
129 for (
long k=0;
k<dimk; ++
k) {
130 sum +=
a[i*dimk+
k]*
b[j*dimk+
k];
138 template <
typename T,
typename Q,
typename S>
140 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
141 const S* MADNESS_RESTRICT
b) {
149 for (
long i=0; i<dimi; ++i) {
150 for (
long j=0; j<dimj; ++j) {
151 for (
long k=0;
k<dimk; ++
k) {
152 c[i*dimj+j] +=
a[
k*dimi+i]*
b[j*dimk+
k];
170 template <
typename aT,
typename bT,
typename cT>
172 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b,
long ldb=-1) {
173 if (ldb == -1) ldb=dimj;
176 for (
long i=0; i<dimi; ++i,
c+=dimj,++
a) {
177 for (
long j=0; j<dimj; ++j)
c[j] = 0.0;
178 const aT *aik_ptr =
a;
179 for (
long k=0;
k<dimk; ++
k,aik_ptr+=dimi) {
181 for (
long j=0; j<dimj; ++j) {
182 c[j] += aki*
b[
k*ldb+j];
189#if defined(HAVE_FAST_BLAS) && !defined(HAVE_INTEL_MKL)
198 template <
typename T>
199 void mxm(
long dimi,
long dimj,
long dimk,
200 T* MADNESS_RESTRICT
c,
const T*
a,
const T*
b) {
202 cblas::gemm(
cblas::NoTrans,
cblas::NoTrans,dimj,dimi,dimk,one,
b,dimj,
a,dimk,one,
c,dimj);
211 template <
typename T>
212 void mTxm(
long dimi,
long dimj,
long dimk,
213 T* MADNESS_RESTRICT
c,
const T*
a,
const T*
b) {
215 cblas::gemm(
cblas::NoTrans,
cblas::Trans,dimj,dimi,dimk,one,
b,dimj,
a,dimi,one,
c,dimj);
224 template <
typename T>
225 void mxmT(
long dimi,
long dimj,
long dimk,
226 T* MADNESS_RESTRICT
c,
const T*
a,
const T*
b) {
228 cblas::gemm(
cblas::Trans,
cblas::NoTrans,dimj,dimi,dimk,one,
b,dimk,
a,dimk,one,
c,dimj);
237 template <
typename T>
238 void mTxmT(
long dimi,
long dimj,
long dimk,
239 T* MADNESS_RESTRICT
c,
const T*
a,
const T*
b) {
241 cblas::gemm(
cblas::Trans,
cblas::Trans,dimj,dimi,dimk,one,
b,dimk,
a,dimi,one,
c,dimj);
256 template <
typename T>
257 void mTxmq(
long dimi,
long dimj,
long dimk,
258 T* MADNESS_RESTRICT
c,
const T*
a,
const T*
b,
long ldb=-1) {
259 if (ldb == -1) ldb=dimj;
262 if (dimi==0 || dimj==0)
return;
264 for (
long i=0; i<dimi*dimj; i++)
c[i] = 0.0;
269 cblas::gemm(
cblas::NoTrans,
cblas::Trans,dimj,dimi,dimk,one,
b,ldb,
a,dimi,zero,
c,dimj);
274 void mTxmq(
long dimi,
long dimj,
long dimk,
double* MADNESS_RESTRICT
c,
const double*
a,
const double*
b,
long ldb);
277 template <
typename T>
278 void mTxmq(
long dimi,
long dimj,
long dimk, std::complex<T>* MADNESS_RESTRICT
c,
const std::complex<T>*
a,
const T*
b,
long ldb) {
279 T* Rc =
new T[dimi*dimj];
280 T* Ic =
new T[dimi*dimj];
281 T* Ra =
new T[dimi*dimk];
282 T* Ia =
new T[dimi*dimk];
284 for (
long i=0; i<dimi*dimk; i++) {
288 mTxmq(dimi,dimj,dimk,Rc,Ra,
b,ldb);
289 mTxmq(dimi,dimj,dimk,Ic,Ia,
b,ldb);
290 for (
long i=0; i<dimi*dimj; i++) c[i] = std::complex<T>(Rc[i],Ic[i]);
309 template <
typename aT,
typename bT,
typename cT>
310 void mxm(
long dimi,
long dimj,
long dimk,
311 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b) {
313 cblas::gemm(
cblas::NoTrans,
cblas::NoTrans,dimj,dimi,dimk,one,
b,dimj,
a,dimk,one,
c,dimj);
322 template <
typename aT,
typename bT,
typename cT>
323 void mTxm(
long dimi,
long dimj,
long dimk,
324 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b) {
326 cblas::gemm(
cblas::NoTrans,
cblas::Trans,dimj,dimi,dimk,one,
b,dimj,
a,dimi,one,
c,dimj);
335 template <
typename aT,
typename bT,
typename cT>
336 void mxmT(
long dimi,
long dimj,
long dimk,
337 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b) {
339 cblas::gemm(
cblas::Trans,
cblas::NoTrans,dimj,dimi,dimk,one,
b,dimk,
a,dimk,one,
c,dimj);
348 template <
typename aT,
typename bT,
typename cT>
349 void mTxmT(
long dimi,
long dimj,
long dimk,
350 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b) {
352 cblas::gemm(
cblas::Trans,
cblas::Trans,dimj,dimi,dimk,one,
b,dimk,
a,dimi,one,
c,dimj);
367 template <
typename aT,
typename bT,
typename cT>
368 void mTxmq(
long dimi,
long dimj,
long dimk,
369 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b,
long ldb=-1) {
370 if (ldb == -1) ldb=dimj;
373 if (dimi==0 || dimj==0)
return;
375 for (
long i=0; i<dimi*dimj; i++)
c[i] = 0.0;
380 cblas::gemm(
cblas::NoTrans,
cblas::Trans,dimj,dimi,dimk,one,
b,ldb,
a,dimi,zero,
c,dimj);
385void mTxmq(
long dimi,
long dimj,
long dimk,
double* MADNESS_RESTRICT
c,
const double*
a,
const double*
b,
long ldb);
392 template <
typename T,
typename Q,
typename S>
393 static inline void mxm(
long dimi,
long dimj,
long dimk,
394 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
395 const S* MADNESS_RESTRICT
b) {
399 template <
typename T,
typename Q,
typename S>
401 void mTxm(
long dimi,
long dimj,
long dimk,
402 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
403 const S* MADNESS_RESTRICT
b) {
407 template <
typename T,
typename Q,
typename S>
408 static inline void mxmT(
long dimi,
long dimj,
long dimk,
409 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
410 const S* MADNESS_RESTRICT
b) {
414 template <
typename T,
typename Q,
typename S>
415 static inline void mTxmT(
long dimi,
long dimj,
long dimk,
416 T* MADNESS_RESTRICT
c,
const Q* MADNESS_RESTRICT
a,
417 const S* MADNESS_RESTRICT
b) {
421 template <
typename aT,
typename bT,
typename cT>
422 void mTxmq(
long dimi,
long dimj,
long dimk,
423 cT* MADNESS_RESTRICT
c,
const aT*
a,
const bT*
b,
long ldb=-1) {
432 inline void mTxm(
long dimi,
long dimj,
long dimk,
433 double* MADNESS_RESTRICT
c,
const double* MADNESS_RESTRICT
a,
434 const double* MADNESS_RESTRICT
b) {
447 long dimk4 = (dimk/4)*4;
448 for (
long i=0; i<dimi; ++i,
c+=dimj) {
449 const double* ai =
a+i;
451 for (
long k=0;
k<dimk4;
k+=4,ai+=4*dimi,
p+=4*dimj) {
452 double ak0i = ai[0 ];
453 double ak1i = ai[dimi];
454 double ak2i = ai[dimi+dimi];
455 double ak3i = ai[dimi+dimi+dimi];
456 const double* bk0 =
p;
457 const double* bk1 =
p+dimj;
458 const double* bk2 =
p+dimj+dimj;
459 const double* bk3 =
p+dimj+dimj+dimj;
460 for (
long j=0; j<dimj; ++j) {
461 c[j] += ak0i*bk0[j] + ak1i*bk1[j] + ak2i*bk2[j] + ak3i*bk3[j];
464 for (
long k=dimk4;
k<dimk; ++
k) {
465 double aki =
a[
k*dimi+i];
466 const double* bk =
b+
k*dimj;
467 for (
long j=0; j<dimj; ++j) {
479 inline void mxmT(
long dimi,
long dimj,
long dimk,
480 double* MADNESS_RESTRICT
c,
481 const double* MADNESS_RESTRICT
a,
const double* MADNESS_RESTRICT
b) {
494 long dimi2 = (dimi/2)*2;
495 for (
long i=0; i<dimi2; i+=2) {
496 const double* ai0 =
a+i*dimk;
497 const double* ai1 =
a+i*dimk+dimk;
498 double* MADNESS_RESTRICT ci0 =
c+i*dimj;
499 double* MADNESS_RESTRICT ci1 =
c+i*dimj+dimj;
500 for (
long j=0; j<dimj; ++j) {
503 const double* bj =
b + j*dimk;
504 for (
long k=0;
k<dimk; ++
k) {
505 sum0 += ai0[
k]*bj[
k];
506 sum1 += ai1[
k]*bj[
k];
512 for (
long i=dimi2; i<dimi; ++i) {
513 const double* ai =
a+i*dimk;
514 double* MADNESS_RESTRICT ci =
c+i*dimj;
515 for (
long j=0; j<dimj; ++j) {
517 const double* bj =
b+j*dimk;
518 for (
long k=0;
k<dimk; ++
k) {
528 inline void mxm(
long dimi,
long dimj,
long dimk,
529 double* MADNESS_RESTRICT
c,
const double* MADNESS_RESTRICT
a,
const double* MADNESS_RESTRICT
b) {
540 long dimk4 = (dimk/4)*4;
541 for (
long i=0; i<dimi; ++i,
c+=dimj,
a+=dimk) {
543 for (
long k=0;
k<dimk4;
k+=4,
p+=4*dimj) {
545 double aik1 =
a[
k+1];
546 double aik2 =
a[
k+2];
547 double aik3 =
a[
k+3];
548 const double* bk0 =
p;
549 const double* bk1 = bk0+dimj;
550 const double* bk2 = bk1+dimj;
551 const double* bk3 = bk2+dimj;
552 for (
long j=0; j<dimj; ++j) {
553 c[j] += aik0*bk0[j] + aik1*bk1[j] + aik2*bk2[j] + aik3*bk3[j];
556 for (
long k=dimk4;
k<dimk; ++
k) {
558 for (
long j=0; j<dimj; ++j) {
559 c[j] += aik*
b[
k*dimj+j];
567 inline void mTxmT(
long dimi,
long dimj,
long dimk,
568 double* MADNESS_RESTRICT csave,
const double* MADNESS_RESTRICT asave,
const double* MADNESS_RESTRICT
b) {
580 long dimj2 = (dimj/2)*2;
582 for (
long klo=0; klo<dimk; klo+=ktile, asave+=ktile*dimi,
b+=ktile) {
583 long khi = klo+ktile;
584 if (khi > dimk) khi = dimk;
587 const double * MADNESS_RESTRICT
a = asave;
588 double * MADNESS_RESTRICT
c = csave;
589 for (
long i=0; i<dimi; ++i,
c+=dimj,++
a) {
591 for (
long k=0;
k<nk; ++
k,
q+=dimi) ai[
k] = *
q;
593 const double* bj0 =
b;
594 for (
long j=0; j<dimj2; j+=2,bj0+=2*dimk) {
595 const double* bj1 = bj0+dimk;
598 for (
long k=0;
k<nk; ++
k) {
599 sum0 += ai[
k]*bj0[
k];
600 sum1 += ai[
k]*bj1[
k];
606 for (
long j=dimj2; j<dimj; ++j,bj0+=dimk) {
608 for (
long k=0;
k<nk; ++
k) {
622 template <
typename aT,
typename bT,
typename cT>
623 void mTxmq_padding(
long dimi,
long dimj,
long dimk,
long ext_b,
624 cT*
c,
const aT*
a,
const bT*
b) {
625 const int alignment = 4;
631 if (dimj%alignment) {
632 effj = (dimj | 3) + 1;
633 c_buf = (cT*)
malloc(
sizeof(cT)*dimi*effj);
637 if (ext_b%alignment) {
639 bT* b_buf = (bT*)
malloc(
sizeof(bT)*dimk*effj);
642 for (
long k=0;
k<dimk;
k++, bp += effj,
b += ext_b)
643 memcpy(bp,
b,
sizeof(bT)*dimj);
651 for (
long i=0; i<dimi; ++i,c_work+=effj,++
a) {
652 for (
long j=0; j<dimj; ++j) c_work[j] = 0.0;
653 const aT *aik_ptr =
a;
654 for (
long k=0;
k<dimk; ++
k,aik_ptr+=dimi) {
656 for (
long j=0; j<dimj; ++j) {
657 c_work[j] += aki*
b[
k*ext_b+j];
663 if (dimj%alignment) {
665 for (
long i=0; i<dimi; i++, ct += effj,
c += dimj)
666 memcpy(
c, ct,
sizeof(cT)*dimj);
672 if (free_b) free((bT*)
b);
677 double*
c,
const double*
a,
const double*
b);
679 __complex__
double*
c,
const __complex__
double*
a,
const __complex__
double*
b);
681 __complex__
double*
c,
const double*
a,
const __complex__
double*
b);
683 __complex__
double*
c,
const __complex__
double*
a,
const double*
b);
686 inline void mTxmq_padding(
long ni,
long nj,
long nk,
long ej,
687 double*
c,
const double*
a,
const double*
b) {
692 inline void mTxmq_padding(
long ni,
long nj,
long nk,
long ej,
693 __complex__
double*
c,
const __complex__
double*
a,
const __complex__
double*
b) {
698 inline void mTxmq_padding(
long ni,
long nj,
long nk,
long ej,
699 __complex__
double*
c,
const double*
a,
const __complex__
double*
b) {
704 inline void mTxmq_padding(
long ni,
long nj,
long nk,
long ej,
705 __complex__
double*
c,
const __complex__
double*
a,
const double*
b) {
708#elif defined(HAVE_IBMBGP)
709 extern void bgpmTxmq(
long ni,
long nj,
long nk,
double* MADNESS_RESTRICT
c,
710 const double*
a,
const double*
b);
711 extern void bgpmTxmq(
long ni,
long nj,
long nk,
double_complex* MADNESS_RESTRICT
c,
715 inline void mTxmq(
long ni,
long nj,
long nk,
double* MADNESS_RESTRICT
c,
const double*
a,
const double*
b) {
716 bgpmTxmq(ni, nj, nk,
c,
a,
b);
721 bgpmTxmq(ni, nj, nk,
c,
a,
b);
double q(double t)
Definition DKops.h:18
Define BLAS like functions.
std::complex< double > double_complex
Definition cfft.h:14
char * p(char *buf, const char *name, int k, int initial_level, double thresh, int order)
Definition derivatives.cc:72
auto T(World &world, response_space &f) -> response_space
Definition global_functions.cc:34
Macros and tools pertaining to the configuration of MADNESS.
#define MADNESS_ASSERT(condition)
Assert a condition that should be free of side-effects since in release builds this might be a no-op.
Definition madness_exception.h:134
void gemm(const CBLAS_TRANSPOSE OpA, const CBLAS_TRANSPOSE OpB, const integer m, const integer n, const integer k, const float alpha, const float *a, const integer lda, const float *b, const integer ldb, const float beta, float *c, const integer ldc)
Multiplies a matrix by a vector.
Definition cblas.h:352
@ NoTrans
Definition cblas_types.h:78
@ Trans
Definition cblas_types.h:79
Namespace for all elements and tools of MADNESS.
Definition DFParameters.h:10
static void mTxm_reference(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const Q *MADNESS_RESTRICT a, const S *MADNESS_RESTRICT b)
Matrix += Matrix transpose * matrix ... reference implementation (slow but correct)
Definition mxm.h:91
void mTxmT(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const T *a, const T *b)
Matrix += Matrix transpose * matrix transpose ... MKL interface version.
Definition mxm.h:238
void mxm(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const T *a, const T *b)
Matrix += Matrix * matrix ... BLAS/MKL interface version.
Definition mxm.h:199
static void mxm_reference(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const Q *MADNESS_RESTRICT a, const S *MADNESS_RESTRICT b)
Matrix += Matrix * matrix reference implementation (slow but correct)
Definition mxm.h:68
void mTxm(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const T *a, const T *b)
Matrix += Matrix transpose * matrix ... MKL interface version.
Definition mxm.h:212
static void mxmT_reference(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const Q *MADNESS_RESTRICT a, const S *MADNESS_RESTRICT b)
Matrix += Matrix * matrix transpose ... reference implementation (slow but correct)
Definition mxm.h:114
void mTxmq_reference(long dimi, long dimj, long dimk, cT *MADNESS_RESTRICT c, const aT *a, const bT *b, long ldb=-1)
Matrix = Matrix transpose * matrix ... slow reference implementation.
Definition mxm.h:171
void mTxmq_padding(long dimi, long dimj, long dimk, long ext_b, cT *c, const aT *a, const bT *b)
Definition mtxmq.h:96
static void mTxmT_reference(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const Q *MADNESS_RESTRICT a, const S *MADNESS_RESTRICT b)
Matrix += Matrix transpose * matrix transpose reference implementation (slow but correct)
Definition mxm.h:139
void bgq_mtxmq_padded(long dimi, long dimj, long dimk, long extb, __complex__ double *c_x, const __complex__ double *a_x, const __complex__ double *b_x)
Definition bgq_mtxm.cc:10
void mxmT(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const T *a, const T *b)
Matrix += Matrix * matrix transpose ... MKL interface version.
Definition mxm.h:225
void mTxmq(long dimi, long dimj, long dimk, T *MADNESS_RESTRICT c, const T *a, const T *b, long ldb=-1)
Matrix = Matrix transpose * matrix ... MKL interface version.
Definition mxm.h:257
static const double b
Definition nonlinschro.cc:119
static const double a
Definition nonlinschro.cc:118
double Q(double a)
Definition relops.cc:20
static const double c
Definition relops.cc:10
static const long k
Definition rk.cc:44
AtomicInt sum
Definition test_atomicint.cc:46