54 #define MAX_ITERATION_COUNT 30
81 template <
typename dataType>
90 uint32_t enableReducedForm)
93 dataType s, s2, si, scale, half_norm_squared;
99 for (i = 0; i < Ncols; i++) {
100 superdiag[i] = scale * s;
104 for (j = i; j < Nrows; j++) {
105 scale += fabs(U[i + j * colUStride]);
110 for (j = i; j < Nrows; j++) {
111 U[i + j * colUStride] = U[i + j * colUStride] / scale;
112 s2 += U[i + j * colUStride] * U[i + j * colUStride];
114 if (U[i + i * colUStride] < 0) {
120 half_norm_squared = U[i + i * colUStride] * s - s2;
121 U[i + i * colUStride] -= s;
123 for (j = i + 1; j < Ncols; j++) {
126 for (k = i; k < Nrows; k++) {
127 si += U[i + k * colUStride] * U[j + k * colUStride];
129 si = si / half_norm_squared;
130 for (k = i; k < Nrows; k++) {
131 U[j + k * colUStride] += si * U[i + k * colUStride];
136 for (j = i; j < Nrows; j++) {
137 U[i + j * colUStride] *= scale;
144 if ((i != Ncols - 1)) {
145 for (j = i + 1; j < Ncols; j++) {
146 scale += fabs(U[j + i * colUStride]);
151 for (j = i + 1; j < Ncols; j++) {
152 U[j + i * colUStride] = U[j + i * colUStride] / scale;
153 s2 += U[j + i * colUStride] * U[j + i * colUStride];
156 if (U[j + i * colUStride] < 0) {
162 half_norm_squared = U[i + 1 + i * colUStride] * s - s2;
163 U[i + 1 + i * colUStride] -= s;
165 for (k = i + 1; k < Ncols; k++) {
166 superdiag[k] = U[k + i * colUStride] / half_norm_squared;
170 for (j = i + 1; j < Nrows; j++) {
172 for (k = i + 1; k < Ncols; k++) {
173 si += U[k + i * colUStride] * U[k + j * colUStride];
175 for (k = i + 1; k < Ncols; k++) {
176 U[k + j * colUStride] += si * superdiag[k];
182 for (k = i + 1; k < Ncols; k++) {
183 U[k + i * colUStride] *= scale;
190 V[(Ncols - 1) + (Ncols - 1) * colVStride] = 1;
191 s = superdiag[Ncols - 1];
192 for (i = Ncols - 2; i >= 0; i--) {
194 for (j = i + 1; j < Ncols; j++) {
195 V[i + j * colVStride] = U[j + i * colUStride] / (U[i + 1 + i * colUStride] * s);
198 for (j = i + 1; j < Ncols; j++) {
200 for (k = i + 1; k < Ncols; k++) {
201 si += U[k + i * colUStride] * V[j + k * colVStride];
203 for (k = i + 1; k < Ncols; k++) {
204 V[j + k * colVStride] += si * V[i + k * colVStride];
209 for (j = i + 1; j < Ncols; j++) {
211 V[j + i * colVStride] = 0;
212 V[i + j * colVStride] = 0;
214 V[i + i * colVStride] = 1;
218 if (enableReducedForm == 0u) {
222 for (i = Nrows - 1; i >= 0; i--) {
223 for (j = Nrows - 1; j >= 0; j--) {
224 if (j <= Ncols - 1) {
225 U[j + i * colUStride] = U[j + i * colUStride];
228 U[j + i * colUStride] = 0;
235 for (i = Ncols - 1; i >= 0; i--) {
237 for (j = i + 1; j < Ncols; j++) {
238 U[j + i * colUStride] = 0;
243 for (j = i + 1; j < Nrows; j++) {
246 for (k = i + 1; k < Nrows; k++) {
247 si += U[i + k * colUStride] * U[j + k * colUStride];
249 si = si / (U[i + i * colUStride] * s);
250 for (k = i; k < Nrows; k++) {
251 U[j + k * colUStride] += si * U[i + k * colUStride];
256 if (i == Ncols - 1) {
257 for (j = i; j < Nrows; j++) {
258 for (k = Nrows - 1; k >= i + 1; k--) {
259 U[k + j * colUStride] =
260 U[i + j * colUStride] * U[i + k * colUStride] / (U[i + i * colUStride] * s);
262 U[k + j * colUStride] += 1;
267 for (j = i; j < Nrows; j++) {
268 U[i + j * colUStride] = U[i + j * colUStride] / s;
272 if (i == Ncols - 1) {
273 for (k = 1; k <= Nrows - Ncols; k++) {
274 U[i + k + (i + k) * colUStride] = 1;
277 for (j = i; j < Nrows; j++) {
278 U[i + j * colUStride] = 0;
281 U[i + i * colUStride] += 1;
287 for (i = Ncols - 1; i >= 0; i--) {
289 for (j = i + 1; j < Ncols; j++) {
290 U[j + i * colUStride] = 0;
293 for (j = i + 1; j < Ncols; j++) {
295 for (k = i + 1; k < Nrows; k++) {
296 si += U[i + k * colUStride] * U[j + k * colUStride];
298 si = si / (U[i + i * colUStride] * s);
299 for (k = i; k < Nrows; k++) {
300 U[j + k * colUStride] += si * U[i + k * colUStride];
303 for (j = i; j < Nrows; j++) {
304 U[i + j * colUStride] = U[i + j * colUStride] / s;
308 for (j = i; j < Nrows; j++) {
309 U[i + j * colUStride] = 0;
312 U[i + i * colUStride] += 1;
324 const int colUStride,
325 const int colVStride,
326 uint32_t enableReducedForm);
333 const int colUStride,
334 const int colVStride,
335 uint32_t enableReducedForm);
337 template <
typename dataType>
344 const int colUStride,
345 const int colVStride,
346 uint32_t enableReducedForm)
349 int row, i, k, m, rotation_test, iter, total_iter;
350 dataType x, y, z, epsilon;
351 dataType c, s, f, g, h;
360 for (i = 0; i < Ncols; i++) {
361 y = fabs(diag[i]) + fabs(superdiag[i]);
366 if (
sizeof(dataType) == 4) {
367 epsilon = FLT_EPSILON * x;
370 epsilon = DBL_EPSILON * x;
373 for (k = Ncols - 1; k >= 0; k--) {
391 for (m = k; m > 0; m--) {
392 if (fabs(superdiag[m]) <= epsilon) {
396 if (fabs(diag[m - 1]) <= epsilon) {
409 for (i = m; i <= k; i++) {
410 f = s * superdiag[i];
411 superdiag[i] = c * superdiag[i];
412 #if !defined(ENABLE_LDRA_COVERAGE)
416 if (fabs(f) <= epsilon) {
421 h = sqrt(f * f + g * g);
426 if (enableReducedForm == 0u) {
427 for (row = 0; row < Nrows; row++) {
428 y = U[m - 1 + row * colUStride];
429 z = U[i + row * colUStride];
430 U[m - 1 + row * colUStride] = y * c + z * s;
431 U[i + row * colUStride] = -y * s + z * c;
435 for (row = 0; row < Nrows; row++) {
436 y = U[m - 1 + row * colUStride];
437 z = U[i + row * colUStride];
438 U[m - 1 + row * colUStride] = y * c + z * s;
439 U[i + row * colUStride] = -y * s + z * c;
450 for (row = 0; row < Ncols; row++) {
451 V[k + row * colVStride] = -V[k + row * colVStride];
457 #if !defined(ENABLE_LDRA_COVERAGE)
469 g = superdiag[k - 1];
471 f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2 * h * y);
476 f = ((x - z) * (x + z) + h * (y / (f + g) - h)) / x;
480 for (i = m + 1; i <= k; i++) {
485 z = sqrt(f * f + h * h);
486 superdiag[i - 1] = z;
494 for (row = 0; row < Ncols; row++) {
495 x = V[i - 1 + row * colVStride];
496 z = V[i + row * colVStride];
497 V[i - 1 + row * colVStride] = x * c + z * s;
498 V[i + row * colVStride] = -x * s + z * c;
500 z = sqrt(f * f + h * h);
502 #if !defined(ENABLE_LDRA_COVERAGE)
515 if (enableReducedForm == 0u) {
516 for (row = 0; row < Nrows; row++) {
517 y = U[i - 1 + row * colUStride];
518 z = U[i + row * colUStride];
519 U[i - 1 + row * colUStride] = c * y + s * z;
520 U[i + row * colUStride] = -s * y + c * z;
524 for (row = 0; row < Nrows; row++) {
525 y = U[i - 1 + row * colUStride];
526 z = U[i + row * colUStride];
527 U[i - 1 + row * colUStride] = c * y + s * z;
528 U[i + row * colUStride] = -s * y + c * z;
547 const int colUStride,
548 const int colVStride,
549 uint32_t enableReducedForm);
556 const int colUStride,
557 const int colVStride,
558 uint32_t enableReducedForm);
560 template <
typename dataType>
565 dataType *singular_values,
566 const int colUStride,
567 const int colVStride,
568 uint32_t enableReducedForm)
571 int i, j, row, max_index;
573 for (i = 0; i < Ncols - 1; i++) {
575 for (j = i + 1; j < Ncols; j++) {
576 if (singular_values[j] > singular_values[max_index]) {
580 if (max_index != i) {
581 temp = singular_values[i];
582 singular_values[i] = singular_values[max_index];
583 singular_values[max_index] = temp;
584 if (enableReducedForm == 0u) {
585 for (row = 0; row < Nrows; row++) {
586 temp = U[max_index + row * colUStride];
587 U[max_index + row * colUStride] = U[i + row * colUStride];
588 U[i + row * colUStride] = temp;
592 for (row = 0; row < Nrows; row++) {
593 temp = U[max_index + row * colUStride];
594 U[max_index + row * colUStride] = U[i + row * colUStride];
595 U[i + row * colUStride] = temp;
599 for (row = 0; row < Ncols; row++) {
600 temp = V[max_index + row * colVStride];
601 V[max_index + row * colVStride] = V[i + row * colVStride];
602 V[i + row * colVStride] = temp;
612 float *singular_values,
613 const int colUStride,
614 const int colVStride,
615 uint32_t enableReducedForm);
620 double *singular_values,
621 const int colUStride,
622 const int colVStride,
623 uint32_t enableReducedForm);
625 template <
typename dataType>
636 const int32_t strideIn,
637 const int32_t strideU,
638 const int32_t strideV,
639 uint32_t enableReducedForm)
643 int row, col, Nrows1, Ncols1, status;
648 if (Nrows >= Ncols) {
656 int32_t dataSize =
sizeof(dataType);
657 int32_t colUStride = strideU / dataSize;
658 int32_t colVStride = strideV / dataSize;
659 int32_t colAStride = strideIn / dataSize;
661 if (Nrows >= Ncols) {
663 for (row = 0; row < Nrows1; row++) {
664 for (col = 0; col < Ncols1; col++) {
665 U[col + row * colUStride] = A[col + row * colAStride];
671 for (row = 0; row < Nrows1; row++) {
672 for (col = 0; col < Ncols1; col++) {
673 U[col + row * colUStride] = A[row + col * colAStride];
681 DSPF_sp_convert_to_bidiag_cn<dataType>(Nrows1, Ncols1, U, V, diag, superdiag, colUStride, colVStride,
687 status = DSPF_sp_bidiag_to_diag_cn<dataType>(Nrows1, Ncols1, U, V, diag, superdiag, colUStride, colVStride,
693 DSPF_sp_sort_singular_values_cn<dataType>(Nrows1, Ncols1, U, V, diag, colUStride, colVStride, enableReducedForm);
700 if (enableReducedForm == 0u) {
701 memcpy(U1, V,
sizeof(dataType) * Nrows * colVStride);
702 memcpy(V, U,
sizeof(dataType) * Ncols * colUStride);
703 memcpy(U, U1,
sizeof(dataType) * Nrows * colUStride);
706 memcpy(U1, V,
sizeof(dataType) * Ncols * colVStride);
707 memcpy(V, U,
sizeof(dataType) * Ncols * colUStride);
708 memcpy(U, U1,
sizeof(dataType) * Nrows * colUStride);
726 const int32_t strideIn,
727 const int32_t strideU,
728 const int32_t strideV,
729 uint32_t enableReducedForm);
740 const int32_t strideIn,
741 const int32_t strideU,
742 const int32_t strideV,
743 uint32_t enableReducedForm);
745 template <
typename dataType>
750 void *restrict pDiag,
751 void *restrict pSuperDiag,
754 void *restrict pScratch)
761 uint32_t heightIn = pKerPrivArgs->
heightIn;
762 uint32_t widthIn = pKerPrivArgs->
widthIn;
763 int32_t strideIn = pKerPrivArgs->
strideIn;
764 int32_t strideU = pKerPrivArgs->
strideU;
765 int32_t strideV = pKerPrivArgs->
strideV;
769 dataType *pALocal = (dataType *) pA;
770 dataType *pULocal = (dataType *) pU;
771 dataType *pVLocal = (dataType *) pV;
772 dataType *pDiagLocal = (dataType *) pDiag;
773 dataType *pSuperDiagLocal = (dataType *) pSuperDiag;
774 dataType *pU1Local = (dataType *) pU1;
775 dataType *pV1Local = (dataType *) pV1;
777 DSPLIB_DEBUGPRINTFN(0,
"pALocal: %p pOutLocal: %p widthIn: %d heightIn: %d\n", pALocal, pULocal, widthIn, heightIn);
779 #if !defined(ENABLE_LDRA_COVERAGE)
780 int svd_status = DSPF_sp_svd_cn<dataType>(pKerPrivArgs, heightIn, widthIn, pALocal, pULocal, pVLocal, pU1Local, pV1Local, pDiagLocal,
781 pSuperDiagLocal, strideIn, strideU, strideV, enableReducedForm);
786 DSPF_sp_svd_cn<dataType>(pKerPrivArgs, heightIn, widthIn, pALocal, pULocal, pVLocal, pU1Local, pV1Local, pDiagLocal,
787 pSuperDiagLocal, strideIn, strideU, strideV, enableReducedForm);
797 void *restrict pDiag,
798 void *restrict pSuperDiag,
801 void *restrict pScratch);
807 void *restrict pDiag,
808 void *restrict pSuperDiag,
811 void *restrict pScratch);
DSPLIB_STATUS DSPLIB_svd_exec_cn(DSPLIB_kernelHandle handle, void *restrict pA, void *restrict pU, void *restrict pV, void *restrict pDiag, void *restrict pSuperDiag, void *restrict pU1, void *restrict pV1, void *restrict pScratch)
This function is the main execution function for the natural C implementation of the kernel....
template int DSPF_sp_sort_singular_values_cn< double >(const int Nrows, const int Ncols, double *U, double *V, double *singular_values, const int colUStride, const int colVStride, uint32_t enableReducedForm)
template int DSPF_sp_convert_to_bidiag_cn< float >(const int Nrows, const int Ncols, float *U, float *V, float *diag, float *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
int DSPF_sp_bidiag_to_diag_cn(const int Nrows, const int Ncols, dataType *U, dataType *V, dataType *diag, dataType *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
template int DSPF_sp_convert_to_bidiag_cn< double >(const int Nrows, const int Ncols, double *U, double *V, double *diag, double *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
DSPLIB_STATUS DSPLIB_svd_init_cn(DSPLIB_kernelHandle handle, const DSPLIB_bufParams2D_t *bufParamsIn, const DSPLIB_bufParams2D_t *bufParamsU, const DSPLIB_bufParams2D_t *bufParamsV, const DSPLIB_bufParams1D_t *bufParamsDiag, const DSPLIB_bufParams1D_t *bufParamsSuperDiag, const DSPLIB_svdInitArgs *pKerInitArgs)
This function is the initialization function for the natural C implementation of the kernel....
template DSPLIB_STATUS DSPLIB_svd_exec_cn< double >(DSPLIB_kernelHandle handle, void *restrict pA, void *restrict pU, void *restrict pV, void *restrict pDiag, void *restrict pSuperDiag, void *restrict pU1, void *restrict pV1, void *restrict pScratch)
int DSPF_sp_sort_singular_values_cn(const int Nrows, const int Ncols, dataType *U, dataType *V, dataType *singular_values, const int colUStride, const int colVStride, uint32_t enableReducedForm)
template int DSPF_sp_svd_cn< float >(DSPLIB_svd_PrivArgs *pKerPrivArgs, const int Nrows, const int Ncols, float *A, float *U, float *V, float *U1, float *V1, float *diag, float *superdiag, const int32_t strideIn, const int32_t strideU, const int32_t strideV, uint32_t enableReducedForm)
template DSPLIB_STATUS DSPLIB_svd_exec_cn< float >(DSPLIB_kernelHandle handle, void *restrict pA, void *restrict pU, void *restrict pV, void *restrict pDiag, void *restrict pSuperDiag, void *restrict pU1, void *restrict pV1, void *restrict pScratch)
template int DSPF_sp_svd_cn< double >(DSPLIB_svd_PrivArgs *pKerPrivArgs, const int Nrows, const int Ncols, double *A, double *U, double *V, double *U1, double *V1, double *diag, double *superdiag, const int32_t strideIn, const int32_t strideU, const int32_t strideV, uint32_t enableReducedForm)
int DSPF_sp_convert_to_bidiag_cn(const int Nrows, const int Ncols, dataType *U, dataType *V, dataType *diag, dataType *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
template int DSPF_sp_sort_singular_values_cn< float >(const int Nrows, const int Ncols, float *U, float *V, float *singular_values, const int colUStride, const int colVStride, uint32_t enableReducedForm)
int DSPF_sp_svd_cn(DSPLIB_svd_PrivArgs *pKerPrivArgs, const int Nrows, const int Ncols, dataType *A, dataType *U, dataType *V, dataType *U1, dataType *V1, dataType *diag, dataType *superdiag, const int32_t strideIn, const int32_t strideU, const int32_t strideV, uint32_t enableReducedForm)
template int DSPF_sp_bidiag_to_diag_cn< double >(const int Nrows, const int Ncols, double *U, double *V, double *diag, double *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
template int DSPF_sp_bidiag_to_diag_cn< float >(const int Nrows, const int Ncols, float *U, float *V, float *diag, float *superdiag, const int colUStride, const int colVStride, uint32_t enableReducedForm)
#define MAX_ITERATION_COUNT
Header file for kernel's internal use. For the kernel's interface, please see DSPLIB_svd.
#define DSPLIB_DEBUGPRINTFN(N, fmt,...)
DSPLIB_STATUS_NAME
The enumeration of all status codes.
void * DSPLIB_kernelHandle
Handle type for DSPLIB operations.
A structure for a 1 dimensional buffer descriptor.
A structure for a 2 dimensional buffer descriptor.
Structure containing the parameters to initialize the kernel.
Structure that is reserved for internal use by the kernel.
uint32_t widthIn
Size of input buffer for different batches DSPLIB_svd_init that will be retrieved and used by DSPLIB_...
uint32_t strideU
Stride between rows of U matrix
uint32_t enableReducedForm
Flag for enabling the calculation of reduced form enableReducedForm = 1 for reduced form SVD calc ena...
int32_t strideIn
Stride between rows of input data matrix
uint32_t strideV
Stride between rows of V matrix
uint32_t heightIn
Height of input data matrix