Hammer  1.0.0
Helicity Amplitude Module for Matrix Element Reweighting
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
Dot.cc
Go to the documentation of this file.
1 ///
2 /// @file Dot.cc
3 /// @brief Tensor dot product algorithm
4 ///
5 
6 //**** This file is a part of the HAMMER library
7 //**** Copyright (C) 2016 - 2020 The HAMMER Collaboration
8 //**** HAMMER is licensed under version 3 of the GPL; see COPYING for details
9 //**** Please note the MCnet academic guidelines; see GUIDELINES for details
10 
11 // -*- C++ -*-
12 
13 #include <set>
14 #include <numeric>
15 #include <type_traits>
16 #include <tuple>
17 
18 #include <boost/functional/hash.hpp>
19 
20 
29 #include "Hammer/Exceptions.hh"
30 #include "Hammer/Tools/Utils.hh"
31 #include "Hammer/Math/Utils.hh"
33 
34 #include <iostream>
35 
36 using namespace std;
37 
38 namespace Hammer {
39 
40  namespace MultiDimensional {
41 
42  using VTensor = VectorContainer;
43  using STensor = SparseContainer;
44  using OTensor = OuterContainer;
45  using Base = IContainer;
46 
47  namespace Ops {
48 
49  Dot::Dot(const IndexPairList& indices, pair<bool, bool> shouldHC) : _indices{indices}, _hc{shouldHC} {
50  _idxLeft.clear();
51  _idxRight.clear();
52  for(auto& elem: _indices) {
53  _idxLeft.insert(elem.first);
54  _idxRight.insert(elem.second);
55  }
56 
57  }
58 
60  auto newdimlabs = getNewIndexLabels(a, b);
61  auto stridesA = a.getIndexing().getInnerOuterStrides(_indices, b.getIndexing().strides());
62  if (newdimlabs.first.size() == 0) {
63  auto newscal = makeEmptyScalar();
64  for (size_t i = 0; i < a.numValues(); ++i) {
65  if (isZero(a[i])) {
66  continue;
67  }
68  Base::ElementType firstTerm = _hc.first ? conj(a[i]) : a[i];
69  auto pospairs = a.getIndexing().splitPosition(i, stridesA);
70  Base::ElementType secondTerm = _hc.second ? conj(b[pospairs.second]) : b[pospairs.second];
71  newscal->element({}) += firstTerm * secondTerm;
72  }
73  return newscal.release();
74  } else {
75  PositionType reduced = b.getIndexing().reducedNumValues(_indices);
76  auto stridesB = b.getIndexing().getOuterStrides2nd(_indices);
77  auto newvect = makeEmptyVector(newdimlabs.first, newdimlabs.second);
78  VTensor* result = static_cast<VTensor*>(newvect.release());
79  for (size_t i = 0; i < a.numValues(); ++i) {
80  if (isZero(a[i])) {
81  continue;
82  }
83  Base::ElementType firstTerm = _hc.first ? conj(a[i]) : a[i];
84  auto pospairs = a.getIndexing().splitPosition(i, stridesA);
85  for (size_t j = 0; j < reduced; ++j) {
86  auto posB = b.getIndexing().build2ndPosition(j, pospairs.second, stridesB);
87  Base::ElementType secondTerm = _hc.second ? conj(b[posB]) : b[posB];
88  (*result)[pospairs.first * reduced + j] += firstTerm * secondTerm;
89  }
90  }
91  return static_cast<Base*>(result);
92  }
93  }
94 
96  auto newdimlabs = getNewIndexLabels(a, b);
97  auto leftinfo = a.getIndexing().processShifts(_indices, IndexPairMember::Left);
98  auto rightinfo = b.getIndexing().processShifts(_indices, IndexPairMember::Right);
99  IndexList inners((a.dims().size() + b.dims().size() - newdimlabs.first.size()) / 2);
100  vector<bool> innerAdds(inners.size(), false);
101  if (newdimlabs.first.size() == 0) {
102  auto newscal = makeEmptyScalar();
103  for (auto& elemL : a) {
104  a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
105  inners, innerAdds);
106  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
107  for (auto& elemR : b) {
108  auto tmpRight = b.getIndexing().splitPosition(elemR.first, get<0>(rightinfo),
109  get<1>(rightinfo), inners, innerAdds, true);
110  if (tmpRight == numeric_limits<size_t>::max())
111  continue;
112  Base::ElementType secondTerm = _hc.second ? conj(elemR.second) : elemR.second;
113  newscal->element({}) += firstTerm * secondTerm;
114  }
115  }
116  return newscal.release();
117  } else {
118  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
119  STensor* result = static_cast<STensor*>(newsparse.release());
120  for (auto& elemL : a) {
121  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
122  inners, innerAdds);
123  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
124  for (auto& elemR : b) {
125  auto tmpRight = b.getIndexing().splitPosition(elemR.first, get<0>(rightinfo),
126  get<1>(rightinfo), inners, innerAdds, true);
127  if (tmpRight == numeric_limits<size_t>::max())
128  continue;
129  Base::ElementType secondTerm = _hc.second ? conj(elemR.second) : elemR.second;
130  (*result)[tmpLeft * get<2>(rightinfo) + tmpRight] += firstTerm * secondTerm;
131  }
132  }
133  return static_cast<Base*>(result);
134  }
135  }
136 
138  auto newdimlabs = getNewIndexLabels(a, b);
139  IndexList inners((a.dims().size() + b.dims().size() - newdimlabs.first.size()) / 2);
140  vector<bool> innerAdds(inners.size(), false);
141  auto leftinfo = a.getIndexing().processShifts(_indices, IndexPairMember::Left);
142  if (newdimlabs.first.size() == 0) {
143  auto newscal = makeEmptyScalar();
144  if (b.rank() > 1) {
145  for (auto& elemL : a) {
146  a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo),
147  get<1>(leftinfo), inners, innerAdds);
148  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
149  Base::ElementType secondTerm = _hc.second ? conj(b.element(inners)) : b.element(inners);
150  newscal->element({}) += firstTerm * secondTerm;
151  }
152  } else {
153  for (auto& elemL : a) {
154  a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo),
155  get<1>(leftinfo), inners, innerAdds);
156  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
157  Base::ElementType secondTerm = _hc.second ? conj(b[inners[0]]) : b[inners[0]];
158  newscal->element({}) += firstTerm * secondTerm;
159  }
160  }
161  return newscal.release();
162  } else {
163  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
164  STensor* result = static_cast<STensor*>(newsparse.release());
165  if(b.rank() == _indices.size()) {
166  if(b.rank() > 1) {
167  for (auto& elemL : a) {
168  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
169  inners, innerAdds);
170  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
171  Base::ElementType secondTerm =
172  _hc.second ? conj(b.element(inners)) : b.element(inners);
173  (*result)[tmpLeft] += firstTerm * secondTerm;
174  }
175  }
176  else {
177  for (auto& elemL : a) {
178  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
179  inners, innerAdds);
180  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
181  Base::ElementType secondTerm = _hc.second ? conj(b[inners[0]]) : b[inners[0]];
182  (*result)[tmpLeft] += firstTerm * secondTerm;
183  }
184  }
185  }
186  else {
187  auto itBe = b.endNonZero();
189  auto rightinfo = fakeB.processShifts(_indices, IndexPairMember::Right);
190  for (auto& elemL : a) {
191  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo),
192  get<1>(leftinfo), inners, innerAdds);
193  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
194  auto itB = b.firstNonZero();
195  for(; *itB != *itBe; itB->next()) {
196  PositionType alPos = fakeB.posToAlignedPos(itB->position());
197  auto tmpRight = fakeB.splitPosition(
198  alPos, get<0>(rightinfo), get<1>(rightinfo), inners, innerAdds, true);
199  if (tmpRight == numeric_limits<size_t>::max())
200  continue;
201  Base::ElementType secondTerm = _hc.second ? conj(itB->value()) : itB->value();
202  (*result)[tmpLeft * get<2>(rightinfo) + tmpRight] += firstTerm * secondTerm;
203  }
204  }
205  }
206  return static_cast<Base*>(result);
207  }
208  }
209 
210  static pair<bool, bool> isSameDot(const OuterElemIterator::EntryType& a, const OuterElemIterator::EntryType& b,
211  const DotGroupType& info, const DotGroupType& infoOther) {
212  UNUSED(a);
213  UNUSED(b);
214  UNUSED(info);
215  UNUSED(infoOther);
216  /// @todo IMPLEMENT
217  return {false, false};
218  }
219 
221  // get the chunks
223  TensorData fullResult;
224  OTensor* oResult = nullptr;
225  auto leftinfo = a.getIndexing().processShifts(chunks, IndexPairMember::Left);
226  auto rightinfo = b.getIndexing().processShifts(chunks, IndexPairMember::Right);
227  for(auto& elemA: a) {
228  for(auto& elemB: b) {
229  // check for repetitions
230  vector<pair<size_t,size_t>> multiplicities(leftinfo.size());
231  vector<bool> used(leftinfo.size(), false);
232  for(size_t i=0; i < leftinfo.size(); ++i) {
233  if(used[i]) continue;
234  used[i] = true;
235  size_t count = 1;
236  size_t countHc = 0;
237  for (size_t j = i + 1; j < leftinfo.size(); ++j) {
238  auto tmp = isSameDot(elemA, elemB, chunks[i+1], chunks[j+1]);
239  if(tmp.first) {
240  if(tmp.second) {
241  ++countHc;
242  }
243  else {
244  ++count;
245  }
246  used[j]=true;
247  }
248  }
249  multiplicities[i] = {count, countHc};
250  }
251  Base::ElementType currentWeight = 1.;
252  OuterElemIterator::EntryType currentTerm;
253  for (size_t i = 0; i < leftinfo.size(); ++i) {
254  if (multiplicities[i].first + multiplicities[i].second == 0)
255  continue;
256  // do the dot
257  OuterElemIterator::EntryType leftTensors;
258  leftTensors.reserve(get<0>(chunks[i + 1]).size());
259  transform(get<0>(chunks[i + 1]).begin(), get<0>(chunks[i + 1]).end(), back_inserter(leftTensors),
260  [&](IndexType idx) -> const pair<SharedTensorData, bool>& { return elemA[idx]; });
261  OuterElemIterator::EntryType rightTensors;
262  rightTensors.reserve(get<1>(chunks[i + 1]).size());
263  transform(get<1>(chunks[i + 1]).begin(), get<1>(chunks[i + 1]).end(), back_inserter(rightTensors),
264  [&](IndexType idx) -> const pair<SharedTensorData, bool>& { return elemB[idx]; });
265  IndexList inners(get<2>(chunks[i + 1]).size());
266  vector<bool> innerAdds(inners.size(), false);
267  auto newdimlabs = getNewIndexLabels(a.getIndexing(), b.getIndexing(), chunks[i + 1]);
268  OuterElemIterator itA{leftTensors};
269  OuterElemIterator itAEnd = itA.end();
270  size_t totalRankB = accumulate(rightTensors.begin(), rightTensors.end(), 0ul, [](PositionType tot, const pair<SharedTensorData, bool>& elem) -> PositionType { return tot + elem.first->rank(); });
271  if(totalRankB == inners.size()) {
272  IndexList::iterator itP1, itP2;
273  if (newdimlabs.first.size() == 0) {
274  Base::ElementType newscal;
275  for (; itA != itAEnd; ++itA) {
276  itP1 = inners.begin();
277  a.getIndexing().splitPosition(itA, chunks[i + 1], get<0>(leftinfo[i]),
278  get<1>(leftinfo[i]), inners, innerAdds);
279  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
280  for(auto& entry: rightTensors) {
281  itP2 = itP1 + static_cast<ptrdiff_t>(entry.first->rank());
282  Base::ElementType secondTerm = (!_hc.second != !entry.second)
283  ? conj(entry.first->element(itP1, itP2))
284  : entry.first->element(itP1, itP2);
285  firstTerm *= secondTerm;
286  if(isZero(secondTerm)) {
287  break;
288  }
289  itP1 = itP2;
290  }
291  newscal += firstTerm;
292  }
293  currentWeight *=
294  pow(newscal, multiplicities[i].first) * pow(conj(newscal), multiplicities[i].second);
295  } else {
296  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
297  STensor* result = static_cast<STensor*>(newsparse.get());
298  for (; itA != itAEnd; ++itA) {
299  itP1 = inners.begin();
300  PositionType tmpLeft =
301  a.getIndexing().splitPosition(itA, chunks[i + 1], get<0>(leftinfo[i]),
302  get<1>(leftinfo[i]), inners, innerAdds);
303  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
304  for (auto& entry : rightTensors) {
305  itP2 = itP1 + static_cast<ptrdiff_t>(entry.first->rank());
306  Base::ElementType secondTerm = (!_hc.second != !entry.second)
307  ? conj(entry.first->element(itP1, itP2))
308  : entry.first->element(itP1, itP2);
309  firstTerm *= secondTerm;
310  if (isZero(secondTerm)) {
311  break;
312  }
313  itP1 = itP2;
314  }
315  (*result)[tmpLeft] += firstTerm;
316  }
317  SharedTensorData tmpShared{newsparse.release()};
318  currentTerm.insert(currentTerm.end(), multiplicities[i].first, {tmpShared, false});
319  currentTerm.insert(currentTerm.end(), multiplicities[i].second, {tmpShared, true});
320  }
321  }
322  else {
323  OuterElemIterator itBEnd = OuterElemIterator{rightTensors}.end();
324  if (newdimlabs.first.size() == 0) {
325  Base::ElementType newscal;
326  for (; itA != itAEnd; ++itA) {
327  a.getIndexing().splitPosition(itA, chunks[i + 1], get<0>(leftinfo[i]),
328  get<1>(leftinfo[i]), inners, innerAdds);
329  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
330  OuterElemIterator itB{rightTensors};
331  for (; itB != itBEnd; ++itB) {
332  PositionType tmpRight = b.getIndexing().splitPosition(
333  itB, chunks[i + 1], get<0>(rightinfo[i]), get<1>(rightinfo[i]), inners,
334  innerAdds, true);
335  if (tmpRight == numeric_limits<size_t>::max())
336  continue;
337  Base::ElementType secondTerm = _hc.second ? conj(*itB) : *itB;
338  newscal += firstTerm * secondTerm;
339  }
340  }
341  currentWeight *= pow(newscal, multiplicities[i].first) *
342  pow(conj(newscal), multiplicities[i].second);
343  } else {
344  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
345  STensor* result = static_cast<STensor*>(newsparse.get());
346  for (; itA != itAEnd; ++itA) {
347  PositionType tmpLeft =
348  a.getIndexing().splitPosition(itA, chunks[i + 1], get<0>(leftinfo[i]),
349  get<1>(leftinfo[i]), inners, innerAdds);
350  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
351  OuterElemIterator itB{rightTensors};
352  for (; itB != itBEnd; ++itB) {
353  PositionType tmpRight = b.getIndexing().splitPosition(
354  itB, chunks[i + 1], get<0>(rightinfo[i]), get<1>(rightinfo[i]), inners,
355  innerAdds, true);
356  if (tmpRight == numeric_limits<size_t>::max())
357  continue;
358  Base::ElementType secondTerm = _hc.second ? conj(*itB) : *itB;
359  (*result)[tmpLeft * get<2>(rightinfo[i]) + tmpRight] +=
360  firstTerm * secondTerm;
361  }
362  }
363  SharedTensorData tmpShared{newsparse.release()};
364  currentTerm.insert(currentTerm.end(), multiplicities[i].first, {tmpShared, false});
365  currentTerm.insert(currentTerm.end(), multiplicities[i].second, {tmpShared, true});
366  }
367  }
368  }
369  // now add those untouched
370  for(auto elem: get<0>(chunks[0])) {
371  currentTerm.insert(currentTerm.end(), {elemA[elem].first, !elemA[elem].second != !_hc.first});
372  }
373  for(auto elem: get<1>(chunks[0])) {
374  currentTerm.insert(currentTerm.end(), {elemB[elem].first, !elemB[elem].second != !_hc.second});
375  }
376  if(currentTerm.size() == 0) {
377  if(fullResult.get() == nullptr) {
378  fullResult = makeScalar(currentWeight.real());
379  }
380  else {
381  fullResult->element({}) += (currentWeight.real());
382  }
383  }
384  else if(currentTerm.size() == 1) {
385  TensorData out = currentTerm[0].first->clone();
386  if(currentTerm[0].second) {
387  out->conjugate();
388  }
389  if (!isZero(currentWeight - 1.)) {
390  out->operator*=(currentWeight);
391  }
392  if(fullResult.get() == nullptr) {
393  fullResult = move(out);
394  }
395  else {
396  Ops::Sum summer{};
397  fullResult = calc2(move(fullResult), *out, summer, "sum_outerdot");
398  }
399  }
400  else {
401  if (!isZero(currentWeight - 1.)) {
402  auto temp = currentTerm.back();
403  currentTerm.pop_back();
404  auto candNew = temp.first->clone();
405  candNew->operator*=(temp.second ? conj(currentWeight) : currentWeight);
406  currentTerm.push_back({SharedTensorData{candNew.release()}, temp.second});
407  }
408  if (oResult != nullptr) {
409  oResult->addTerm(currentTerm);
410  } else {
411  fullResult = combineSharedTensors(move(currentTerm));
412  oResult = static_cast<OTensor*>(fullResult.get());
413  }
414  }
415  }
416  }
417  return static_cast<Base*>(fullResult.release());
418  }
419 
422  ASSERT(get<1>(chunks[0]).size() ==0);
423  ASSERT(get<2>(chunks[0]).size() == 0);
424  ASSERT(get<1>(chunks[1]).size() == 0);
425  TensorData fullResult;
426  OTensor* oResult = nullptr;
427  auto leftinfo = a.getIndexing().processShifts(chunks, IndexPairMember::Left);
428  auto rightinfo = b.getIndexing().processShifts(_indices, IndexPairMember::Right);
429  ASSERT(leftinfo.size() == 1);
430  for (auto& elemA : a) {
431  Base::ElementType currentWeight = 1.;
432  OuterElemIterator::EntryType currentTerm;
433  OuterElemIterator::EntryType leftTensors;
434  leftTensors.reserve(get<0>(chunks[1]).size());
435  transform(get<0>(chunks[1]).begin(), get<0>(chunks[1]).end(),
436  back_inserter(leftTensors),
437  [&](IndexType idx) -> const pair<SharedTensorData, bool>& { return elemA[idx]; });
438  IndexList inners(get<2>(chunks[1]).size());
439  vector<bool> innerAdds(inners.size(), false);
440  auto newdimlabs = getNewIndexLabels(a.getIndexing(), b.getIndexing(), chunks[1]);
441  OuterElemIterator itA{leftTensors};
442  OuterElemIterator itAEnd = itA.end();
443  if (newdimlabs.first.size() == 0) {
444  Base::ElementType newscal;
445  for (; itA != itAEnd; ++itA) {
446  a.getIndexing().splitPosition(itA, chunks[1], get<0>(leftinfo[0]),
447  get<1>(leftinfo[0]), inners, innerAdds);
448  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
449  for (auto& elemR : b) {
450  auto tmpRight = b.getIndexing().splitPosition(
451  elemR.first, get<0>(rightinfo), get<1>(rightinfo), inners, innerAdds, true);
452  if (tmpRight == numeric_limits<size_t>::max())
453  continue;
454  Base::ElementType secondTerm = _hc.second ? conj(elemR.second) : elemR.second;
455  newscal += firstTerm * secondTerm;
456  }
457  }
458  currentWeight *= newscal;
459  } else {
460  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
461  STensor* result = static_cast<STensor*>(newsparse.get());
462  for (; itA != itAEnd; ++itA) {
463  PositionType tmpLeft = a.getIndexing().splitPosition(
464  itA, chunks[1], get<0>(leftinfo[0]), get<1>(leftinfo[0]), inners, innerAdds);
465  Base::ElementType firstTerm = _hc.first ? conj(*itA) : *itA;
466  for (auto& elemR : b) {
467  auto tmpRight = b.getIndexing().splitPosition(
468  elemR.first, get<0>(rightinfo), get<1>(rightinfo), inners, innerAdds, true);
469  if (tmpRight == numeric_limits<size_t>::max())
470  continue;
471  Base::ElementType secondTerm = _hc.second ? conj(elemR.second) : elemR.second;
472  (*result)[tmpLeft * get<2>(rightinfo) + tmpRight] += firstTerm * secondTerm;
473  }
474  }
475  SharedTensorData tmpShared{newsparse.release()};
476  currentTerm.push_back({tmpShared, false});
477  }
478  for (auto elem : get<0>(chunks[0])) {
479  currentTerm.insert(currentTerm.end(), {elemA[elem].first, !elemA[elem].second != !_hc.first});
480  }
481  if (currentTerm.size() == 0) {
482  if (fullResult.get() == nullptr) {
483  fullResult = makeScalar(currentWeight.real());
484  } else {
485  fullResult->element({}) += (currentWeight.real());
486  }
487  } else if (currentTerm.size() == 1) {
488  TensorData out = currentTerm[0].first->clone();
489  if (currentTerm[0].second) {
490  out->conjugate();
491  }
492  if (!isZero(currentWeight - 1.)) {
493  out->operator*=(currentWeight);
494  }
495  if (fullResult.get() == nullptr) {
496  fullResult = move(out);
497  } else {
498  Ops::Sum summer{};
499  fullResult = calc2(move(fullResult), *out, summer, "sum_outerdot");
500  }
501  } else {
502  if (!isZero(currentWeight - 1.)) {
503  auto temp = currentTerm.back();
504  currentTerm.pop_back();
505  auto candNew = temp.first->clone();
506  candNew->operator*=(temp.second ? conj(currentWeight) : currentWeight);
507  currentTerm.push_back({SharedTensorData{candNew.release()}, temp.second});
508  }
509  if (oResult != nullptr) {
510  oResult->addTerm(currentTerm);
511  } else {
512  fullResult = combineSharedTensors(move(currentTerm));
513  oResult = static_cast<OTensor*>(fullResult.get());
514  }
515  }
516  }
517  return static_cast<Base*>(fullResult.release());
518  }
519 
520 
521  Base* Dot::operator()(OTensor& a, const Base& b) {
522  // inverse ordering: start dotting from rightmost ensures compatibility
523  // with getNewIndexLabels() for fully collapsed outers
524  map<IndexType, IndexPairList, greater<IndexType>> edges;
525  for(auto& elem: _indices) {
526  auto keyvalA = a.getIndexing().getElementIndex(elem.first);
527  edges[keyvalA.first].push_back({keyvalA.second, elem.second});
528  }
529  IndexType tmpshift = 0ul;
530  for(auto& elem: edges) {
531  for(auto& elem2: elem.second) {
532  elem2.second = static_cast<IndexType>(elem2.second + tmpshift);
533  }
534  tmpshift = static_cast<IndexType>(tmpshift + a.getIndexing().getSubIndexing(elem.first).rank() - elem.second.size());
535  }
536  if(edges.size() == a.getIndexing().numSubIndexing()) {
537  auto newdimlabs = getNewIndexLabels(a, b);
538  TensorData result;
539  if(newdimlabs.first.size() == 0) {
540  result = makeEmptyScalar();
541  }
542  else {
543  result = makeEmptySparse(newdimlabs.first, newdimlabs.second);
544  }
545  // Outer is going to be collapsed
546  SharedTensorData current_entry;
547  Ops::Sum summer{};
548  for (auto& elem : a) {
549  bool first = true;
550  for (auto& elem2 : edges) {
551  // dot the subtensor
552  Ops::Dot dotter{elem2.second, {elem[elem2.first].second, false}};
553  // cout << "In O dot B" << endl;
554  if (!first) {
555  current_entry = calc2(elem[elem2.first].first, *current_entry, dotter, "dot_outerdot");
556  } else {
557  first = false;
558  current_entry = calc2(elem[elem2.first].first, b, dotter, "dot_outerdot");
559  }
560  }
561  if(result->rank() == 0) {
562  result->element({}) += current_entry->element({});
563  }
564  else {
565  result = calc2(move(result), *current_entry, summer, "sum_outerdot");
566  }
567  }
568  return result.release();
569  }
570  else {
571  OuterContainer::DataType newdata;
572  newdata.reserve(a.numAddends());
573  vector<pair<SharedTensorData, bool>> current_entries(a.getIndexing().numSubIndexing() - edges.size() + 1);
574  for(auto& elem: a) {
575  size_t dotIdx = current_entries.size();
576  size_t curInIdx = elem.size() - 1;
577  size_t curOutIdx = current_entries.size() - 1;
578  for (auto& elem2 : edges) {
579  if(elem2.first != curInIdx) {
580  // copy the pass throughs
581  for(; curInIdx != elem2.first; --curInIdx, --curOutIdx) {
582  current_entries[curOutIdx] = elem[curInIdx];
583  }
584  }
585  // dot the subtensor
586  // cout << "In O dot B2" << endl;
587  Ops::Dot dotter{elem2.second, {elem[curInIdx].second, false}};
588  if (dotIdx < current_entries.size()) {
589  current_entries[dotIdx].first =
590  calc2(elem[curInIdx].first, *(current_entries[dotIdx].first), dotter, "dot_outerdot");
591  }
592  else {
593  dotIdx = curOutIdx--;
594  current_entries[dotIdx].first = calc2(elem[curInIdx].first, b, dotter, "dot_outerdot");
595  current_entries[dotIdx].second = false;
596  }
597  }
598  newdata.push_back(current_entries);
599  }
600  vector<IndexList> dimlist;
601  vector<LabelsList> lablist;
602  for (auto elem : newdata[0]) {
603  dimlist.push_back(elem.first->dims());
604  lablist.push_back(elem.first->labels());
605  }
606  a.swap(newdata);
607  a.swapIndexing(BlockIndexing{dimlist, lablist});
608  return static_cast<Base*>(&a);
609  }
610  }
611 
612  Base* Dot::operator()(STensor& a, const OTensor& b) {
613  DotGroupList chunks = partitionContractions(a.getIndexing(), b.getIndexing());
614  TensorData fullResult;
615  OTensor* oResult = nullptr;
616  auto leftinfo = a.getIndexing().processShifts(_indices, IndexPairMember::Left);
617  auto rightinfo = b.getIndexing().processShifts(chunks, IndexPairMember::Right);
618  ASSERT(rightinfo.size() == 1);
619  for (auto& elemB : b) {
620  Base::ElementType currentWeight = 1.;
621  OuterElemIterator::EntryType currentTerm;
622  OuterElemIterator::EntryType rightTensors;
623  rightTensors.reserve(get<1>(chunks[1]).size());
624  transform(get<1>(chunks[1]).begin(), get<1>(chunks[1]).end(),
625  back_inserter(rightTensors),
626  [&](IndexType idx) -> const pair<SharedTensorData, bool>& { return elemB[idx]; });
627  IndexList inners(get<2>(chunks[1]).size());
628  vector<bool> innerAdds(inners.size(), false);
629  auto newdimlabs = getNewIndexLabels(a.getIndexing(), b.getIndexing(), chunks[1]);
630  size_t totalRankB = accumulate(rightTensors.begin(), rightTensors.end(), 0ul, [](PositionType tot, const pair<SharedTensorData, bool>& elem) -> PositionType { return tot + elem.first->rank(); });
631  if (totalRankB == inners.size()) {
632  IndexList::iterator itP1, itP2;
633  if (newdimlabs.first.size() == 0) {
634  Base::ElementType newscal;
635  for (auto& elemL : a) {
636  itP1 = inners.begin();
637  a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
638  inners, innerAdds);
639  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
640  for (auto& entry : rightTensors) {
641  itP2 = itP1 + static_cast<ptrdiff_t>(entry.first->rank());
642  Base::ElementType secondTerm = (!_hc.second != !entry.second)
643  ? conj(entry.first->element(itP1, itP2))
644  : entry.first->element(itP1, itP2);
645  firstTerm *= secondTerm;
646  if (isZero(secondTerm)) {
647  break;
648  }
649  itP1 = itP2;
650  }
651  newscal += firstTerm;
652  }
653  currentWeight *= newscal;
654  } else {
655  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
656  STensor* result = static_cast<STensor*>(newsparse.get());
657  for (auto& elemL : a) {
658  itP1 = inners.begin();
659  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo),
660  get<1>(leftinfo), inners, innerAdds);
661  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
662  for (auto& entry : rightTensors) {
663  itP2 = itP1 + static_cast<ptrdiff_t>(entry.first->rank());
664  Base::ElementType secondTerm = (!_hc.second != !entry.second)
665  ? conj(entry.first->element(itP1, itP2))
666  : entry.first->element(itP1, itP2);
667  firstTerm *= secondTerm;
668  if (isZero(secondTerm)) {
669  break;
670  }
671  itP1 = itP2;
672  }
673  (*result)[tmpLeft] += firstTerm;
674  }
675  SharedTensorData tmpShared{newsparse.release()};
676  currentTerm.push_back({tmpShared, false});
677  }
678  }
679  else {
680  OuterElemIterator itBEnd = OuterElemIterator{rightTensors}.end();
681  if (newdimlabs.first.size() == 0) {
682  Base::ElementType newscal;
683  for (auto& elemL : a) {
684  a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo), get<1>(leftinfo),
685  inners, innerAdds);
686  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
687  OuterElemIterator itB{rightTensors};
688  for (; itB != itBEnd; ++itB) {
689  PositionType tmpRight = b.getIndexing().splitPosition(
690  itB, chunks[1], get<0>(rightinfo[0]), get<1>(rightinfo[0]), inners,
691  innerAdds, true);
692  if (tmpRight == numeric_limits<size_t>::max())
693  continue;
694  Base::ElementType secondTerm = _hc.second ? conj(*itB) : *itB;
695  newscal += firstTerm * secondTerm;
696  }
697  }
698  currentWeight *= newscal;
699  } else {
700  auto newsparse = makeEmptySparse(newdimlabs.first, newdimlabs.second);
701  STensor* result = static_cast<STensor*>(newsparse.get());
702  for (auto& elemL : a) {
703  auto tmpLeft = a.getIndexing().splitPosition(elemL.first, get<0>(leftinfo),
704  get<1>(leftinfo), inners, innerAdds);
705  Base::ElementType firstTerm = _hc.first ? conj(elemL.second) : elemL.second;
706  OuterElemIterator itB{rightTensors};
707  for (; itB != itBEnd; ++itB) {
708  PositionType tmpRight = b.getIndexing().splitPosition(
709  itB, chunks[1], get<0>(rightinfo[0]), get<1>(rightinfo[0]), inners,
710  innerAdds, true);
711  if (tmpRight == numeric_limits<size_t>::max())
712  continue;
713  Base::ElementType secondTerm = _hc.second ? conj(*itB) : *itB;
714  (*result)[tmpLeft * get<2>(rightinfo[0]) + tmpRight] += firstTerm * secondTerm;
715  }
716  }
717  SharedTensorData tmpShared{newsparse.release()};
718  currentTerm.push_back({tmpShared, false});
719  }
720  }
721  for (auto elem : get<1>(chunks[0])) {
722  currentTerm.insert(currentTerm.end(), {elemB[elem].first, !elemB[elem].second != !_hc.second});
723  }
724  if(currentTerm.size() == 0) {
725  if(fullResult.get() == nullptr) {
726  fullResult = makeScalar(currentWeight.real());
727  }
728  else {
729  fullResult->element({}) += (currentWeight.real());
730  }
731  }
732  else if(currentTerm.size() == 1) {
733  TensorData out = currentTerm[0].first->clone();
734  if(currentTerm[0].second) {
735  out->conjugate();
736  }
737  if (!isZero(currentWeight - 1.)) {
738  out->operator*=(currentWeight);
739  }
740  if(fullResult.get() == nullptr) {
741  fullResult = move(out);
742  }
743  else {
744  Ops::Sum summer{};
745  fullResult = calc2(move(fullResult), *out, summer, "sum_outerdot");
746  }
747  }
748  else {
749  if (!isZero(currentWeight - 1.)) {
750  auto temp = currentTerm.back();
751  currentTerm.pop_back();
752  auto candNew = temp.first->clone();
753  candNew->operator*=(temp.second ? conj(currentWeight) : currentWeight);
754  currentTerm.push_back({SharedTensorData{candNew.release()}, temp.second});
755  }
756  if (oResult != nullptr) {
757  oResult->addTerm(currentTerm);
758  } else {
759  fullResult = combineSharedTensors(move(currentTerm));
760  oResult = static_cast<OTensor*>(fullResult.get());
761  }
762  }
763  }
764  return static_cast<Base*>(fullResult.release());
765  }
766 
767  Base* Dot::operator()(Base& a, const Base& b) {
768  auto newdimlabs = getNewIndexLabels(a, b);
769  if(newdimlabs.first.size() == 0) {
770  auto newscalar = makeEmptyScalar();
771  IContainer::ElementType res = 0.;
772  BruteForceIterator bf{a.dims()};
773  for (auto elem : bf) {
774  auto aVal = a.element(elem);
775  if (isZero(aVal))
776  continue;
777  if (_hc.first) {
778  aVal = conj(aVal);
779  }
780  IndexList fixed = b.dims();
781  for (auto idx : _indices) {
782  fixed[idx.second] = elem[idx.first];
783  }
784  auto bVal = b.element(fixed);
785  if (isZero(bVal))
786  continue;
787  if (_hc.second) {
788  bVal = conj(bVal);
789  }
790  res += aVal * bVal;
791  }
792  newscalar->element({}) = res;
793  return newscalar.release();
794  }
795  else {
796  auto tmp = makeEmptySparse(newdimlabs.first, newdimlabs.second);
797  Base* results = tmp.release();
798  BruteForceIterator bf{a.dims()};
799  for (auto elem : bf) {
800  auto aVal = a.element(elem);
801  if (isZero(aVal))
802  continue;
803  if (_hc.first) {
804  aVal = conj(aVal);
805  }
806  IndexList fixed = b.dims();
807  for (auto idx : _indices) {
808  fixed[idx.second] = elem[idx.first];
809  }
810  BruteForceIterator bf2{b.dims(), fixed};
811  for (auto elem2 : bf2) {
812  auto bVal = b.element(elem2);
813  if (isZero(bVal))
814  continue;
815  if (_hc.second) {
816  bVal = conj(bVal);
817  }
818  auto idxRes = combineIndex(elem, elem2);
819  results->element(idxRes) += aVal * bVal;
820  }
821  }
822  return results;
823  }
824  }
825 
826  Base* Dot::error(Base&, const Base&) {
827  throw Error("Invalid data types for tensor Dot");
828  }
829 
830  IndexList Dot::combineIndex(const IndexList& a, const IndexList& b) const {
831  IndexList result = a;
832  for (auto elem : reverse_range(_idxLeft)) {
833  result.erase(result.begin() + elem);
834  }
835  size_t base = result.size();
836  result.insert(result.end(),b.begin(), b.end());
837  for (auto elem : reverse_range(_idxRight)) {
838  result.erase(result.begin() + static_cast<ptrdiff_t>(base + elem));
839  }
840  return result;
841  }
842 
843  pair<IndexList, LabelsList> Dot::getNewIndexLabels(const Base& first, const Base& second) const {
844  IndexList resultD = first.dims();
845  LabelsList resultL = _hc.first ? flipListOfLabels(first.labels()) : first.labels();
846  for (auto elem : reverse_range(_idxLeft)) {
847  resultD.erase(resultD.begin() + elem);
848  resultL.erase(resultL.begin() + elem);
849  }
850  size_t base = resultD.size();
851  const IndexList& dim2 = second.dims();
852  const LabelsList& lab2 = _hc.second ? flipListOfLabels(second.labels()) : second.labels();
853  resultD.insert(resultD.end(), dim2.begin(), dim2.end());
854  resultL.insert(resultL.end(), lab2.begin(), lab2.end());
855  for (auto elem : reverse_range(_idxRight)) {
856  resultD.erase(resultD.begin() + static_cast<ptrdiff_t>(base + elem));
857  resultL.erase(resultL.begin() + static_cast<ptrdiff_t>(base + elem));
858  }
859  return make_pair(resultD, resultL);
860  }
861 
862  SharedTensorData Dot::calcSharedDot(SharedTensorData origin, const IContainer& other,
863  const IndexPairList& indices, pair<bool, bool> shouldHC) {
864  Ops::Dot dotter{indices, shouldHC};
865  return calc2(move(origin), other, dotter, "dot");
866  }
867 
868  unsigned long long Dot::dotSignature(const IContainer& a, const IContainer& b,
869  const std::string& type) const {
870  UNUSED(a);
871  UNUSED(b);
872  UNUSED(type);
873  ///@todo IMPLEMENT
874  return 0;
875  }
876 
877  template <size_t N, typename U, typename... Types>
878  typename enable_if<is_convertible<vector<U>, typename tuple_element<N, tuple<Types...>>::type>::value, bool>::type
879  matchPartitions(const tuple<Types...>& data, U value) {
880  return find(get<N>(data).begin(), get<N>(data).end(), value) != get<N>(data).end();
881  }
882 
883  template <size_t N, typename U, typename... Types>
884  typename enable_if<is_convertible<vector<U>, typename tuple_element<N, tuple<Types...>>::type>::value,
885  void>::type
886  addPartitionEntry(tuple<Types...>& data, U value) {
887  if(find(get<N>(data).begin(), get<N>(data).end(), value) == get<N>(data).end())
888  get<N>(data).push_back(value);
889  }
890 
891  template <size_t N, typename... Types>
892  typename enable_if<(N < sizeof...(Types)), void>::type appendPartitionEntries(const tuple<Types...>& from,
893  tuple<Types...>& to) {
894  get<N>(to).insert(get<N>(to).end(), get<N>(from).begin(), get<N>(from).end());
895  }
896 
897  DotGroupList Dot::partitionContractions(const BlockIndexing& lhs,
898  const BlockIndexing& rhs) const {
899  DotGroupList partitions;
900  vector<bool> validPartitions;
901  IndexList lfree(lhs.numSubIndexing());
902  IndexList rfree(rhs.numSubIndexing());
903  iota(lfree.begin(), lfree.end(), 0);
904  iota(rfree.begin(), rfree.end(), 0);
905  //Placeholder for the untouched
906  auto frontelem = make_tuple<IndexList, IndexList, IndexPairList>({}, {}, {});
907  partitions.push_back(frontelem);
908  validPartitions.push_back(true);
909  for (auto& elem : _indices) {
910  auto lloc = lhs.getElementIndex(elem.first).first;
911  auto rloc = rhs.getElementIndex(elem.second).first;
912  lfree.erase(remove(lfree.begin(), lfree.end(), lloc), lfree.end());
913  rfree.erase(remove(rfree.begin(), rfree.end(), rloc), rfree.end());
914  vector<size_t> finds{};
915  auto match = [&](size_t pos, IndexType valL, IndexType valR) -> bool {
916  const auto& data = partitions[pos];
917  return matchPartitions<0>(data, valL) || matchPartitions<1>(data, valR);
918  };
919  for (size_t i = 0; i< partitions.size(); ++i) { // Find all matches on the left or the right
920  if (validPartitions[i] && match(i, lloc, rloc)) {
921  finds.push_back(i);
922  }
923  }
924  auto merge = [&](IndexType lidx, IndexType ridx, IndexPair contraction) -> void {
925  auto& data = partitions[finds[0]];
926  addPartitionEntry<0>(data, lidx);
927  addPartitionEntry<1>(data, ridx);
928  addPartitionEntry<2>(data, contraction);
929  };
930  auto merge_range = [&](size_t other) -> void {
931  auto& from = partitions[finds[other]];
932  auto& to = partitions[finds[0]];
933  appendPartitionEntries<0>(from, to);
934  appendPartitionEntries<1>(from, to);
935  appendPartitionEntries<2>(from, to);
936  };
937  if (finds.size() > 0) { // Insert elem into first found match
938  merge(lloc, rloc, elem);
939  for (size_t idx = 1; idx < finds.size(); ++idx) { // merge in other finds into first found match
940  if(validPartitions[finds[idx]]) {
941  merge_range(idx);
942  validPartitions[finds[idx]] = false;
943  }
944  }
945  }
946  else {
947  auto newelem = make_tuple<IndexList, IndexList, IndexPairList>({lloc}, {rloc}, {elem});
948  partitions.push_back(newelem);
949  validPartitions.push_back(true);
950  }
951  }
952  // eliminate already merged partitions
953  for(size_t i = partitions.size(); 0 < i--;) {
954  if(!validPartitions[i]) partitions.erase(partitions.begin() + static_cast<ptrdiff_t>(i));
955  }
956  // sort contractions
957  for(auto& elem: partitions) {
958  std::sort(get<2>(elem).begin(), get<2>(elem).end(),
959  [](const IndexPair& a, const IndexPair& b) -> bool { return a.second < b.second; });
960  }
961  //Add in untouched tensors
962  partitions[0] = DotGroupType{lfree, rfree, {}};
963  return partitions;
964  }
965 
966 
967  pair<IndexList, LabelsList> Dot::getNewIndexLabels(const BlockIndexing& lhs, const BlockIndexing& rhs,
968  const DotGroupType& chunk) const {
969  IndexList dims;
970  LabelsList labels;
971  map<IndexType, IndexType> leftPosMaps;
972  map<IndexType, IndexType> rightPosMaps;
973  IndexType offset = 0;
974  for (auto elem : get<0>(chunk)) {
975  leftPosMaps.insert({elem, offset});
976  dims.insert(dims.end(), lhs.getSubIndexing(elem).dims().begin(), lhs.getSubIndexing(elem).dims().end());
977  labels.insert(labels.end(), lhs.getSubIndexing(elem).labels().begin(), lhs.getSubIndexing(elem).labels().end());
978  offset = static_cast<IndexType>(offset + lhs.getSubIndexing(elem).rank());
979  }
980  for (auto elem : get<1>(chunk)) {
981  rightPosMaps.insert({elem, offset});
982  dims.insert(dims.end(), rhs.getSubIndexing(elem).dims().begin(), rhs.getSubIndexing(elem).dims().end());
983  labels.insert(labels.end(), rhs.getSubIndexing(elem).labels().begin(),
984  rhs.getSubIndexing(elem).labels().end());
985  offset = static_cast<IndexType>(offset + rhs.getSubIndexing(elem).rank());
986  }
987  set<size_t> deletes;
988  for(auto& elem: get<2>(chunk)) {
989  auto left = lhs.getElementIndex(elem.first);
990  auto right = rhs.getElementIndex(elem.second);
991  deletes.insert(leftPosMaps[left.first] + left.second);
992  deletes.insert(rightPosMaps[right.first] + right.second);
993  }
994  for(auto elem: reverse_range(deletes)) {
995  dims.erase(dims.begin() + static_cast<ptrdiff_t>(elem));
996  labels.erase(labels.begin() + static_cast<ptrdiff_t>(elem));
997  }
998  return {dims, labels};
999  }
1000 
1001 
1002  DotGroupList Dot::partitionContractions(const LabeledIndexing<AlignedIndexing>&, const BlockIndexing& rhs) const {
1003  DotGroupList partitions(2);
1004  IndexList rfree(rhs.numSubIndexing());
1005  iota(rfree.begin(), rfree.end(), 0);
1006  get<2>(partitions[1]) = _indices;
1007  for (auto& elem : _indices) {
1008  auto rloc = rhs.getElementIndex(elem.second).first;
1009  rfree.erase(remove(rfree.begin(), rfree.end(), rloc), rfree.end());
1010  get<1>(partitions[1]).push_back(rloc);
1011  }
1012  // sort contractions
1013  for (auto& elem : partitions) {
1014  std::sort(get<2>(elem).begin(), get<2>(elem).end(),
1015  [](const IndexPair& a, const IndexPair& b) -> bool { return a.second < b.second; });
1016  }
1017  // Add in untouched tensors
1018  partitions[0] = DotGroupType{IndexList{}, rfree, {}};
1019  return partitions;
1020  }
1021 
1022  DotGroupList Dot::partitionContractions(const BlockIndexing& lhs,
1023  const LabeledIndexing<AlignedIndexing>&) const {
1024  DotGroupList partitions(2);
1025  IndexList lfree(lhs.numSubIndexing());
1026  iota(lfree.begin(), lfree.end(), 0);
1027  get<2>(partitions[1]) = _indices;
1028  for (auto& elem : _indices) {
1029  auto lloc = lhs.getElementIndex(elem.first).first;
1030  lfree.erase(remove(lfree.begin(), lfree.end(), lloc), lfree.end());
1031  get<0>(partitions[1]).push_back(lloc);
1032  }
1033  // sort contractions
1034  for (auto& elem : partitions) {
1035  std::sort(get<2>(elem).begin(), get<2>(elem).end(),
1036  [](const IndexPair& a, const IndexPair& b) -> bool { return a.second < b.second; });
1037  }
1038  // Add in untouched tensors
1039  partitions[0] = DotGroupType{lfree, IndexList{}, {}};
1040  return partitions;
1041  }
1042 
1043  pair<IndexList, LabelsList> Dot::getNewIndexLabels(const LabeledIndexing<AlignedIndexing>& lhs,
1044  const BlockIndexing& rhs,
1045  const DotGroupType& chunk) const {
1046  IndexList dims;
1047  LabelsList labels;
1048  map<IndexType, IndexType> rightPosMaps;
1049  IndexType offset = 0;
1050  dims.insert(dims.end(), lhs.dims().begin(),
1051  lhs.dims().end());
1052  labels.insert(labels.end(), lhs.labels().begin(),
1053  lhs.labels().end());
1054  offset = static_cast<IndexType>(offset + lhs.rank());
1055  for (auto elem : get<1>(chunk)) {
1056  rightPosMaps.insert({elem, offset});
1057  dims.insert(dims.end(), rhs.getSubIndexing(elem).dims().begin(),
1058  rhs.getSubIndexing(elem).dims().end());
1059  labels.insert(labels.end(), rhs.getSubIndexing(elem).labels().begin(),
1060  rhs.getSubIndexing(elem).labels().end());
1061  offset = static_cast<IndexType>(offset + rhs.getSubIndexing(elem).rank());
1062  }
1063  set<size_t> deletes;
1064  for (auto& elem : get<2>(chunk)) {
1065  auto right = rhs.getElementIndex(elem.second);
1066  deletes.insert(elem.first);
1067  deletes.insert(rightPosMaps[right.first] + right.second);
1068  }
1069  for (auto elem : reverse_range(deletes)) {
1070  dims.erase(dims.begin() + static_cast<ptrdiff_t>(elem));
1071  labels.erase(labels.begin() + static_cast<ptrdiff_t>(elem));
1072  }
1073  return {dims, labels};
1074  }
1075 
1076  pair<IndexList, LabelsList> Dot::getNewIndexLabels(const BlockIndexing& lhs, const LabeledIndexing<AlignedIndexing>& rhs,
1077  const DotGroupType& chunk) const {
1078  IndexList dims;
1079  LabelsList labels;
1080  map<IndexType, IndexType> leftPosMaps;
1081  IndexType offset = 0;
1082  for (auto elem : get<0>(chunk)) {
1083  leftPosMaps.insert({elem, offset});
1084  dims.insert(dims.end(), lhs.getSubIndexing(elem).dims().begin(),
1085  lhs.getSubIndexing(elem).dims().end());
1086  labels.insert(labels.end(), lhs.getSubIndexing(elem).labels().begin(),
1087  lhs.getSubIndexing(elem).labels().end());
1088  offset = static_cast<IndexType>(offset + lhs.getSubIndexing(elem).rank());
1089  }
1090  dims.insert(dims.end(), rhs.dims().begin(),
1091  rhs.dims().end());
1092  labels.insert(labels.end(), rhs.labels().begin(),
1093  rhs.labels().end());
1094  set<size_t> deletes;
1095  for (auto& elem : get<2>(chunk)) {
1096  auto left = lhs.getElementIndex(elem.first);
1097  deletes.insert(leftPosMaps[left.first] + left.second);
1098  deletes.insert(offset + elem.second);
1099  }
1100  for (auto elem : reverse_range(deletes)) {
1101  dims.erase(dims.begin() + static_cast<ptrdiff_t>(elem));
1102  labels.erase(labels.begin() + static_cast<ptrdiff_t>(elem));
1103  }
1104  return {dims, labels};
1105  }
1106 
1107  } // namespace Ops
1108 
1109 
1110  } // namespace MultiDimensional
1111 
1112 } // namespace Hammer
std::pair< IndexType, IndexType > IndexPair
TensorData makeEmptySparse(const IndexList &dimensions, const LabelsList &labels)
const LabeledIndexing< SequentialIndexing > & getIndexing() const
std::tuple< IndexList, IndexList, IndexPairList > DotGroupType
const LabeledIndexing< AlignedIndexing > & getIndexing() const
virtual LabelsList labels() const =0
std::vector< IndexPair > IndexPairList
reversion_wrapper< T > reverse_range(T &&iterable)
Definition: Tools/Utils.hh:89
IContainer * operator()(VectorContainer &first, const VectorContainer &second)
Definition: Dot.cc:59
size_t PositionType
#define ASSERT(x)
Definition: Exceptions.hh:95
Non-sparse tensor data container.
enable_if< is_convertible< vector< U >, typename tuple_element< N, tuple< Types...> >::type >::value, void >::type addPartitionEntry(tuple< Types...> &data, U value)
Definition: Dot.cc:886
std::vector< std::pair< SharedTensorData, bool >> EntryType
std::complex< double > ElementType
Definition: IContainer.hh:34
uint16_t IndexType
const LabelsList & labels() const
get the labels of all the indices at once
Tensor sum algorithm.
TensorPtr calc2(TensorPtr origin, const IContainer &other, Ops op, std::string opName)
std::unique_ptr< IContainer > TensorData
std::shared_ptr< IContainer > SharedTensorData
Tensor operations helper functions.
std::pair< bool, bool > _hc
Definition: Dot.hh:73
static pair< bool, bool > isSameDot(const OuterElemIterator::EntryType &a, const OuterElemIterator::EntryType &b, const DotGroupType &info, const DotGroupType &infoOther)
Definition: Dot.cc:210
(Sum of) Outer product tensor data container
auto begin(reversion_wrapper< T > w)
Definition: Tools/Utils.hh:79
Hammer exception definitions.
std::vector< DotGroupType > DotGroupList
enable_if<(N< sizeof...(Types)), void >::type appendPartitionEntries(const tuple< Types...> &from, tuple< Types...> &to)
Definition: Dot.cc:892
std::pair< IndexList, LabelsList > getNewIndexLabels(const IContainer &first, const IContainer &second) const
Definition: Dot.cc:843
Sparse tensor data container.
std::vector< IndexType > IndexList
#define UNUSED(x)
Definition: Tools/Utils.hh:28
Order-0 tensor data container.
virtual IndexList dims() const =0
Outer product tensor indexer.
std::vector< std::tuple< IndexList, std::vector< bool >, PositionType > > processShifts(const DotGroupList &chunks, IndexPairMember which) const
Generic error class.
Definition: Exceptions.hh:23
bool isZero(const std::complex< double > val)
Definition: Math/Utils.hh:25
TensorData makeScalar(complex< double > value)
void addTerm(std::vector< std::pair< SharedTensorData, bool >> tensorsAndConjFlags)
SparseContainer STensor
Definition: AddAt.cc:28
const BlockIndexing & getIndexing() const
reference element(const IndexList &coords={}) override
TensorData combineSharedTensors(std::vector< std::pair< SharedTensorData, bool >> &&data)
const LabeledIndexing< AlignedIndexing > & getSubIndexing(IndexType position) const
std::vector< IndexLabel > LabelsList
VectorContainer VTensor
Definition: AddAt.cc:27
TensorData makeEmptyVector(const IndexList &dimensions, const LabelsList &labels)
enable_if< is_convertible< vector< U >, typename tuple_element< N, tuple< Types...> >::type >::value, bool >::type matchPartitions(const tuple< Types...> &data, U value)
Definition: Dot.cc:879
OuterContainer OTensor
Definition: AddAt.cc:29
Generic tensor indexing iterator.
Tensor dot product algorithm.
DotGroupList partitionContractions(const BlockIndexing &lhs, const BlockIndexing &rhs) const
Definition: Dot.cc:897
auto end(reversion_wrapper< T > w)
Definition: Tools/Utils.hh:84
LabelsList flipListOfLabels(LabelsList labels)
IndexPair getElementIndex(IndexType position) const
virtual reference element(const IndexList &coords={})=0