matrix.cpp 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #include "matrix.hpp"
  2. void
  3. Matrix::init(size_t nrow,size_t ncol){
  4. if(data!=nullptr) delete[] data;
  5. nr=nrow;
  6. nc=ncol;
  7. nc_avx=(nc-1)/4+1;
  8. nc_full=4*nc_avx;
  9. data=(double*)new __m256d[nr*nc_full];
  10. }
  11. //---------------
  12. // Matrix::clear
  13. //---------------
  14. void
  15. Matrix::clear(){
  16. __m256d* avx=(__m256d*)data;
  17. for(size_t i=0;i<nr*nc_avx;++i){
  18. avx[i]=zeros;
  19. }
  20. }
  21. void
  22. Matrix::display() const{
  23. for(size_t i=0;i<nr;++i){
  24. for(size_t j=0;j<nc;++j){
  25. cout<<get(i,j)<<'\t';
  26. }
  27. cout<<endl;
  28. }
  29. }
  30. void
  31. Matrix::swap_lines(size_t i,size_t j){
  32. __m256d* avx_i=get_avx_row(i);
  33. __m256d* avx_j=get_avx_row(j);
  34. for(size_t k=0;k<nc_avx;++k){
  35. __m256d a=*avx_i;
  36. *avx_i=*avx_j;
  37. *avx_j=a;
  38. ++avx_i;
  39. ++avx_j;
  40. }
  41. }
  42. void
  43. Matrix::mul_line(size_t i,double a){
  44. __m256d b=_mm256_set1_pd(a);
  45. __m256d* avx=get_avx_row(i);
  46. for(size_t k=0;k<nc_avx;++k){
  47. *avx=_mm256_mul_pd(*avx,b);
  48. ++avx;
  49. }
  50. }
  51. void
  52. Matrix::add_mul_line(size_t i,size_t j,double a){
  53. __m256d b=_mm256_set1_pd(a);
  54. __m256d* avx_i=get_avx_row(i);
  55. __m256d* avx_j=get_avx_row(j);
  56. for(size_t k=0;k<nc_avx;++k){
  57. *avx_i=_mm256_fmadd_pd(*avx_j,b,*avx_i);
  58. ++avx_i;
  59. ++avx_j;
  60. }
  61. }
  62. double
  63. Matrix::Gauss(){
  64. double det=1;
  65. size_t np=0; //np=0
  66. for(size_t j=0;j<nc;++j){
  67. for(size_t p=np;p<nr;++p){
  68. double c=get(p,j);
  69. if(c!=0){
  70. det*=c;
  71. mul_line(p,1.0/c);
  72. for(size_t k=0;k<nr;++k){
  73. if(k!=p){
  74. add_mul_line(k,p,-get(k,j));
  75. }
  76. }
  77. if(p!=np){
  78. swap_lines(np,p);
  79. det*=-1;
  80. }
  81. ++np;
  82. break;
  83. }
  84. }
  85. }
  86. return det;
  87. }
  88. double
  89. Matrix::get_diag_square_sym(size_t i) const{
  90. AvxBlock b;
  91. b.avx=zeros;
  92. const __m256d* avx=get_avx_row(i);
  93. for(size_t k=0;k<nc_avx;++k){
  94. b.avx=_mm256_fmadd_pd(*avx,*avx,b.avx);
  95. ++avx;
  96. }
  97. return b.data[0]+b.data[1]+b.data[2]+b.data[3];
  98. }