10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
25 struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
28 typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
29 typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar;
31 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
32 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
33 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
34 typename traits<RhsXprType>::Index>::type Index;
35 typedef typename LhsXprType::Nested LhsNested;
36 typedef typename RhsXprType::Nested RhsNested;
37 typedef typename remove_reference<LhsNested>::type _LhsNested;
38 typedef typename remove_reference<RhsNested>::type _RhsNested;
41 static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
42 static const int Layout = traits<LhsXprType>::Layout;
49 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
50 struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>,
Eigen::Dense>
52 typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
55 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
56 struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
58 typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
61 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename Device_>
62 struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
63 typedef Indices_ Indices;
64 typedef LeftArgType_ LeftArgType;
65 typedef RightArgType_ RightArgType;
66 typedef Device_ Device;
69 static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
74 template<
typename Indices,
typename LhsXprType,
typename RhsXprType>
75 class TensorContractionOp :
public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
78 typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
79 typedef typename internal::gebp_traits<
typename LhsXprType::CoeffReturnType,
80 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
81 typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
82 typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
83 typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
85 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
86 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims)
87 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
90 const Indices& indices()
const {
return m_indices; }
94 const typename internal::remove_all<typename LhsXprType::Nested>::type&
95 lhsExpression()
const {
return m_lhs_xpr; }
98 const typename internal::remove_all<typename RhsXprType::Nested>::type&
99 rhsExpression()
const {
return m_rhs_xpr; }
102 typename LhsXprType::Nested m_lhs_xpr;
103 typename RhsXprType::Nested m_rhs_xpr;
104 const Indices m_indices;
108 template<
typename Derived>
109 struct TensorContractionEvaluatorBase
111 typedef typename internal::traits<Derived>::Indices Indices;
112 typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
113 typedef typename internal::traits<Derived>::RightArgType RightArgType;
114 typedef typename internal::traits<Derived>::Device Device;
116 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
117 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
118 typedef typename XprType::Index Index;
119 typedef typename XprType::CoeffReturnType CoeffReturnType;
120 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
124 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
125 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
134 typedef typename internal::conditional<
135 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
136 typedef typename internal::conditional<
137 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
139 static const int LDims =
140 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
141 static const int RDims =
142 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
143 static const int ContractDims = internal::array_size<Indices>::value;
144 static const int NumDims = LDims + RDims - 2 * ContractDims;
146 typedef array<Index, ContractDims> contract_t;
147 typedef array<Index, LDims - ContractDims> left_nocontract_t;
148 typedef array<Index, RDims - ContractDims> right_nocontract_t;
150 typedef DSizes<Index, NumDims> Dimensions;
152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
153 TensorContractionEvaluatorBase(
const XprType& op,
const Device& device)
154 : m_leftImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
155 op.lhsExpression(), op.rhsExpression()), device),
156 m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
157 op.rhsExpression(), op.lhsExpression()), device),
160 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
161 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout)),
162 YOU_MADE_A_PROGRAMMING_MISTAKE);
165 DSizes<Index, LDims> eval_left_dims;
166 DSizes<Index, RDims> eval_right_dims;
167 array<IndexPair<Index>, ContractDims> eval_op_indices;
168 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
170 for (
int i = 0; i < LDims; i++) {
171 eval_left_dims[i] = m_leftImpl.dimensions()[i];
173 for (
int i = 0; i < RDims; i++) {
174 eval_right_dims[i] = m_rightImpl.dimensions()[i];
177 for (
int i = 0; i < ContractDims; i++) {
178 eval_op_indices[i].first = op.indices()[i].first;
179 eval_op_indices[i].second = op.indices()[i].second;
183 for (
int i = 0; i < LDims; i++) {
184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
186 for (
int i = 0; i < RDims; i++) {
187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
191 for (
int i = 0; i < ContractDims; i++) {
192 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second;
193 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first;
199 for (
int i = 0; i < ContractDims; i++) {
200 for (
int j = i + 1; j < ContractDims; j++) {
201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
202 eval_op_indices[j].second != eval_op_indices[i].second &&
203 "contraction axes should be unique");
204 if (eval_op_indices[j].first < eval_op_indices[i].first) {
205 numext::swap(eval_op_indices[j], eval_op_indices[i]);
210 array<Index, LDims> lhs_strides;
212 for (
int i = 0; i < LDims-1; ++i) {
213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
216 array<Index, RDims> rhs_strides;
218 for (
int i = 0; i < RDims-1; ++i) {
219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
222 if (m_i_strides.size() > 0) m_i_strides[0] = 1;
223 if (m_j_strides.size() > 0) m_j_strides[0] = 1;
224 if (m_k_strides.size() > 0) m_k_strides[0] = 1;
234 m_lhs_inner_dim_contiguous =
true;
236 unsigned int nocontract_idx = 0;
238 for (
int i = 0; i < LDims; i++) {
240 bool contracting =
false;
241 for (
int j = 0; j < ContractDims; j++) {
242 if (eval_op_indices[j].first == i) {
249 m_dimensions[dim_idx] = eval_left_dims[i];
250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
252 m_lhs_inner_dim_contiguous =
false;
254 if (nocontract_idx+1 < internal::array_size<left_nocontract_t>::value) {
255 m_i_strides[nocontract_idx+1] =
256 m_i_strides[nocontract_idx] * eval_left_dims[i];
258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
266 for (
int i = 0; i < RDims; i++) {
267 bool contracting =
false;
269 for (
int j = 0; j < ContractDims; j++) {
270 if (eval_op_indices[j].second == i) {
276 m_dimensions[dim_idx] = eval_right_dims[i];
277 if (nocontract_idx+1 < internal::array_size<right_nocontract_t>::value) {
278 m_j_strides[nocontract_idx+1] =
279 m_j_strides[nocontract_idx] * eval_right_dims[i];
281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
294 m_rhs_inner_dim_contiguous =
true;
295 m_rhs_inner_dim_reordered =
false;
296 for (
int i = 0; i < ContractDims; i++) {
297 Index left = eval_op_indices[i].first;
298 Index right = eval_op_indices[i].second;
300 Index size = eval_left_dims[left];
301 eigen_assert(size == eval_right_dims[right] &&
302 "Contraction axes must be same size");
304 if (i+1 < static_cast<int>(internal::array_size<contract_t>::value)) {
305 m_k_strides[i+1] = m_k_strides[i] * size;
307 m_k_size = m_k_strides[i] * size;
309 m_left_contracting_strides[i] = lhs_strides[left];
310 m_right_contracting_strides[i] = rhs_strides[right];
312 if (i > 0 && right < eval_op_indices[i-1].second) {
313 m_rhs_inner_dim_reordered =
true;
316 m_rhs_inner_dim_contiguous =
false;
321 if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
322 for (
int i = 0, j = NumDims - 1; i < j; i++, j--) {
323 numext::swap(m_dimensions[i], m_dimensions[j]);
328 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
330 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* data) {
331 m_leftImpl.evalSubExprsIfNeeded(NULL);
332 m_rightImpl.evalSubExprsIfNeeded(NULL);
337 m_result =
static_cast<Scalar *
>(m_device.allocate(dimensions().TotalSize() *
sizeof(Scalar)));
343 EIGEN_DEVICE_FUNC
void evalTo(Scalar* buffer)
const {
344 if (this->m_lhs_inner_dim_contiguous) {
345 if (this->m_rhs_inner_dim_contiguous) {
346 if (this->m_rhs_inner_dim_reordered) {
347 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, true, Unaligned>(buffer);
350 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, false, Unaligned>(buffer);
354 if (this->m_rhs_inner_dim_reordered) {
355 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, true, Unaligned>(buffer);
358 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, false, Unaligned>(buffer);
363 if (this->m_rhs_inner_dim_contiguous) {
364 if (this->m_rhs_inner_dim_reordered) {
365 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, true, Unaligned>(buffer);
368 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, false, Unaligned>(buffer);
372 if (this->m_rhs_inner_dim_reordered) {
373 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, true, Unaligned>(buffer);
376 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, false, Unaligned>(buffer);
382 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
383 EIGEN_DEVICE_FUNC
void evalGemv(Scalar* buffer)
const {
384 const Index rows = m_i_size;
385 const Index cols = m_k_size;
387 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
388 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
389 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
390 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
391 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
392 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
393 const int lhs_alignment = LeftEvaluator::IsAligned ? Aligned : Unaligned;
394 const int rhs_alignment = RightEvaluator::IsAligned ? Aligned : Unaligned;
395 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
396 LeftEvaluator, left_nocontract_t,
397 contract_t, lhs_packet_size,
398 lhs_inner_dim_contiguous,
399 false, lhs_alignment> LhsMapper;
401 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
402 RightEvaluator, right_nocontract_t,
403 contract_t, rhs_packet_size,
404 rhs_inner_dim_contiguous,
405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
408 m_left_contracting_strides, m_k_strides);
409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
410 m_right_contracting_strides, m_k_strides);
412 const Scalar alpha(1);
413 const Index resIncr(1);
416 m_device.memset(buffer, 0, rows *
sizeof(Scalar));
418 internal::general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,false,RhsScalar,RhsMapper,false>::run(
419 rows, cols, lhs, rhs,
420 buffer, resIncr, alpha);
423 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
424 EIGEN_DEVICE_FUNC
void evalGemm(Scalar* buffer)
const {
426 const Index k = this->m_k_size;
429 const Index m = this->m_i_size;
432 const Index n = this->m_j_size;
435 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
438 typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
439 typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
440 typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
442 const Index nr = Traits::nr;
443 const Index mr = Traits::mr;
445 typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
446 typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
448 const Index lhs_packet_size = internal::unpacket_traits<typename LeftEvaluator::PacketReturnType>::size;
449 const Index rhs_packet_size = internal::unpacket_traits<typename RightEvaluator::PacketReturnType>::size;
451 typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
452 LeftEvaluator, left_nocontract_t,
453 contract_t, lhs_packet_size,
454 lhs_inner_dim_contiguous,
455 false, Unaligned> LhsMapper;
457 typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
458 RightEvaluator, right_nocontract_t,
459 contract_t, rhs_packet_size,
460 rhs_inner_dim_contiguous,
461 rhs_inner_dim_reordered, Unaligned> RhsMapper;
463 typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
466 internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, mr, Traits::LhsProgress, ColMajor> pack_lhs;
467 internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, nr, ColMajor> pack_rhs;
469 internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper, mr, nr, false, false> gebp;
472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
473 this->m_left_contracting_strides, this->m_k_strides);
475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
476 this->m_right_contracting_strides, this->m_k_strides);
478 OutputMapper output(buffer, m);
481 internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
482 const Index kc = blocking.kc();
483 const Index mc = numext::mini(m, blocking.mc());
484 const Index nc = numext::mini(n, blocking.nc());
485 const Index sizeA = mc * kc;
486 const Index sizeB = kc * nc;
488 LhsScalar* blockA =
static_cast<LhsScalar *
>(this->m_device.allocate(sizeA *
sizeof(LhsScalar)));
489 RhsScalar* blockB =
static_cast<RhsScalar *
>(this->m_device.allocate(sizeB *
sizeof(RhsScalar)));
491 for(Index i2=0; i2<m; i2+=mc)
493 const Index actual_mc = numext::mini(i2+mc,m)-i2;
494 for (Index k2 = 0; k2 < k; k2 += kc) {
496 const Index actual_kc = numext::mini(k2 + kc, k) - k2;
497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
500 for (Index j2 = 0; j2 < n; j2 += nc) {
502 const Index actual_nc = numext::mini(j2 + nc, n) - j2;
503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
512 this->m_device.deallocate(blockA);
513 this->m_device.deallocate(blockB);
516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
517 m_leftImpl.cleanup();
518 m_rightImpl.cleanup();
520 if (m_result != NULL) {
521 m_device.deallocate(m_result);
526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
527 return m_result[index];
530 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool)
const {
531 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0);
534 template<
int LoadMode>
535 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
539 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data()
const {
return m_result; }
543 TensorContractionEvaluatorBase& operator = (
const TensorContractionEvaluatorBase&);
544 Dimensions m_dimensions;
546 contract_t m_k_strides;
547 contract_t m_left_contracting_strides;
548 contract_t m_right_contracting_strides;
550 bool m_lhs_inner_dim_contiguous;
551 bool m_rhs_inner_dim_contiguous;
552 bool m_rhs_inner_dim_reordered;
554 left_nocontract_t m_i_strides;
555 right_nocontract_t m_j_strides;
556 left_nocontract_t m_left_nocontract_strides;
557 right_nocontract_t m_right_nocontract_strides;
563 TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
564 TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
565 const Device& m_device;
571 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename Device>
572 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
573 public TensorContractionEvaluatorBase<
574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
575 typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
576 typedef TensorContractionEvaluatorBase<Self> Base;
578 typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
579 typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
580 typedef typename XprType::Index Index;
581 typedef typename XprType::CoeffReturnType CoeffReturnType;
582 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
585 Layout = TensorEvaluator<LeftArgType, Device>::Layout
592 typedef typename internal::conditional<
593 static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
594 typedef typename internal::conditional<
595 static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
597 static const int LDims =
598 internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
599 static const int RDims =
600 internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
601 static const int ContractDims = internal::array_size<Indices>::value;
603 typedef array<Index, ContractDims> contract_t;
604 typedef array<Index, LDims - ContractDims> left_nocontract_t;
605 typedef array<Index, RDims - ContractDims> right_nocontract_t;
607 static const int NumDims = LDims + RDims - 2 * ContractDims;
610 typedef DSizes<Index, NumDims> Dimensions;
612 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device) :
615 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
616 EIGEN_DEVICE_FUNC
void evalProduct(Scalar* buffer)
const {
617 if (this->m_j_size == 1) {
618 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
622 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H Namespace containing all symbols from the Eigen library.
Definition: AdolcForward:45