SparseProduct.h 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_SPARSEPRODUCT_H
  10. #define EIGEN_SPARSEPRODUCT_H
  11. namespace Eigen {
  12. /** \returns an expression of the product of two sparse matrices.
  13. * By default a conservative product preserving the symbolic non zeros is performed.
  14. * The automatic pruning of the small values can be achieved by calling the pruned() function
  15. * in which case a totally different product algorithm is employed:
  16. * \code
  17. * C = (A*B).pruned(); // supress numerical zeros (exact)
  18. * C = (A*B).pruned(ref);
  19. * C = (A*B).pruned(ref,epsilon);
  20. * \endcode
  21. * where \c ref is a meaningful non zero reference value.
  22. * */
  23. template<typename Derived>
  24. template<typename OtherDerived>
  25. inline const Product<Derived,OtherDerived,AliasFreeProduct>
  26. SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
  27. {
  28. return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
  29. }
  30. namespace internal {
  31. // sparse * sparse
  32. template<typename Lhs, typename Rhs, int ProductType>
  33. struct generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
  34. {
  35. template<typename Dest>
  36. static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs)
  37. {
  38. evalTo(dst, lhs, rhs, typename evaluator_traits<Dest>::Shape());
  39. }
  40. // dense += sparse * sparse
  41. template<typename Dest,typename ActualLhs>
  42. static void addTo(Dest& dst, const ActualLhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0)
  43. {
  44. typedef typename nested_eval<ActualLhs,Dynamic>::type LhsNested;
  45. typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
  46. LhsNested lhsNested(lhs);
  47. RhsNested rhsNested(rhs);
  48. internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
  49. typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst);
  50. }
  51. // dense -= sparse * sparse
  52. template<typename Dest>
  53. static void subTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, typename enable_if<is_same<typename evaluator_traits<Dest>::Shape,DenseShape>::value,int*>::type* = 0)
  54. {
  55. addTo(dst, -lhs, rhs);
  56. }
  57. protected:
  58. // sparse = sparse * sparse
  59. template<typename Dest>
  60. static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, SparseShape)
  61. {
  62. typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
  63. typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
  64. LhsNested lhsNested(lhs);
  65. RhsNested rhsNested(rhs);
  66. internal::conservative_sparse_sparse_product_selector<typename remove_all<LhsNested>::type,
  67. typename remove_all<RhsNested>::type, Dest>::run(lhsNested,rhsNested,dst);
  68. }
  69. // dense = sparse * sparse
  70. template<typename Dest>
  71. static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, DenseShape)
  72. {
  73. dst.setZero();
  74. addTo(dst, lhs, rhs);
  75. }
  76. };
  77. // sparse * sparse-triangular
  78. template<typename Lhs, typename Rhs, int ProductType>
  79. struct generic_product_impl<Lhs, Rhs, SparseShape, SparseTriangularShape, ProductType>
  80. : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
  81. {};
  82. // sparse-triangular * sparse
  83. template<typename Lhs, typename Rhs, int ProductType>
  84. struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, ProductType>
  85. : public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
  86. {};
  87. // dense = sparse-product (can be sparse*sparse, sparse*perm, etc.)
  88. template< typename DstXprType, typename Lhs, typename Rhs>
  89. struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
  90. {
  91. typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
  92. static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
  93. {
  94. Index dstRows = src.rows();
  95. Index dstCols = src.cols();
  96. if((dst.rows()!=dstRows) || (dst.cols()!=dstCols))
  97. dst.resize(dstRows, dstCols);
  98. generic_product_impl<Lhs, Rhs>::evalTo(dst,src.lhs(),src.rhs());
  99. }
  100. };
  101. // dense += sparse-product (can be sparse*sparse, sparse*perm, etc.)
  102. template< typename DstXprType, typename Lhs, typename Rhs>
  103. struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::add_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
  104. {
  105. typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
  106. static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
  107. {
  108. generic_product_impl<Lhs, Rhs>::addTo(dst,src.lhs(),src.rhs());
  109. }
  110. };
  111. // dense -= sparse-product (can be sparse*sparse, sparse*perm, etc.)
  112. template< typename DstXprType, typename Lhs, typename Rhs>
  113. struct Assignment<DstXprType, Product<Lhs,Rhs,AliasFreeProduct>, internal::sub_assign_op<typename DstXprType::Scalar,typename Product<Lhs,Rhs,AliasFreeProduct>::Scalar>, Sparse2Dense>
  114. {
  115. typedef Product<Lhs,Rhs,AliasFreeProduct> SrcXprType;
  116. static void run(DstXprType &dst, const SrcXprType &src, const internal::sub_assign_op<typename DstXprType::Scalar,typename SrcXprType::Scalar> &)
  117. {
  118. generic_product_impl<Lhs, Rhs>::subTo(dst,src.lhs(),src.rhs());
  119. }
  120. };
  121. template<typename Lhs, typename Rhs, int Options>
  122. struct unary_evaluator<SparseView<Product<Lhs, Rhs, Options> >, IteratorBased>
  123. : public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
  124. {
  125. typedef SparseView<Product<Lhs, Rhs, Options> > XprType;
  126. typedef typename XprType::PlainObject PlainObject;
  127. typedef evaluator<PlainObject> Base;
  128. explicit unary_evaluator(const XprType& xpr)
  129. : m_result(xpr.rows(), xpr.cols())
  130. {
  131. using std::abs;
  132. ::new (static_cast<Base*>(this)) Base(m_result);
  133. typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
  134. typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
  135. LhsNested lhsNested(xpr.nestedExpression().lhs());
  136. RhsNested rhsNested(xpr.nestedExpression().rhs());
  137. internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type,
  138. typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result,
  139. abs(xpr.reference())*xpr.epsilon());
  140. }
  141. protected:
  142. PlainObject m_result;
  143. };
  144. } // end namespace internal
  145. } // end namespace Eigen
  146. #endif // EIGEN_SPARSEPRODUCT_H