SparseTriangularView.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. // Copyright (C) 2012 Désiré Nuentsa-Wakam <desire.nuentsa_wakam@inria.fr>
  6. //
  7. // This Source Code Form is subject to the terms of the Mozilla
  8. // Public License v. 2.0. If a copy of the MPL was not distributed
  9. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  10. #ifndef EIGEN_SPARSE_TRIANGULARVIEW_H
  11. #define EIGEN_SPARSE_TRIANGULARVIEW_H
  12. namespace Eigen {
  13. /** \ingroup SparseCore_Module
  14. *
  15. * \brief Base class for a triangular part in a \b sparse matrix
  16. *
  17. * This class is an abstract base class of class TriangularView, and objects of type TriangularViewImpl cannot be instantiated.
  18. * It extends class TriangularView with additional methods which are available for sparse expressions only.
  19. *
  20. * \sa class TriangularView, SparseMatrixBase::triangularView()
  21. */
  22. template<typename MatrixType, unsigned int Mode> class TriangularViewImpl<MatrixType,Mode,Sparse>
  23. : public SparseMatrixBase<TriangularView<MatrixType,Mode> >
  24. {
  25. enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
  26. || ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)),
  27. SkipLast = !SkipFirst,
  28. SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
  29. HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
  30. };
  31. typedef TriangularView<MatrixType,Mode> TriangularViewType;
  32. protected:
  33. // dummy solve function to make TriangularView happy.
  34. void solve() const;
  35. typedef SparseMatrixBase<TriangularViewType> Base;
  36. public:
  37. EIGEN_SPARSE_PUBLIC_INTERFACE(TriangularViewType)
  38. typedef typename MatrixType::Nested MatrixTypeNested;
  39. typedef typename internal::remove_reference<MatrixTypeNested>::type MatrixTypeNestedNonRef;
  40. typedef typename internal::remove_all<MatrixTypeNested>::type MatrixTypeNestedCleaned;
  41. template<typename RhsType, typename DstType>
  42. EIGEN_DEVICE_FUNC
  43. EIGEN_STRONG_INLINE void _solve_impl(const RhsType &rhs, DstType &dst) const {
  44. if(!(internal::is_same<RhsType,DstType>::value && internal::extract_data(dst) == internal::extract_data(rhs)))
  45. dst = rhs;
  46. this->solveInPlace(dst);
  47. }
  48. /** Applies the inverse of \c *this to the dense vector or matrix \a other, "in-place" */
  49. template<typename OtherDerived> void solveInPlace(MatrixBase<OtherDerived>& other) const;
  50. /** Applies the inverse of \c *this to the sparse vector or matrix \a other, "in-place" */
  51. template<typename OtherDerived> void solveInPlace(SparseMatrixBase<OtherDerived>& other) const;
  52. };
  53. namespace internal {
  54. template<typename ArgType, unsigned int Mode>
  55. struct unary_evaluator<TriangularView<ArgType,Mode>, IteratorBased>
  56. : evaluator_base<TriangularView<ArgType,Mode> >
  57. {
  58. typedef TriangularView<ArgType,Mode> XprType;
  59. protected:
  60. typedef typename XprType::Scalar Scalar;
  61. typedef typename XprType::StorageIndex StorageIndex;
  62. typedef typename evaluator<ArgType>::InnerIterator EvalIterator;
  63. enum { SkipFirst = ((Mode&Lower) && !(ArgType::Flags&RowMajorBit))
  64. || ((Mode&Upper) && (ArgType::Flags&RowMajorBit)),
  65. SkipLast = !SkipFirst,
  66. SkipDiag = (Mode&ZeroDiag) ? 1 : 0,
  67. HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
  68. };
  69. public:
  70. enum {
  71. CoeffReadCost = evaluator<ArgType>::CoeffReadCost,
  72. Flags = XprType::Flags
  73. };
  74. explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
  75. inline Index nonZerosEstimate() const {
  76. return m_argImpl.nonZerosEstimate();
  77. }
  78. class InnerIterator : public EvalIterator
  79. {
  80. typedef EvalIterator Base;
  81. public:
  82. EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
  83. : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
  84. {
  85. if(SkipFirst)
  86. {
  87. while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer))
  88. Base::operator++();
  89. if(HasUnitDiag)
  90. m_returnOne = m_containsDiag;
  91. }
  92. else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
  93. {
  94. if((!SkipFirst) && Base::operator bool())
  95. Base::operator++();
  96. m_returnOne = m_containsDiag;
  97. }
  98. }
  99. EIGEN_STRONG_INLINE InnerIterator& operator++()
  100. {
  101. if(HasUnitDiag && m_returnOne)
  102. m_returnOne = false;
  103. else
  104. {
  105. Base::operator++();
  106. if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
  107. {
  108. if((!SkipFirst) && Base::operator bool())
  109. Base::operator++();
  110. m_returnOne = m_containsDiag;
  111. }
  112. }
  113. return *this;
  114. }
  115. EIGEN_STRONG_INLINE operator bool() const
  116. {
  117. if(HasUnitDiag && m_returnOne)
  118. return true;
  119. if(SkipFirst) return Base::operator bool();
  120. else
  121. {
  122. if (SkipDiag) return (Base::operator bool() && this->index() < this->outer());
  123. else return (Base::operator bool() && this->index() <= this->outer());
  124. }
  125. }
  126. // inline Index row() const { return (ArgType::Flags&RowMajorBit ? Base::outer() : this->index()); }
  127. // inline Index col() const { return (ArgType::Flags&RowMajorBit ? this->index() : Base::outer()); }
  128. inline StorageIndex index() const
  129. {
  130. if(HasUnitDiag && m_returnOne) return internal::convert_index<StorageIndex>(Base::outer());
  131. else return Base::index();
  132. }
  133. inline Scalar value() const
  134. {
  135. if(HasUnitDiag && m_returnOne) return Scalar(1);
  136. else return Base::value();
  137. }
  138. protected:
  139. bool m_returnOne;
  140. bool m_containsDiag;
  141. private:
  142. Scalar& valueRef();
  143. };
  144. protected:
  145. evaluator<ArgType> m_argImpl;
  146. const ArgType& m_arg;
  147. };
  148. } // end namespace internal
  149. template<typename Derived>
  150. template<int Mode>
  151. inline const TriangularView<const Derived, Mode>
  152. SparseMatrixBase<Derived>::triangularView() const
  153. {
  154. return TriangularView<const Derived, Mode>(derived());
  155. }
  156. } // end namespace Eigen
  157. #endif // EIGEN_SPARSE_TRIANGULARVIEW_H