22 #ifndef __HERMES_COMMON_SUPERLU_SOLVER_H_
23 #define __HERMES_COMMON_SUPERLU_SOLVER_H_
28 #include "linear_matrix_solver.h"
31 #include <supermatrix.h>
37 template <
typename Scalar>
class SuperLUSolver;
39 template <
typename Scalar>
43 void gsequ (SuperMatrix *A,
double *r,
double *c,
double *rowcnd,
double *colcnd,
double *amax,
int *info);
44 void laqgs (SuperMatrix *A,
float *r,
float *c,
float rowcnd,
float colcnd,
float amax,
char *equed);
45 int gstrf (superlu_options_t *options,
int m,
int n,
double anorm, LUstruct_t *LUstruct, gridinfo_t *grid, SuperLUStat_t *stat,
int *info);
46 float pivotGrowth (
int ncols, SuperMatrix *A,
int *perm_c, SuperMatrix *L, SuperMatrix *U);
47 float langs (
char *norm, SuperMatrix *A);
48 void gscon (
char *norm, SuperMatrix *L, SuperMatrix *U,
float anorm,
float *rcond, SuperLUStat_t *stat,
int *info);
49 void gstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U,
int *perm_c,
int *perm_r, SuperMatrix *B, SuperLUStat_t *stat,
int *info);
50 double lamch_ (
char *cmach);
51 int querySpace (SuperMatrix *, SuperMatrix *, slu_memusage_t *);
55 typedef superlu_options_t slu_options_t;
56 typedef SuperLUStat_t slu_stat_t;
62 #define SLU_DESTROY_L Destroy_SuperNode_Matrix
63 #define SLU_DESTROY_U Destroy_CompCol_Matrix
64 #define SLU_INIT_STAT(stat_ptr) StatInit(stat_ptr)
65 #define SLU_PRINT_STAT(stat_ptr) StatPrint(stat_ptr)
67 #define SLU_DTYPE SLU_D
69 #define SLU_PRINT_CSC_MATRIX zPrint_CompCol_Matrix
70 #define Scalar_MALLOC doublecomplexMalloc
73 template<
typename Scalar>
struct SuperLuType;
77 struct SuperLuType<double>
80 typedef double Scalar;
85 struct SuperLuType<std::complex<double> >
88 typedef struct {
double r, i; } Scalar;
99 template <
typename Scalar>
100 class SuperLUMatrix :
public SparseMatrix<Scalar>
104 virtual ~SuperLUMatrix();
106 virtual void alloc();
108 virtual Scalar
get(
unsigned int m,
unsigned int n);
110 virtual void add(
unsigned int m,
unsigned int n, Scalar v);
111 virtual void add_to_diagonal(Scalar v);
112 virtual void add(
unsigned int m,
unsigned int n, Scalar **mat,
int *rows,
int *cols);
114 virtual unsigned int get_matrix_size()
const;
115 virtual unsigned int get_nnz()
const;
116 virtual double get_fill_in()
const;
119 virtual void add_matrix(SuperLUMatrix* mat);
123 virtual void add_to_diagonal_blocks(
int num_stages, SuperLUMatrix* mat);
124 virtual void add_sparse_to_diagonal_blocks(
int num_stages, SparseMatrix<Scalar>* mat);
129 virtual void add_as_block(
unsigned int i,
unsigned int j, SuperLUMatrix* mat);
132 void multiply_with_vector(Scalar* vector_in, Scalar* vector_out);
134 void multiply_with_Scalar(Scalar value);
141 void create(
unsigned int size,
unsigned int nnz,
int* ap,
int* ai, Scalar* ax);
143 SuperLUMatrix<Scalar>* duplicate();
156 friend class Solvers::SuperLUSolver<Scalar>;
157 template<
typename T>
friend SparseMatrix<T>*
create_matrix();
161 template <
typename Scalar>
162 class SuperLUVector :
public Vector<Scalar>
166 virtual ~SuperLUVector();
168 virtual void alloc(
unsigned int ndofs);
170 virtual Scalar
get(
unsigned int idx);
171 virtual void extract(Scalar *v)
const;
173 virtual void change_sign();
174 virtual void set(
unsigned int idx, Scalar y);
175 virtual void add(
unsigned int idx, Scalar y);
176 virtual void add(
unsigned int n,
unsigned int *idx, Scalar *y);
177 virtual void add_vector(Vector<Scalar>* vec);
178 virtual void add_vector(Scalar* vec);
185 friend class Solvers::SuperLUSolver<Scalar>;
193 template <
typename Scalar>
194 class HERMES_API SuperLUSolver :
public DirectSolver<Scalar>
200 SuperLUSolver(SuperLUMatrix<Scalar> *m, SuperLUVector<Scalar> *rhs);
201 virtual ~SuperLUSolver();
203 virtual bool solve();
204 virtual int get_matrix_size();
208 SuperLUMatrix<Scalar> *m;
210 SuperLUVector<Scalar> *rhs;
220 bool check_status(
unsigned int info);
224 int *local_Ai, *local_Ap;
225 typename SuperLuType<Scalar>::Scalar *local_Ax, *local_rhs;
227 bool setup_factorization();
228 void free_factorization_data();
238 slu_options_t options;
242 void create_csc_matrix (SuperMatrix *A,
int m,
int n,
int nnz,
typename SuperLuType<Scalar>::Scalar *nzval,
int *rowind,
int *colptr,
243 Stype_t stype, Dtype_t dtype, Mtype_t mtype);
244 void solver_driver (superlu_options_t *options, SuperMatrix *A,
int *perm_c,
int *perm_r,
int *etree,
char *equed,
double *R,
245 double *C, SuperMatrix *L, SuperMatrix *U,
void *work,
int lwork, SuperMatrix *B, SuperMatrix *X,
double *recip_pivot_growth,
246 double *rcond,
double *ferr,
double *berr, slu_memusage_t *mem_usage, SuperLUStat_t *stat,
int *info);
247 void create_dense_matrix (SuperMatrix *X,
int m,
int n,
typename SuperLuType<Scalar>::Scalar *x,
int ldx, Stype_t stype, Dtype_t dtype, Mtype_t mtype);
256 template<
typename T>
friend LinearMatrixSolver<T>*
create_linear_solver(Matrix<T>* matrix, Vector<T>* rhs);