HermesCommon  2.0
superlu_solver.h
Go to the documentation of this file.
1 // This file is part of HermesCommon
2 //
3 // Copyright (c) 2009 hp-FEM group at the University of Nevada, Reno (UNR).
4 // Email: hpfem-group@unr.edu, home page: http://hpfem.org/.
5 //
6 // Hermes2D is free software; you can redistribute it and/or modify
7 // it under the terms of the GNU General Public License as published
8 // by the Free Software Foundation; either version 2 of the License,
9 // or (at your option) any later version.
10 //
11 // Hermes2D is distributed in the hope that it will be useful,
12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 // GNU General Public License for more details.
15 //
16 // You should have received a copy of the GNU General Public License
17 // along with Hermes2D; if not, write to the Free Software
18 // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
22 #ifndef __HERMES_COMMON_SUPERLU_SOLVER_H_
23 #define __HERMES_COMMON_SUPERLU_SOLVER_H_
24 
25 #include "config.h"
26 #ifdef WITH_SUPERLU
27 typedef int int_t;
28 #include "linear_matrix_solver.h"
29 #include "matrix.h"
30 
31 #include <supermatrix.h>
32 #include <slu_util.h>
33 namespace Hermes
34 {
35  namespace Solvers
36  {
37  template <typename Scalar> class SuperLUSolver;
38 #ifdef SLU_MT
39  template <typename Scalar>
40  class SuperLu
41  {
42  public:
43  void gsequ (SuperMatrix *A, double *r, double *c, double *rowcnd, double *colcnd, double *amax, int *info);
44  void laqgs (SuperMatrix *A, float *r, float *c, float rowcnd, float colcnd, float amax, char *equed);
45  int gstrf (superlu_options_t *options, int m, int n, double anorm, LUstruct_t *LUstruct, gridinfo_t *grid, SuperLUStat_t *stat, int *info);
46  float pivotGrowth (int ncols, SuperMatrix *A, int *perm_c, SuperMatrix *L, SuperMatrix *U);
47  float langs (char *norm, SuperMatrix *A);
48  void gscon (char *norm, SuperMatrix *L, SuperMatrix *U, float anorm, float *rcond, SuperLUStat_t *stat, int *info);
49  void gstrs (trans_t trans, SuperMatrix *L, SuperMatrix *U, int *perm_c, int *perm_r, SuperMatrix *B, SuperLUStat_t *stat, int *info);
50  double lamch_ (char *cmach);
51  int querySpace (SuperMatrix *, SuperMatrix *, slu_memusage_t *);
52  }
53 #else //SLU_MT
54 
55  typedef superlu_options_t slu_options_t;
56  typedef SuperLUStat_t slu_stat_t;
57  typedef struct
58  {
59  float for_lu;
60  float total_needed;
61  } slu_memusage_t;
62 #define SLU_DESTROY_L Destroy_SuperNode_Matrix
63 #define SLU_DESTROY_U Destroy_CompCol_Matrix
64 #define SLU_INIT_STAT(stat_ptr) StatInit(stat_ptr)
65 #define SLU_PRINT_STAT(stat_ptr) StatPrint(stat_ptr)
66 
67 #define SLU_DTYPE SLU_D
68 
69 #define SLU_PRINT_CSC_MATRIX zPrint_CompCol_Matrix
70 #define Scalar_MALLOC doublecomplexMalloc
71 
73  template<typename Scalar> struct SuperLuType;
74 
76  template<>
77  struct SuperLuType<double>
78  {
80  typedef double Scalar;
81  };
82 
84  template<>
85  struct SuperLuType<std::complex<double> >
86  {
88  typedef struct { double r, i; } Scalar;
89  };
90 #endif //SLU_MT
91  }
92 }
93 
94 namespace Hermes
95 {
96  namespace Algebra
97  {
99  template <typename Scalar>
100  class SuperLUMatrix : public SparseMatrix<Scalar>
101  {
102  public:
103  SuperLUMatrix();
104  virtual ~SuperLUMatrix();
105 
106  virtual void alloc();
107  virtual void free();
108  virtual Scalar get(unsigned int m, unsigned int n);
109  virtual void zero();
110  virtual void add(unsigned int m, unsigned int n, Scalar v);
111  virtual void add_to_diagonal(Scalar v);
112  virtual void add(unsigned int m, unsigned int n, Scalar **mat, int *rows, int *cols);
113  virtual bool dump(FILE *file, const char *var_name, EMatrixDumpFormat fmt = DF_MATLAB_SPARSE, char* number_format = "%lf");
114  virtual unsigned int get_matrix_size() const;
115  virtual unsigned int get_nnz() const;
116  virtual double get_fill_in() const;
119  virtual void add_matrix(SuperLUMatrix* mat);
123  virtual void add_to_diagonal_blocks(int num_stages, SuperLUMatrix* mat);
124  virtual void add_sparse_to_diagonal_blocks(int num_stages, SparseMatrix<Scalar>* mat);
129  virtual void add_as_block(unsigned int i, unsigned int j, SuperLUMatrix* mat);
130 
131  // Applies the matrix to vector_in and saves result to vector_out.
132  void multiply_with_vector(Scalar* vector_in, Scalar* vector_out);
133  // Multiplies matrix with a Scalar.
134  void multiply_with_Scalar(Scalar value);
141  void create(unsigned int size, unsigned int nnz, int* ap, int* ai, Scalar* ax);
142  // Duplicates a matrix (including allocation).
143  SuperLUMatrix<Scalar>* duplicate();
144 
145  protected:
146  // SUPERLU specific data structures for storing the matrix (CSC format).
148  Scalar *Ax;
150  int *Ai;
152  unsigned int *Ap;
154  unsigned int nnz;
155 
156  friend class Solvers::SuperLUSolver<Scalar>;
157  template<typename T> friend SparseMatrix<T>* create_matrix();
158  };
159 
161  template <typename Scalar>
162  class SuperLUVector : public Vector<Scalar>
163  {
164  public:
165  SuperLUVector();
166  virtual ~SuperLUVector();
167 
168  virtual void alloc(unsigned int ndofs);
169  virtual void free();
170  virtual Scalar get(unsigned int idx);
171  virtual void extract(Scalar *v) const;
172  virtual void zero();
173  virtual void change_sign();
174  virtual void set(unsigned int idx, Scalar y);
175  virtual void add(unsigned int idx, Scalar y);
176  virtual void add(unsigned int n, unsigned int *idx, Scalar *y);
177  virtual void add_vector(Vector<Scalar>* vec);
178  virtual void add_vector(Scalar* vec);
179  virtual bool dump(FILE *file, const char *var_name, EMatrixDumpFormat fmt = DF_MATLAB_SPARSE, char* number_format = "%lf");
180 
181  protected:
183  Scalar *v; // Vector entries.
184 
185  friend class Solvers::SuperLUSolver<Scalar>;
186  };
187  }
188  namespace Solvers
189  {
193  template <typename Scalar>
194  class HERMES_API SuperLUSolver : public DirectSolver<Scalar>
195  {
196  public:
200  SuperLUSolver(SuperLUMatrix<Scalar> *m, SuperLUVector<Scalar> *rhs);
201  virtual ~SuperLUSolver();
202 
203  virtual bool solve();
204  virtual int get_matrix_size();
205 
206  protected:
208  SuperLUMatrix<Scalar> *m;
210  SuperLUVector<Scalar> *rhs;
211 
212  bool has_A, has_B;
213  bool inited;
214  bool A_changed;
215  // internally during factorization or externally by
216  // the user.
217 
220  bool check_status(unsigned int info);
221 
224  int *local_Ai, *local_Ap;
225  typename SuperLuType<Scalar>::Scalar *local_Ax, *local_rhs;
226 
227  bool setup_factorization();
228  void free_factorization_data();
229  void free_matrix();
230  void free_rhs();
231 
232  SuperMatrix A, B;
233  SuperMatrix L, U;
234  double *R, *C;
235  int *perm_r;
236  int *perm_c;
237  int *etree;
238  slu_options_t options;
239 
240  private:
241 #ifndef SLU_MT
242  void create_csc_matrix (SuperMatrix *A, int m, int n, int nnz, typename SuperLuType<Scalar>::Scalar *nzval, int *rowind, int *colptr,
243  Stype_t stype, Dtype_t dtype, Mtype_t mtype);
244  void solver_driver (superlu_options_t *options, SuperMatrix *A, int *perm_c, int *perm_r, int *etree, char *equed, double *R,
245  double *C, SuperMatrix *L, SuperMatrix *U, void *work, int lwork, SuperMatrix *B, SuperMatrix *X, double *recip_pivot_growth,
246  double *rcond, double *ferr, double *berr, slu_memusage_t *mem_usage, SuperLUStat_t *stat, int *info);
247  void create_dense_matrix (SuperMatrix *X, int m, int n, typename SuperLuType<Scalar>::Scalar *x, int ldx, Stype_t stype, Dtype_t dtype, Mtype_t mtype);
248 #endif //SLU_MT
249 
250 #ifndef SLU_MT
251  char equed[1];
252 #else
253  equed_t equed;
254  SuperMatrix AC;
255 #endif //SLU_MT
256  template<typename T> friend LinearMatrixSolver<T>* create_linear_solver(Matrix<T>* matrix, Vector<T>* rhs);
257  };
258  }
259 }
260 #endif //SUPER_LU
261 #endif