CwiseTernaryOp.h 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. // Copyright (C) 2006-2008 Benoit Jacob <jacob.benoit.1@gmail.com>
  6. // Copyright (C) 2016 Eugene Brevdo <ebrevdo@gmail.com>
  7. //
  8. // This Source Code Form is subject to the terms of the Mozilla
  9. // Public License v. 2.0. If a copy of the MPL was not distributed
  10. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  11. #ifndef EIGEN_CWISE_TERNARY_OP_H
  12. #define EIGEN_CWISE_TERNARY_OP_H
  13. namespace Eigen {
  14. namespace internal {
  15. template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3>
  16. struct traits<CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> > {
  17. // we must not inherit from traits<Arg1> since it has
  18. // the potential to cause problems with MSVC
  19. typedef typename remove_all<Arg1>::type Ancestor;
  20. typedef typename traits<Ancestor>::XprKind XprKind;
  21. enum {
  22. RowsAtCompileTime = traits<Ancestor>::RowsAtCompileTime,
  23. ColsAtCompileTime = traits<Ancestor>::ColsAtCompileTime,
  24. MaxRowsAtCompileTime = traits<Ancestor>::MaxRowsAtCompileTime,
  25. MaxColsAtCompileTime = traits<Ancestor>::MaxColsAtCompileTime
  26. };
  27. // even though we require Arg1, Arg2, and Arg3 to have the same scalar type
  28. // (see CwiseTernaryOp constructor),
  29. // we still want to handle the case when the result type is different.
  30. typedef typename result_of<TernaryOp(
  31. const typename Arg1::Scalar&, const typename Arg2::Scalar&,
  32. const typename Arg3::Scalar&)>::type Scalar;
  33. typedef typename internal::traits<Arg1>::StorageKind StorageKind;
  34. typedef typename internal::traits<Arg1>::StorageIndex StorageIndex;
  35. typedef typename Arg1::Nested Arg1Nested;
  36. typedef typename Arg2::Nested Arg2Nested;
  37. typedef typename Arg3::Nested Arg3Nested;
  38. typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
  39. typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
  40. typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
  41. enum { Flags = _Arg1Nested::Flags & RowMajorBit };
  42. };
  43. } // end namespace internal
  44. template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
  45. typename StorageKind>
  46. class CwiseTernaryOpImpl;
  47. /** \class CwiseTernaryOp
  48. * \ingroup Core_Module
  49. *
  50. * \brief Generic expression where a coefficient-wise ternary operator is
  51. * applied to two expressions
  52. *
  53. * \tparam TernaryOp template functor implementing the operator
  54. * \tparam Arg1Type the type of the first argument
  55. * \tparam Arg2Type the type of the second argument
  56. * \tparam Arg3Type the type of the third argument
  57. *
  58. * This class represents an expression where a coefficient-wise ternary
  59. * operator is applied to three expressions.
  60. * It is the return type of ternary operators, by which we mean only those
  61. * ternary operators where
  62. * all three arguments are Eigen expressions.
  63. * For example, the return type of betainc(matrix1, matrix2, matrix3) is a
  64. * CwiseTernaryOp.
  65. *
  66. * Most of the time, this is the only way that it is used, so you typically
  67. * don't have to name
  68. * CwiseTernaryOp types explicitly.
  69. *
  70. * \sa MatrixBase::ternaryExpr(const MatrixBase<Argument2> &, const
  71. * MatrixBase<Argument3> &, const CustomTernaryOp &) const, class CwiseBinaryOp,
  72. * class CwiseUnaryOp, class CwiseNullaryOp
  73. */
  74. template <typename TernaryOp, typename Arg1Type, typename Arg2Type,
  75. typename Arg3Type>
  76. class CwiseTernaryOp : public CwiseTernaryOpImpl<
  77. TernaryOp, Arg1Type, Arg2Type, Arg3Type,
  78. typename internal::traits<Arg1Type>::StorageKind>,
  79. internal::no_assignment_operator
  80. {
  81. public:
  82. typedef typename internal::remove_all<Arg1Type>::type Arg1;
  83. typedef typename internal::remove_all<Arg2Type>::type Arg2;
  84. typedef typename internal::remove_all<Arg3Type>::type Arg3;
  85. typedef typename CwiseTernaryOpImpl<
  86. TernaryOp, Arg1Type, Arg2Type, Arg3Type,
  87. typename internal::traits<Arg1Type>::StorageKind>::Base Base;
  88. EIGEN_GENERIC_PUBLIC_INTERFACE(CwiseTernaryOp)
  89. typedef typename internal::ref_selector<Arg1Type>::type Arg1Nested;
  90. typedef typename internal::ref_selector<Arg2Type>::type Arg2Nested;
  91. typedef typename internal::ref_selector<Arg3Type>::type Arg3Nested;
  92. typedef typename internal::remove_reference<Arg1Nested>::type _Arg1Nested;
  93. typedef typename internal::remove_reference<Arg2Nested>::type _Arg2Nested;
  94. typedef typename internal::remove_reference<Arg3Nested>::type _Arg3Nested;
  95. EIGEN_DEVICE_FUNC
  96. EIGEN_STRONG_INLINE CwiseTernaryOp(const Arg1& a1, const Arg2& a2,
  97. const Arg3& a3,
  98. const TernaryOp& func = TernaryOp())
  99. : m_arg1(a1), m_arg2(a2), m_arg3(a3), m_functor(func) {
  100. // require the sizes to match
  101. EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg2)
  102. EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(Arg1, Arg3)
  103. // The index types should match
  104. EIGEN_STATIC_ASSERT((internal::is_same<
  105. typename internal::traits<Arg1Type>::StorageKind,
  106. typename internal::traits<Arg2Type>::StorageKind>::value),
  107. STORAGE_KIND_MUST_MATCH)
  108. EIGEN_STATIC_ASSERT((internal::is_same<
  109. typename internal::traits<Arg1Type>::StorageKind,
  110. typename internal::traits<Arg3Type>::StorageKind>::value),
  111. STORAGE_KIND_MUST_MATCH)
  112. eigen_assert(a1.rows() == a2.rows() && a1.cols() == a2.cols() &&
  113. a1.rows() == a3.rows() && a1.cols() == a3.cols());
  114. }
  115. EIGEN_DEVICE_FUNC
  116. EIGEN_STRONG_INLINE Index rows() const {
  117. // return the fixed size type if available to enable compile time
  118. // optimizations
  119. if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
  120. RowsAtCompileTime == Dynamic &&
  121. internal::traits<typename internal::remove_all<Arg2Nested>::type>::
  122. RowsAtCompileTime == Dynamic)
  123. return m_arg3.rows();
  124. else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
  125. RowsAtCompileTime == Dynamic &&
  126. internal::traits<typename internal::remove_all<Arg3Nested>::type>::
  127. RowsAtCompileTime == Dynamic)
  128. return m_arg2.rows();
  129. else
  130. return m_arg1.rows();
  131. }
  132. EIGEN_DEVICE_FUNC
  133. EIGEN_STRONG_INLINE Index cols() const {
  134. // return the fixed size type if available to enable compile time
  135. // optimizations
  136. if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
  137. ColsAtCompileTime == Dynamic &&
  138. internal::traits<typename internal::remove_all<Arg2Nested>::type>::
  139. ColsAtCompileTime == Dynamic)
  140. return m_arg3.cols();
  141. else if (internal::traits<typename internal::remove_all<Arg1Nested>::type>::
  142. ColsAtCompileTime == Dynamic &&
  143. internal::traits<typename internal::remove_all<Arg3Nested>::type>::
  144. ColsAtCompileTime == Dynamic)
  145. return m_arg2.cols();
  146. else
  147. return m_arg1.cols();
  148. }
  149. /** \returns the first argument nested expression */
  150. EIGEN_DEVICE_FUNC
  151. const _Arg1Nested& arg1() const { return m_arg1; }
  152. /** \returns the first argument nested expression */
  153. EIGEN_DEVICE_FUNC
  154. const _Arg2Nested& arg2() const { return m_arg2; }
  155. /** \returns the third argument nested expression */
  156. EIGEN_DEVICE_FUNC
  157. const _Arg3Nested& arg3() const { return m_arg3; }
  158. /** \returns the functor representing the ternary operation */
  159. EIGEN_DEVICE_FUNC
  160. const TernaryOp& functor() const { return m_functor; }
  161. protected:
  162. Arg1Nested m_arg1;
  163. Arg2Nested m_arg2;
  164. Arg3Nested m_arg3;
  165. const TernaryOp m_functor;
  166. };
  167. // Generic API dispatcher
  168. template <typename TernaryOp, typename Arg1, typename Arg2, typename Arg3,
  169. typename StorageKind>
  170. class CwiseTernaryOpImpl
  171. : public internal::generic_xpr_base<
  172. CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type {
  173. public:
  174. typedef typename internal::generic_xpr_base<
  175. CwiseTernaryOp<TernaryOp, Arg1, Arg2, Arg3> >::type Base;
  176. };
  177. } // end namespace Eigen
  178. #endif // EIGEN_CWISE_TERNARY_OP_H