40 #include "../common/c71/DSPLIB_inlines.h"
52 #define SE_PARAM_BASE (0x0000)
53 #define SE0_PARAM_OFFSET (SE_PARAM_BASE)
54 #define SE1_PARAM_OFFSET (SE0_PARAM_OFFSET + SE_PARAM_SIZE)
55 #define SA0_PARAM_OFFSET (SE1_PARAM_OFFSET + SE_PARAM_SIZE)
56 #define SA1_PARAM_OFFSET (SA0_PARAM_OFFSET + SE_PARAM_SIZE)
57 #define SA2_PARAM_OFFSET (SA1_PARAM_OFFSET + SE_PARAM_SIZE)
59 #define DSPLIB_MATMUL_DOUBLE_UNROLL_FACTOR (8)
60 #define DSPLIB_MATMUL_SE_DOUBLE_UNROLL_FACTOR (4)
62 #define DSPLIB_MATMUL_FLOAT_UNROLL_FACTOR (16)
63 #define DSPLIB_MATMUL_SE_FLOAT_UNROLL_FACTOR (8)
65 template <
typename dataType>
inline void setUnrollFactors(int32_t *unrollFactor, int32_t *seUnrollFactor);
78 template <
typename dataType>
86 __SE_TEMPLATE_v1 se0Params;
87 __SE_TEMPLATE_v1 se1Params;
88 __SA_TEMPLATE_v1 sa0Params;
89 __SA_TEMPLATE_v1 sa1Params;
90 __SA_TEMPLATE_v1 sa2Params;
92 __SE_ELETYPE SE_ELETYPE;
93 __SE_VECLEN SE_VECLEN;
94 __SA_VECLEN SA_VECLEN;
98 uint8_t *pBlock = pKerPrivArgs->
bufPblock;
100 int32_t M = pKerPrivArgs->
M;
101 int32_t K = pKerPrivArgs->
K;
102 int32_t N = pKerPrivArgs->
N;
107 int32_t unrollFactor = 0;
108 int32_t seUnrollFactor = 0;
110 setUnrollFactors<dataType>(&unrollFactor, &seUnrollFactor);
112 typedef typename c7x::make_full_vector<dataType>::type vec;
114 int32_t elementCount = c7x::element_count_of<vec>::value;
115 SE_VECLEN = c7x::se_veclen<vec>::value;
116 SA_VECLEN = c7x::sa_veclen<vec>::value;
117 SE_ELETYPE = c7x::se_eletype<vec>::value;
118 int32_t KBlocks = ((K + unrollFactor - 1)) / (unrollFactor);
119 int32_t NBlocks = ((N + elementCount - 1)) / (elementCount);
121 pKerPrivArgs->
KBlocks = KBlocks;
122 pKerPrivArgs->
NBlocks = NBlocks;
128 sa0Params = __gen_SA_TEMPLATE_v1();
129 sa0Params.VECLEN = SA_VECLEN;
130 sa0Params.DIMFMT = __SA_DIMFMT_5D;
131 sa0Params.DECDIM1 = __SA_DECDIM_DIM2;
132 sa0Params.DECDIM1SD = __SA_DECDIMSD_DIM1;
135 sa0Params.ICNT1 = seUnrollFactor;
137 sa0Params.ICNT2 = KBlocks;
138 sa0Params.DIM2 = (int32_t) ((uint32_t) seUnrollFactor << (uint32_t) 1);
139 sa0Params.DECDIM1_WIDTH = K;
140 sa0Params.ICNT3 = NBlocks;
143 sa0Params.DIM4 = strideIn0;
149 sa2Params = __gen_SA_TEMPLATE_v1();
150 sa2Params.VECLEN = SA_VECLEN;
151 sa2Params.DIMFMT = __SA_DIMFMT_5D;
152 sa2Params.DECDIM1 = __SA_DECDIM_DIM2;
153 sa2Params.DECDIM1SD = __SA_DECDIMSD_DIM1;
156 sa2Params.ICNT1 = seUnrollFactor;
158 sa2Params.ICNT2 = KBlocks;
159 sa2Params.DIM2 = (int32_t) ((uint32_t) seUnrollFactor << (uint32_t) 1);
160 sa2Params.DECDIM1_WIDTH = (uint32_t) (K % 2 == 0 ? K : K - 1);
161 sa2Params.ICNT3 = NBlocks;
164 sa2Params.DIM4 = strideIn0;
170 se0Params = __gen_SE_TEMPLATE_v1();
171 se0Params.ELETYPE = SE_ELETYPE;
172 se0Params.VECLEN = SE_VECLEN;
173 se0Params.DIMFMT = __SE_DIMFMT_5D;
174 se0Params.DECDIM1 = __SE_DECDIM_DIM2;
175 se0Params.DECDIM2 = __SE_DECDIM_DIM3;
176 se0Params.DECDIM1SD = __SE_DECDIMSD_DIM1;
177 se0Params.DECDIM2SD = __SE_DECDIMSD_DIM0;
179 se0Params.ICNT0 = elementCount;
180 se0Params.ICNT1 = seUnrollFactor;
181 se0Params.DIM1 = (int32_t) ((uint32_t) strideIn1 << (uint32_t) 1);
182 se0Params.ICNT2 = KBlocks;
183 se0Params.DIM2 = seUnrollFactor * (int32_t) ((uint32_t) strideIn1 << (uint32_t) 1);
184 se0Params.DECDIM1_WIDTH = (uint32_t) K * strideIn1;
185 se0Params.ICNT3 = NBlocks;
186 se0Params.DIM3 = elementCount;
187 se0Params.DECDIM2_WIDTH = (uint32_t) N;
195 se1Params = __gen_SE_TEMPLATE_v1();
196 se1Params.ELETYPE = SE_ELETYPE;
197 se1Params.VECLEN = SE_VECLEN;
198 se1Params.DIMFMT = __SE_DIMFMT_5D;
199 se1Params.DECDIM1 = __SE_DECDIM_DIM2;
200 se1Params.DECDIM2 = __SE_DECDIM_DIM3;
201 se1Params.DECDIM1SD = __SE_DECDIMSD_DIM1;
202 se1Params.DECDIM2SD = __SE_DECDIMSD_DIM0;
204 se1Params.ICNT0 = elementCount;
205 se1Params.ICNT1 = seUnrollFactor;
206 se1Params.DIM1 = (int32_t) ((uint32_t) strideIn1 << (uint32_t) 1);
207 se1Params.ICNT2 = KBlocks;
208 se1Params.DIM2 = seUnrollFactor * (int32_t) ((uint32_t) strideIn1 << (uint32_t) 1);
209 se1Params.DECDIM1_WIDTH = (uint32_t) (K % 2 == 0 ? K : K - 1) * strideIn1;
210 se1Params.ICNT3 = NBlocks;
211 se1Params.DIM3 = elementCount;
212 se1Params.DECDIM2_WIDTH = (uint32_t) N;
220 sa1Params = __gen_SA_TEMPLATE_v1();
221 sa1Params.VECLEN = SA_VECLEN;
222 sa1Params.DIMFMT = __SA_DIMFMT_3D;
223 sa1Params.DECDIM1 = __SA_DECDIM_DIM1;
224 sa1Params.DECDIM1SD = __SA_DECDIMSD_DIM0;
226 sa1Params.ICNT0 = elementCount;
227 sa1Params.ICNT1 = NBlocks;
228 sa1Params.DIM1 = elementCount;
229 sa1Params.DECDIM1_WIDTH = N;
231 sa1Params.DIM2 = strideOut;
254 template <
typename T,
typename vec, u
int32_t
id>
static inline vec
loadAMatSA(__vpred tmp,
void *pIn)
259 tmp = c7x::strm_agen<id, T>::get_vpred();
260 out = __vload_pred_dup(tmp, (c7x::strm_agen<id, T>::get_adv(pIn)));
262 DSPLIB_debugPrintVector(out);
267 template <
typename T,
typename vec>
static inline void writeOutSA1(__vpred tmp, vec *addr, T pOut, vec out)
271 tmp = c7x::strm_agen<1, vec>::get_vpred();
272 addr = c7x::strm_agen<1, vec>::get_adv(pOut);
273 __vstore_pred(tmp, addr, out);
277 template <
typename dataType>
281 void *restrict pOut);
291 int32_t M = pKerPrivArgs->
M;
292 int32_t K = pKerPrivArgs->
K;
293 int32_t KBlocks = pKerPrivArgs->
KBlocks;
294 int32_t NBlocks = pKerPrivArgs->
NBlocks;
296 __SE_TEMPLATE_v1 se0Params;
297 __SE_TEMPLATE_v1 se1Params;
298 __SA_TEMPLATE_v1 sa0Params;
299 __SA_TEMPLATE_v1 sa1Params;
300 __SA_TEMPLATE_v1 sa2Params;
302 #if DSPLIB_DEBUGPRINT
303 printf(
"Enter DSPLIB_matMul_exec_ci\n");
306 typedef typename c7x::make_full_vector<double>::type vec;
308 uint8_t *pBlock = pKerPrivArgs->
bufPblock;
320 __SE0_OPEN(pIn1, se0Params);
321 __SE1_OPEN(((
double *) pIn1 + strideIn1), se1Params);
324 __SA0_OPEN(sa0Params);
325 __SA2_OPEN(sa2Params);
326 __SA1_OPEN(sa1Params);
331 vec r00, r01, r03, r02, r04, r05, r06, r07;
341 DSPLIB_DEBUGPRINTFN(1,
"\nIn _ci.cpp M = %d, NBlocks = %d, KBlocks = %d: \n", M, NBlocks, KBlocks);
343 vec a00, a01, a02, a03;
345 double *pIn0Local = ((
double *) pIn0 + 1);
348 a00 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
349 a01 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
350 a02 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
351 a03 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
355 for (int32_t mn = 0; mn < M * NBlocks; mn++) {
366 for (int32_t k = 0; k < KBlocks; k++) {
368 a = loadAMatSA<double, vec, 0>(predA, pIn0);
369 b = c7x::strm_eng<0, vec>::get_adv();
372 b = c7x::strm_eng<1, vec>::get_adv();
375 a = loadAMatSA<double, vec, 0>(predA, pIn0);
376 b = c7x::strm_eng<0, vec>::get_adv();
379 b = c7x::strm_eng<1, vec>::get_adv();
382 a = loadAMatSA<double, vec, 0>(predA, pIn0);
383 b = c7x::strm_eng<0, vec>::get_adv();
386 b = c7x::strm_eng<1, vec>::get_adv();
389 a = loadAMatSA<double, vec, 0>(predA, pIn0);
390 b = c7x::strm_eng<0, vec>::get_adv();
393 b = c7x::strm_eng<1, vec>::get_adv();
396 a00 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
397 a01 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
398 a02 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
399 a03 = loadAMatSA<double, vec, 2>(predA, pIn0Local);
419 se0Params = *(__SE_TEMPLATE_v1 *) ((uint8_t *) pBlock);
421 sa0Params = *(__SA_TEMPLATE_v1 *) ((uint8_t *) pBlock + (2 * SE_PARAM_SIZE));
422 sa1Params = *(__SA_TEMPLATE_v1 *) ((uint8_t *) pBlock + (3 * SE_PARAM_SIZE));
425 __SE0_OPEN(pIn1, se0Params);
426 __SA0_OPEN(sa0Params);
429 __SA1_OPEN(sa1Params);
434 vec r00, r02, r04, r06;
444 DSPLIB_DEBUGPRINTFN(1,
"\nIn _ci.cpp M = %d, NBlocks = %d, KBlocks = %d: \n", M, NBlocks, KBlocks);
448 for (int32_t mn = 0; mn < M * NBlocks; mn++) {
455 for (int32_t k = 0; k < KBlocks; k++) {
457 a = loadAMatSA<double, vec, 0>(predA, pIn0);
458 b = c7x::strm_eng<0, vec>::get_adv();
461 a = loadAMatSA<double, vec, 0>(predA, pIn0);
462 b = c7x::strm_eng<0, vec>::get_adv();
465 a = loadAMatSA<double, vec, 0>(predA, pIn0);
466 b = c7x::strm_eng<0, vec>::get_adv();
469 a = loadAMatSA<double, vec, 0>(predA, pIn0);
470 b = c7x::strm_eng<0, vec>::get_adv();
501 int32_t M = pKerPrivArgs->
M;
502 int32_t K = pKerPrivArgs->
K;
503 int32_t KBlocks = pKerPrivArgs->
KBlocks;
504 int32_t NBlocks = pKerPrivArgs->
NBlocks;
506 __SE_TEMPLATE_v1 se0Params;
507 __SE_TEMPLATE_v1 se1Params;
508 __SA_TEMPLATE_v1 sa0Params;
509 __SA_TEMPLATE_v1 sa1Params;
510 __SA_TEMPLATE_v1 sa2Params;
512 #if DSPLIB_DEBUGPRINT
513 printf(
"Enter DSPLIB_matMul_exec_ci\n");
516 typedef typename c7x::make_full_vector<float>::type vec;
518 uint8_t *pBlock = pKerPrivArgs->
bufPblock;
530 __SE0_OPEN(pIn1, se0Params);
531 __SE1_OPEN(((
float *) pIn1 + strideIn1), se1Params);
534 __SA0_OPEN(sa0Params);
535 __SA2_OPEN(sa2Params);
536 __SA1_OPEN(sa1Params);
541 vec r00, r01, r03, r02, r04, r05, r06, r07;
542 vec r08, r09, r0a, r0b, r0c, r0d, r0e, r0f;
553 DSPLIB_DEBUGPRINTFN(1,
"\nIn _ci.cpp M = %d, NBlocks = %d, KBlocks = %d: \n", M, NBlocks, KBlocks);
555 vec a00, a01, a02, a03;
556 vec a04, a05, a06, a07;
558 float *pIn0Local = ((
float *) pIn0 + 1);
561 a00 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
562 a01 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
563 a02 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
564 a03 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
565 a04 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
566 a05 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
567 a06 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
568 a07 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
578 for (int32_t mn = 0; mn < M * NBlocks; mn++) {
599 for (int32_t k = 0; k < KBlocks; k++) {
601 a = loadAMatSA<float, vec, 0>(predA, pIn0);
604 b = c7x::strm_eng<0, vec>::get_adv();
611 b = c7x::strm_eng<1, vec>::get_adv();
616 a = loadAMatSA<float, vec, 0>(predA, pIn0);
619 b = c7x::strm_eng<0, vec>::get_adv();
626 b = c7x::strm_eng<1, vec>::get_adv();
631 a = loadAMatSA<float, vec, 0>(predA, pIn0);
632 b = c7x::strm_eng<0, vec>::get_adv();
635 b = c7x::strm_eng<1, vec>::get_adv();
638 a = loadAMatSA<float, vec, 0>(predA, pIn0);
639 b = c7x::strm_eng<0, vec>::get_adv();
642 b = c7x::strm_eng<1, vec>::get_adv();
645 a = loadAMatSA<float, vec, 0>(predA, pIn0);
646 b = c7x::strm_eng<0, vec>::get_adv();
649 b = c7x::strm_eng<1, vec>::get_adv();
652 a = loadAMatSA<float, vec, 0>(predA, pIn0);
653 b = c7x::strm_eng<0, vec>::get_adv();
656 b = c7x::strm_eng<1, vec>::get_adv();
659 a = loadAMatSA<float, vec, 0>(predA, pIn0);
660 b = c7x::strm_eng<0, vec>::get_adv();
663 b = c7x::strm_eng<1, vec>::get_adv();
666 a = loadAMatSA<float, vec, 0>(predA, pIn0);
667 b = c7x::strm_eng<0, vec>::get_adv();
670 b = c7x::strm_eng<1, vec>::get_adv();
673 a00 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
674 a01 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
675 a02 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
676 a03 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
678 a04 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
679 a05 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
680 a06 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
681 a07 = loadAMatSA<float, vec, 2>(predA, pIn0Local);
717 se0Params = *(__SE_TEMPLATE_v1 *) ((uint8_t *) pBlock);
719 sa0Params = *(__SA_TEMPLATE_v1 *) ((uint8_t *) pBlock + (2 * SE_PARAM_SIZE));
720 sa1Params = *(__SA_TEMPLATE_v1 *) ((uint8_t *) pBlock + (3 * SE_PARAM_SIZE));
723 __SE0_OPEN(pIn1, se0Params);
724 __SA0_OPEN(sa0Params);
727 __SA1_OPEN(sa1Params);
732 vec r00, r02, r04, r06;
733 vec r08, r0a, r0c, r0e;
744 DSPLIB_DEBUGPRINTFN(1,
"\nIn _ci.cpp M = %d, NBlocks = %d, KBlocks = %d: \n", M, NBlocks, KBlocks);
748 for (int32_t mn = 0; mn < M * NBlocks; mn++) {
759 for (int32_t k = 0; k < KBlocks; k++) {
761 a = loadAMatSA<float, vec, 0>(predA, pIn0);
762 b = c7x::strm_eng<0, vec>::get_adv();
765 a = loadAMatSA<float, vec, 0>(predA, pIn0);
766 b = c7x::strm_eng<0, vec>::get_adv();
769 a = loadAMatSA<float, vec, 0>(predA, pIn0);
770 b = c7x::strm_eng<0, vec>::get_adv();
773 a = loadAMatSA<float, vec, 0>(predA, pIn0);
774 b = c7x::strm_eng<0, vec>::get_adv();
777 a = loadAMatSA<float, vec, 0>(predA, pIn0);
778 b = c7x::strm_eng<0, vec>::get_adv();
781 a = loadAMatSA<float, vec, 0>(predA, pIn0);
782 b = c7x::strm_eng<0, vec>::get_adv();
785 a = loadAMatSA<float, vec, 0>(predA, pIn0);
786 b = c7x::strm_eng<0, vec>::get_adv();
789 a = loadAMatSA<float, vec, 0>(predA, pIn0);
790 b = c7x::strm_eng<0, vec>::get_adv();
824 template <
typename dataType>
828 DSPLIB_matMul_generic_core_ci<dataType>(handle, pIn0, pIn1, pOut);
835 void *restrict pOut);
839 void *restrict pOut);
template DSPLIB_STATUS DSPLIB_matMul_generic_init_ci< double >(DSPLIB_kernelHandle handle, const DSPLIB_bufParams2D_t *bufParamsIn0, const DSPLIB_bufParams2D_t *bufParamsIn1, const DSPLIB_bufParams2D_t *bufParamsOut, const DSPLIB_matMul_InitArgs *pKerInitArgs)
DSPLIB_STATUS DSPLIB_matMul_generic_core_ci< float >(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
#define DSPLIB_MATMUL_FLOAT_UNROLL_FACTOR
template DSPLIB_STATUS DSPLIB_matMul_generic_init_ci< float >(DSPLIB_kernelHandle handle, const DSPLIB_bufParams2D_t *bufParamsIn0, const DSPLIB_bufParams2D_t *bufParamsIn1, const DSPLIB_bufParams2D_t *bufParamsOut, const DSPLIB_matMul_InitArgs *pKerInitArgs)
DSPLIB_STATUS DSPLIB_matMul_generic_init_ci(DSPLIB_kernelHandle handle, const DSPLIB_bufParams2D_t *bufParamsIn0, const DSPLIB_bufParams2D_t *bufParamsIn1, const DSPLIB_bufParams2D_t *bufParamsOut, const DSPLIB_matMul_InitArgs *pKerInitArgs)
DSPLIB_STATUS DSPLIB_matMul_generic_core_ci< double >(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
void setUnrollFactors< float >(int32_t *unrollFactor, int32_t *seUnrollFactor)
#define DSPLIB_MATMUL_DOUBLE_UNROLL_FACTOR
template DSPLIB_STATUS DSPLIB_matMul_generic_exec_ci< float >(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
static void writeOutSA1(__vpred tmp, vec *addr, T pOut, vec out)
void setUnrollFactors(int32_t *unrollFactor, int32_t *seUnrollFactor)
DSPLIB_STATUS DSPLIB_matMul_generic_exec_ci(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
DSPLIB_STATUS DSPLIB_matMul_generic_core_ci(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
#define DSPLIB_MATMUL_SE_FLOAT_UNROLL_FACTOR
template DSPLIB_STATUS DSPLIB_matMul_generic_exec_ci< double >(DSPLIB_kernelHandle handle, void *restrict pIn0, void *restrict pIn1, void *restrict pOut)
#define DSPLIB_MATMUL_SE_DOUBLE_UNROLL_FACTOR
void setUnrollFactors< double >(int32_t *unrollFactor, int32_t *seUnrollFactor)
static vec loadAMatSA(__vpred tmp, void *pIn)
Header file for kernel's internal use. For the kernel's interface, please see DSPLIB_matMul.
#define DSPLIB_DEBUGPRINTFN(N, fmt,...)
DSPLIB_STATUS_NAME
The enumeration of all status codes.
void * DSPLIB_kernelHandle
Handle type for DSPLIB operations.
A structure for a 2 dimensional buffer descriptor.
Structure containing the parameters to initialize the kernel.
Structure that is reserved for internal use by the kernel.
uint8_t bufPblock[DSPLIB_MATMUL_IXX_IXX_OXX_PBLOCK_SIZE]
int32_t strideIn1Elements
int32_t strideIn0Elements
int32_t strideOutElements