Blender  V2.59
sgstrs.c
Go to the documentation of this file.
00001 
00005 /*
00006  * -- SuperLU routine (version 3.0) --
00007  * Univ. of California Berkeley, Xerox Palo Alto Research Center,
00008  * and Lawrence Berkeley National Lab.
00009  * October 15, 2003
00010  *
00011  */
00012 /*
00013   Copyright (c) 1994 by Xerox Corporation.  All rights reserved.
00014  
00015   THIS MATERIAL IS PROVIDED AS IS, WITH ABSOLUTELY NO WARRANTY
00016   EXPRESSED OR IMPLIED.  ANY USE IS AT YOUR OWN RISK.
00017  
00018   Permission is hereby granted to use or copy this program for any
00019   purpose, provided the above notices are retained on all copies.
00020   Permission to modify the code and to distribute modified code is
00021   granted, provided the above notices are retained, and a notice that
00022   the code was modified is included with the above copyright notice.
00023 */
00024 
00025 #include "ssp_defs.h"
00026 
00027 
00028 /* 
00029  * Function prototypes 
00030  */
00031 void susolve(int, int, float*, float*);
00032 void slsolve(int, int, float*, float*);
00033 void smatvec(int, int, int, float*, float*, float*);
00034 void sprint_soln(int , float *);
00035 
00036 void
00037 sgstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
00038         int *perm_c, int *perm_r, SuperMatrix *B,
00039         SuperLUStat_t *stat, int *info)
00040 {
00041 /*
00042  * Purpose
00043  * =======
00044  *
00045  * SGSTRS solves a system of linear equations A*X=B or A'*X=B
00046  * with A sparse and B dense, using the LU factorization computed by
00047  * SGSTRF.
00048  *
00049  * See supermatrix.h for the definition of 'SuperMatrix' structure.
00050  *
00051  * Arguments
00052  * =========
00053  *
00054  * trans   (input) trans_t
00055  *          Specifies the form of the system of equations:
00056  *          = NOTRANS: A * X = B  (No transpose)
00057  *          = TRANS:   A'* X = B  (Transpose)
00058  *          = CONJ:    A**H * X = B  (Conjugate transpose)
00059  *
00060  * L       (input) SuperMatrix*
00061  *         The factor L from the factorization Pr*A*Pc=L*U as computed by
00062  *         sgstrf(). Use compressed row subscripts storage for supernodes,
00063  *         i.e., L has types: Stype = SLU_SC, Dtype = SLU_S, Mtype = SLU_TRLU.
00064  *
00065  * U       (input) SuperMatrix*
00066  *         The factor U from the factorization Pr*A*Pc=L*U as computed by
00067  *         sgstrf(). Use column-wise storage scheme, i.e., U has types:
00068  *         Stype = SLU_NC, Dtype = SLU_S, Mtype = SLU_TRU.
00069  *
00070  * perm_c  (input) int*, dimension (L->ncol)
00071  *         Column permutation vector, which defines the 
00072  *         permutation matrix Pc; perm_c[i] = j means column i of A is 
00073  *         in position j in A*Pc.
00074  *
00075  * perm_r  (input) int*, dimension (L->nrow)
00076  *         Row permutation vector, which defines the permutation matrix Pr; 
00077  *         perm_r[i] = j means row i of A is in position j in Pr*A.
00078  *
00079  * B       (input/output) SuperMatrix*
00080  *         B has types: Stype = SLU_DN, Dtype = SLU_S, Mtype = SLU_GE.
00081  *         On entry, the right hand side matrix.
00082  *         On exit, the solution matrix if info = 0;
00083  *
00084  * stat     (output) SuperLUStat_t*
00085  *          Record the statistics on runtime and floating-point operation count.
00086  *          See util.h for the definition of 'SuperLUStat_t'.
00087  *
00088  * info    (output) int*
00089  *         = 0: successful exit
00090  *         < 0: if info = -i, the i-th argument had an illegal value
00091  *
00092  */
00093 #ifdef _CRAY
00094     _fcd ftcs1, ftcs2, ftcs3, ftcs4;
00095 #endif
00096 #ifdef USE_VENDOR_BLAS
00097     float   alpha = 1.0, beta = 1.0;
00098     float   *work_col;
00099 #endif
00100     DNformat *Bstore;
00101     float   *Bmat;
00102     SCformat *Lstore;
00103     NCformat *Ustore;
00104     float   *Lval, *Uval;
00105     int      fsupc, nrow, nsupr, nsupc, luptr, istart, irow;
00106     int      i, j, k, iptr, jcol, n, ldb, nrhs;
00107     float   *work, *rhs_work, *soln;
00108     flops_t  solve_ops;
00109     void sprint_soln();
00110 
00111     /* Test input parameters ... */
00112     *info = 0;
00113     Bstore = B->Store;
00114     ldb = Bstore->lda;
00115     nrhs = B->ncol;
00116     if ( trans != NOTRANS && trans != TRANS && trans != CONJ ) *info = -1;
00117     else if ( L->nrow != L->ncol || L->nrow < 0 ||
00118               L->Stype != SLU_SC || L->Dtype != SLU_S || L->Mtype != SLU_TRLU )
00119         *info = -2;
00120     else if ( U->nrow != U->ncol || U->nrow < 0 ||
00121               U->Stype != SLU_NC || U->Dtype != SLU_S || U->Mtype != SLU_TRU )
00122         *info = -3;
00123     else if ( ldb < SUPERLU_MAX(0, L->nrow) ||
00124               B->Stype != SLU_DN || B->Dtype != SLU_S || B->Mtype != SLU_GE )
00125         *info = -6;
00126     if ( *info ) {
00127         i = -(*info);
00128         xerbla_("sgstrs", &i);
00129         return;
00130     }
00131 
00132     n = L->nrow;
00133     work = floatCalloc(n * nrhs);
00134     if ( !work ) ABORT("Malloc fails for local work[].");
00135     soln = floatMalloc(n);
00136     if ( !soln ) ABORT("Malloc fails for local soln[].");
00137 
00138     Bmat = Bstore->nzval;
00139     Lstore = L->Store;
00140     Lval = Lstore->nzval;
00141     Ustore = U->Store;
00142     Uval = Ustore->nzval;
00143     solve_ops = 0;
00144     
00145     if ( trans == NOTRANS ) {
00146         /* Permute right hand sides to form Pr*B */
00147         for (i = 0; i < nrhs; i++) {
00148             rhs_work = &Bmat[i*ldb];
00149             for (k = 0; k < n; k++) soln[perm_r[k]] = rhs_work[k];
00150             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00151         }
00152         
00153         /* Forward solve PLy=Pb. */
00154         for (k = 0; k <= Lstore->nsuper; k++) {
00155             fsupc = L_FST_SUPC(k);
00156             istart = L_SUB_START(fsupc);
00157             nsupr = L_SUB_START(fsupc+1) - istart;
00158             nsupc = L_FST_SUPC(k+1) - fsupc;
00159             nrow = nsupr - nsupc;
00160 
00161             solve_ops += nsupc * (nsupc - 1) * nrhs;
00162             solve_ops += 2 * nrow * nsupc * nrhs;
00163             
00164             if ( nsupc == 1 ) {
00165                 for (j = 0; j < nrhs; j++) {
00166                     rhs_work = &Bmat[j*ldb];
00167                     luptr = L_NZ_START(fsupc);
00168                     for (iptr=istart+1; iptr < L_SUB_START(fsupc+1); iptr++){
00169                         irow = L_SUB(iptr);
00170                         ++luptr;
00171                         rhs_work[irow] -= rhs_work[fsupc] * Lval[luptr];
00172                     }
00173                 }
00174             } else {
00175                 luptr = L_NZ_START(fsupc);
00176 #ifdef USE_VENDOR_BLAS
00177 #ifdef _CRAY
00178                 ftcs1 = _cptofcd("L", strlen("L"));
00179                 ftcs2 = _cptofcd("N", strlen("N"));
00180                 ftcs3 = _cptofcd("U", strlen("U"));
00181                 STRSM( ftcs1, ftcs1, ftcs2, ftcs3, &nsupc, &nrhs, &alpha,
00182                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00183                 
00184                 SGEMM( ftcs2, ftcs2, &nrow, &nrhs, &nsupc, &alpha, 
00185                         &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
00186                         &beta, &work[0], &n );
00187 #else
00188                 strsm_("L", "L", "N", "U", &nsupc, &nrhs, &alpha,
00189                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00190                 
00191                 sgemm_( "N", "N", &nrow, &nrhs, &nsupc, &alpha, 
00192                         &Lval[luptr+nsupc], &nsupr, &Bmat[fsupc], &ldb, 
00193                         &beta, &work[0], &n );
00194 #endif
00195                 for (j = 0; j < nrhs; j++) {
00196                     rhs_work = &Bmat[j*ldb];
00197                     work_col = &work[j*n];
00198                     iptr = istart + nsupc;
00199                     for (i = 0; i < nrow; i++) {
00200                         irow = L_SUB(iptr);
00201                         rhs_work[irow] -= work_col[i]; /* Scatter */
00202                         work_col[i] = 0.0;
00203                         iptr++;
00204                     }
00205                 }
00206 #else           
00207                 for (j = 0; j < nrhs; j++) {
00208                     rhs_work = &Bmat[j*ldb];
00209                     slsolve (nsupr, nsupc, &Lval[luptr], &rhs_work[fsupc]);
00210                     smatvec (nsupr, nrow, nsupc, &Lval[luptr+nsupc],
00211                             &rhs_work[fsupc], &work[0] );
00212 
00213                     iptr = istart + nsupc;
00214                     for (i = 0; i < nrow; i++) {
00215                         irow = L_SUB(iptr);
00216                         rhs_work[irow] -= work[i];
00217                         work[i] = 0.0;
00218                         iptr++;
00219                     }
00220                 }
00221 #endif              
00222             } /* else ... */
00223         } /* for L-solve */
00224 
00225 #ifdef DEBUG
00226         printf("After L-solve: y=\n");
00227         sprint_soln(n, Bmat);
00228 #endif
00229 
00230         /*
00231          * Back solve Ux=y.
00232          */
00233         for (k = Lstore->nsuper; k >= 0; k--) {
00234             fsupc = L_FST_SUPC(k);
00235             istart = L_SUB_START(fsupc);
00236             nsupr = L_SUB_START(fsupc+1) - istart;
00237             nsupc = L_FST_SUPC(k+1) - fsupc;
00238             luptr = L_NZ_START(fsupc);
00239 
00240             solve_ops += nsupc * (nsupc + 1) * nrhs;
00241 
00242             if ( nsupc == 1 ) {
00243                 rhs_work = &Bmat[0];
00244                 for (j = 0; j < nrhs; j++) {
00245                     rhs_work[fsupc] /= Lval[luptr];
00246                     rhs_work += ldb;
00247                 }
00248             } else {
00249 #ifdef USE_VENDOR_BLAS
00250 #ifdef _CRAY
00251                 ftcs1 = _cptofcd("L", strlen("L"));
00252                 ftcs2 = _cptofcd("U", strlen("U"));
00253                 ftcs3 = _cptofcd("N", strlen("N"));
00254                 STRSM( ftcs1, ftcs2, ftcs3, ftcs3, &nsupc, &nrhs, &alpha,
00255                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00256 #else
00257                 strsm_("L", "U", "N", "N", &nsupc, &nrhs, &alpha,
00258                        &Lval[luptr], &nsupr, &Bmat[fsupc], &ldb);
00259 #endif
00260 #else           
00261                 for (j = 0; j < nrhs; j++)
00262                     susolve ( nsupr, nsupc, &Lval[luptr], &Bmat[fsupc+j*ldb] );
00263 #endif          
00264             }
00265 
00266             for (j = 0; j < nrhs; ++j) {
00267                 rhs_work = &Bmat[j*ldb];
00268                 for (jcol = fsupc; jcol < fsupc + nsupc; jcol++) {
00269                     solve_ops += 2*(U_NZ_START(jcol+1) - U_NZ_START(jcol));
00270                     for (i = U_NZ_START(jcol); i < U_NZ_START(jcol+1); i++ ){
00271                         irow = U_SUB(i);
00272                         rhs_work[irow] -= rhs_work[jcol] * Uval[i];
00273                     }
00274                 }
00275             }
00276             
00277         } /* for U-solve */
00278 
00279 #ifdef DEBUG
00280         printf("After U-solve: x=\n");
00281         sprint_soln(n, Bmat);
00282 #endif
00283 
00284         /* Compute the final solution X := Pc*X. */
00285         for (i = 0; i < nrhs; i++) {
00286             rhs_work = &Bmat[i*ldb];
00287             for (k = 0; k < n; k++) soln[k] = rhs_work[perm_c[k]];
00288             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00289         }
00290         
00291         stat->ops[SOLVE] = solve_ops;
00292 
00293     } else { /* Solve A'*X=B or CONJ(A)*X=B */
00294         /* Permute right hand sides to form Pc'*B. */
00295         for (i = 0; i < nrhs; i++) {
00296             rhs_work = &Bmat[i*ldb];
00297             for (k = 0; k < n; k++) soln[perm_c[k]] = rhs_work[k];
00298             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00299         }
00300 
00301         stat->ops[SOLVE] = 0;
00302         for (k = 0; k < nrhs; ++k) {
00303             
00304             /* Multiply by inv(U'). */
00305             sp_strsv("U", "T", "N", L, U, &Bmat[k*ldb], stat, info);
00306             
00307             /* Multiply by inv(L'). */
00308             sp_strsv("L", "T", "U", L, U, &Bmat[k*ldb], stat, info);
00309             
00310         }
00311         /* Compute the final solution X := Pr'*X (=inv(Pr)*X) */
00312         for (i = 0; i < nrhs; i++) {
00313             rhs_work = &Bmat[i*ldb];
00314             for (k = 0; k < n; k++) soln[k] = rhs_work[perm_r[k]];
00315             for (k = 0; k < n; k++) rhs_work[k] = soln[k];
00316         }
00317 
00318     }
00319 
00320     SUPERLU_FREE(work);
00321     SUPERLU_FREE(soln);
00322 }
00323 
00324 /*
00325  * Diagnostic print of the solution vector 
00326  */
00327 void
00328 sprint_soln(int n, float *soln)
00329 {
00330     int i;
00331 
00332     for (i = 0; i < n; i++) 
00333         printf("\t%d: %.4f\n", i, soln[i]);
00334 }