vector.hpp 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #ifndef VECTOR_HPP
  2. #define VECTOR_HPP
  3. #include <immintrin.h>
  4. #include <iostream>
  5. #include "debug.hpp"
  6. using namespace std;
  7. using Real = float;
  8. #define AVX_SIZE(n) (n+7)/8
  9. struct Vector{
  10. #ifdef DEBUG
  11. size_t n;
  12. #endif
  13. Real* data;
  14. Real& operator[](size_t i);
  15. };
  16. inline Real&
  17. Vector::operator[](size_t i){
  18. assert(i<n);
  19. return data[i];
  20. }
  21. inline Vector
  22. init_vector(size_t n){
  23. Vector v;
  24. #if DEBUG
  25. v.n=n;
  26. #endif
  27. v.data=static_cast<Real*>(std::aligned_alloc(32,32*AVX_SIZE(n))); //256 bits -> 32 octets
  28. #if DEGUB
  29. assert(v.data!=nullptr);
  30. #endif
  31. return v;
  32. }
  33. inline void
  34. delete_vector(Vector v){
  35. free(v.data);
  36. }
  37. inline bool
  38. is_null(Vector v){
  39. return v.data==nullptr;
  40. }
  41. #if DEBUG
  42. static const Vector NullVector={0,nullptr};
  43. #else
  44. static const Vector NullVector={nullptr};
  45. #endif
  46. inline void
  47. display(Vector v,size_t n){
  48. if(n==0){
  49. cout<<"[]"<<endl;
  50. return;
  51. }
  52. cout<<'['<<v[0];
  53. for(size_t i=1;i<n;++i){
  54. cout<<','<<v[i];
  55. }
  56. cout<<']'<<endl;
  57. }
  58. inline size_t
  59. argmax(Vector v,size_t n){
  60. assert(n>0);
  61. size_t imax=0;
  62. Real vmax=v[0];
  63. for(size_t i=1;i<n;++i){
  64. if(v[i]>vmax){
  65. vmax=v[i];
  66. imax=i;
  67. }
  68. }
  69. return imax;
  70. }
  71. inline size_t
  72. indice2(size_t i,size_t j,size_t nj){
  73. return i*nj+j;
  74. }
  75. inline size_t
  76. indice3(size_t i,size_t j,size_t k,size_t nj,size_t nk){
  77. return (i*nj+j)*nk+k;
  78. }
  79. inline size_t
  80. indice4(size_t i,size_t j,size_t k,size_t l,size_t nj,size_t nk,size_t nl){
  81. return ((i*nj+j)*nk+k)*nl+l;
  82. }
  83. #endif