Intrepid2
Intrepid2_DataCombiners.hpp
Go to the documentation of this file.
1//
2// Intrepid2_DataCombiners.hpp
3// Trilinos
4//
5// Created by Roberts, Nathan V on 5/31/23.
6//
7
8#ifndef Intrepid2_DataCombiners_hpp
9#define Intrepid2_DataCombiners_hpp
10
15
17#include "Intrepid2_Data.hpp"
20#include "Intrepid2_ScalarView.hpp"
21
22namespace Intrepid2 {
23 template<class DataScalar,typename DeviceType>
24 class Data;
25
26 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
27 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB, bool includeInnerLoop=false>
28 struct InPlaceCombinationFunctor
29 {
30 private:
31 ThisUnderlyingViewType this_underlying_;
32 AUnderlyingViewType A_underlying_;
33 BUnderlyingViewType B_underlying_;
34 BinaryOperator binaryOperator_;
35 int innerLoopSize_;
36 public:
37 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
38 BinaryOperator binaryOperator)
39 :
40 this_underlying_(this_underlying),
41 A_underlying_(A_underlying),
42 B_underlying_(B_underlying),
43 binaryOperator_(binaryOperator)
44 {
45 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
46 }
47
48 InPlaceCombinationFunctor(ThisUnderlyingViewType this_underlying, AUnderlyingViewType A_underlying, BUnderlyingViewType B_underlying,
49 BinaryOperator binaryOperator, int innerLoopSize)
50 :
51 this_underlying_(this_underlying),
52 A_underlying_(A_underlying),
53 B_underlying_(B_underlying),
54 binaryOperator_(binaryOperator),
55 innerLoopSize_(innerLoopSize)
56 {
57 INTREPID2_TEST_FOR_EXCEPTION(includeInnerLoop,std::invalid_argument,"If includeInnerLoop is true, must specify the size of the inner loop");
58 }
59
60 template<class ...IntArgs, bool M=includeInnerLoop>
61 KOKKOS_INLINE_FUNCTION
62 enable_if_t<!M, void>
63 operator()(const IntArgs&... args) const
64 {
65 auto & result = ArgExtractorThis::get( this_underlying_, args... );
66 const auto & A_val = ArgExtractorA::get( A_underlying_, args... );
67 const auto & B_val = ArgExtractorB::get( B_underlying_, args... );
68
69 result = binaryOperator_(A_val,B_val);
70 }
71
72 template<class ...IntArgs, bool M=includeInnerLoop>
73 KOKKOS_INLINE_FUNCTION
74 enable_if_t<M, void>
75 operator()(const IntArgs&... args) const
76 {
77 using int_type = std::tuple_element_t<0, std::tuple<IntArgs...>>;
78 for (int_type iFinal=0; iFinal<static_cast<int_type>(innerLoopSize_); iFinal++)
79 {
80 auto & result = ArgExtractorThis::get( this_underlying_, args..., iFinal );
81 const auto & A_val = ArgExtractorA::get( A_underlying_, args..., iFinal );
82 const auto & B_val = ArgExtractorB::get( B_underlying_, args..., iFinal );
83
84 result = binaryOperator_(A_val,B_val);
85 }
86 }
87 };
88
90 template<class BinaryOperator, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType>
91 struct InPlaceCombinationFunctorConstantCase
92 {
93 private:
94 ThisUnderlyingViewType this_underlying_;
95 AUnderlyingViewType A_underlying_;
96 BUnderlyingViewType B_underlying_;
97 BinaryOperator binaryOperator_;
98 public:
99 InPlaceCombinationFunctorConstantCase(ThisUnderlyingViewType this_underlying,
100 AUnderlyingViewType A_underlying,
101 BUnderlyingViewType B_underlying,
102 BinaryOperator binaryOperator)
103 :
104 this_underlying_(this_underlying),
105 A_underlying_(A_underlying),
106 B_underlying_(B_underlying),
107 binaryOperator_(binaryOperator)
108 {
109 INTREPID2_TEST_FOR_EXCEPTION(this_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
110 INTREPID2_TEST_FOR_EXCEPTION(A_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
111 INTREPID2_TEST_FOR_EXCEPTION(B_underlying.extent(0) != 1,std::invalid_argument,"all views for InPlaceCombinationFunctorConstantCase should have rank 1 and extent 1");
112 }
113
114 KOKKOS_INLINE_FUNCTION
115 void operator()(const int arg0) const
116 {
117 auto & result = this_underlying_(0);
118 const auto & A_val = A_underlying_(0);
119 const auto & B_val = B_underlying_(0);
120
121 result = binaryOperator_(A_val,B_val);
122 }
123 };
124
126 template<bool passThroughBlockDiagonalArgs>
128 {
129 template<class ViewType, class ...IntArgs>
130 static KOKKOS_INLINE_FUNCTION typename ViewType::reference_type get(const ViewType &view, const IntArgs&... intArgs)
131 {
132 return view.getWritableEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
133 }
134 };
135
137 template<bool passThroughBlockDiagonalArgs>
139 {
140 template<class ViewType, class ...IntArgs>
141 static KOKKOS_INLINE_FUNCTION typename ViewType::const_reference_type get(const ViewType &view, const IntArgs&... intArgs)
142 {
143 return view.getEntryWithPassThroughOption(passThroughBlockDiagonalArgs, intArgs...);
144 }
145 };
146
147// static class for combining two Data objects using a specified binary operator
148 template <class DataScalar,typename DeviceType, class BinaryOperator>
150{
151 using reference_type = typename ScalarView<DataScalar,DeviceType>::reference_type;
152 using const_reference_type = typename ScalarView<const DataScalar,DeviceType>::reference_type;
153public:
155 template<class PolicyType, class ThisUnderlyingViewType, class AUnderlyingViewType, class BUnderlyingViewType,
156 class ArgExtractorThis, class ArgExtractorA, class ArgExtractorB>
157 static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying,
158 AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying,
159 BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
160 {
162 Functor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
163 Kokkos::parallel_for("compute in-place", policy, functor);
164 }
165
167 template<int rank>
168 static
169 enable_if_t<rank != 7, void>
171 {
172 auto policy = thisData.template dataExtentRangePolicy<rank>();
173
174 const bool A_1D = A.getUnderlyingViewRank() == 1;
175 const bool B_1D = B.getUnderlyingViewRank() == 1;
176 const bool this_1D = thisData.getUnderlyingViewRank() == 1;
177 const bool A_constant = A_1D && (A.getUnderlyingViewSize() == 1);
178 const bool B_constant = B_1D && (B.getUnderlyingViewSize() == 1);
179 const bool this_constant = this_1D && (thisData.getUnderlyingViewSize() == 1);
180 const bool A_full = A.underlyingMatchesLogical();
181 const bool B_full = B.underlyingMatchesLogical();
182 const bool this_full = thisData.underlyingMatchesLogical();
183
185
187 const FullArgExtractorData<true> fullArgsData; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
188 const FullArgExtractorWritableData<true> fullArgsWritable; // true: pass through block diagonal args. This is due to the behavior of dataExtentRangePolicy() for block diagonal args.
189
196
197 // this lambda returns -1 if there is not a rank-1 underlying view whose data extent matches the logical extent in the corresponding dimension;
198 // otherwise, it returns the logical index of the corresponding dimension.
199 auto get1DArgIndex = [](const Data<DataScalar,DeviceType> &data) -> int
200 {
201 const auto & variationTypes = data.getVariationTypes();
202 for (int d=0; d<rank; d++)
203 {
204 if (variationTypes[d] == GENERAL)
205 {
206 return d;
207 }
208 }
209 return -1;
210 };
211 if (this_constant)
212 {
213 // then A, B are constant, too
214 auto thisAE = constArg;
215 auto AAE = constArg;
216 auto BAE = constArg;
217 auto & this_underlying = thisData.template getUnderlyingView<1>();
218 auto & A_underlying = A.template getUnderlyingView<1>();
219 auto & B_underlying = B.template getUnderlyingView<1>();
220 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
221 }
222 else if (this_full && A_full && B_full)
223 {
224 auto thisAE = fullArgs;
225 auto AAE = fullArgs;
226 auto BAE = fullArgs;
227
228 auto & this_underlying = thisData.template getUnderlyingView<rank>();
229 auto & A_underlying = A.template getUnderlyingView<rank>();
230 auto & B_underlying = B.template getUnderlyingView<rank>();
231
232 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
233 }
234 else if (A_constant)
235 {
236 auto AAE = constArg;
237 auto & A_underlying = A.template getUnderlyingView<1>();
238 if (this_full)
239 {
240 auto thisAE = fullArgs;
241 auto & this_underlying = thisData.template getUnderlyingView<rank>();
242
243 if (B_full)
244 {
245 auto BAE = fullArgs;
246 auto & B_underlying = B.template getUnderlyingView<rank>();
247 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
248 }
249 else // this_full, not B_full: B may have modular data, etc.
250 {
251 auto BAE = fullArgsData;
252 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
253 }
254 }
255 else // this is not full
256 {
257 // below, we optimize for the case of 1D data in B, when A is constant. Still need to handle other cases…
258 if (B_1D && (get1DArgIndex(B) != -1) )
259 {
260 // since A is constant, that implies that this_1D is true, and has the same 1DArgIndex
261 const int argIndex = get1DArgIndex(B);
262 auto & B_underlying = B.template getUnderlyingView<1>();
263 auto & this_underlying = thisData.template getUnderlyingView<1>();
264 switch (argIndex)
265 {
266 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, AAE, arg0); break;
267 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, AAE, arg1); break;
268 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, AAE, arg2); break;
269 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, AAE, arg3); break;
270 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, AAE, arg4); break;
271 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, AAE, arg5); break;
272 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
273 }
274 }
275 else
276 {
277 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
278 auto thisAE = fullArgsWritable;
279 auto BAE = fullArgsData;
280 storeInPlaceCombination(policy, thisData, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
281 }
282 }
283 }
284 else if (B_constant)
285 {
286 auto BAE = constArg;
287 auto & B_underlying = B.template getUnderlyingView<1>();
288 if (this_full)
289 {
290 auto thisAE = fullArgs;
291 auto & this_underlying = thisData.template getUnderlyingView<rank>();
292 if (A_full)
293 {
294 auto AAE = fullArgs;
295 auto & A_underlying = A.template getUnderlyingView<rank>();
296
297 storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, BAE);
298 }
299 else // this_full, not A_full: A may have modular data, etc.
300 {
301 // use A (the Data object). This could be further optimized by using A's underlying View and an appropriately-defined ArgExtractor.
302 auto AAE = fullArgsData;
303 storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
304 }
305 }
306 else // this is not full
307 {
308 // below, we optimize for the case of 1D data in A, when B is constant. Still need to handle other cases…
309 if (A_1D && (get1DArgIndex(A) != -1) )
310 {
311 // since B is constant, that implies that this_1D is true, and has the same 1DArgIndex as A
312 const int argIndex = get1DArgIndex(A);
313 auto & A_underlying = A.template getUnderlyingView<1>();
314 auto & this_underlying = thisData.template getUnderlyingView<1>();
315 switch (argIndex)
316 {
317 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, BAE); break;
318 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, BAE); break;
319 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, BAE); break;
320 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, BAE); break;
321 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, BAE); break;
322 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, BAE); break;
323 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
324 }
325 }
326 else
327 {
328 // since storing to Data object requires a call to getWritableEntry(), we use FullArgExtractorWritableData
329 auto thisAE = fullArgsWritable;
330 auto AAE = fullArgsData;
331 storeInPlaceCombination(policy, thisData, A, B_underlying, binaryOperator, thisAE, AAE, BAE);
332 }
333 }
334 }
335 else // neither A nor B constant
336 {
337 if (this_1D && (get1DArgIndex(thisData) != -1))
338 {
339 // possible ways that "this" could have full-extent, 1D data
340 // 1. A constant, B 1D
341 // 2. A 1D, B constant
342 // 3. A 1D, B 1D
343 // The constant possibilities are already addressed above, leaving us with (3). Note that A and B don't have to be full-extent, however
344 const int argThis = get1DArgIndex(thisData);
345 const int argA = get1DArgIndex(A); // if not full-extent, will be -1
346 const int argB = get1DArgIndex(B); // ditto
347
348 auto & A_underlying = A.template getUnderlyingView<1>();
349 auto & B_underlying = B.template getUnderlyingView<1>();
350 auto & this_underlying = thisData.template getUnderlyingView<1>();
351 if ((argA != -1) && (argB != -1))
352 {
353#ifdef INTREPID2_HAVE_DEBUG
354 INTREPID2_TEST_FOR_EXCEPTION(argA != argThis, std::logic_error, "Unexpected 1D arg combination.");
355 INTREPID2_TEST_FOR_EXCEPTION(argB != argThis, std::logic_error, "Unexpected 1D arg combination.");
356#endif
357 switch (argThis)
358 {
359 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg0, arg0, arg0); break;
360 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg1, arg1, arg1); break;
361 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg2, arg2, arg2); break;
362 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg3, arg3, arg3); break;
363 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg4, arg4, arg4); break;
364 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, arg5, arg5, arg5); break;
365 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
366 }
367 }
368 else if (argA != -1)
369 {
370 // B is not full-extent in dimension argThis; use the Data object
371 switch (argThis)
372 {
373 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg0, arg0, fullArgsData); break;
374 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg1, arg1, fullArgsData); break;
375 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg2, arg2, fullArgsData); break;
376 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg3, arg3, fullArgsData); break;
377 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg4, arg4, fullArgsData); break;
378 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, arg5, arg5, fullArgsData); break;
379 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
380 }
381 }
382 else
383 {
384 // A is not full-extent in dimension argThis; use the Data object
385 switch (argThis)
386 {
387 case 0: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg0, fullArgsData, arg0); break;
388 case 1: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg1, fullArgsData, arg1); break;
389 case 2: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg2, fullArgsData, arg2); break;
390 case 3: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg3, fullArgsData, arg3); break;
391 case 4: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg4, fullArgsData, arg4); break;
392 case 5: storeInPlaceCombination(policy, this_underlying, A, B_underlying, binaryOperator, arg5, fullArgsData, arg5); break;
393 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
394 }
395 }
396 }
397 else if (this_full)
398 {
399 // This case uses A,B Data objects; could be optimized by dividing into subcases and using underlying Views with appropriate ArgExtractors.
400 auto & this_underlying = thisData.template getUnderlyingView<rank>();
401 auto thisAE = fullArgs;
402
403 if (A_full)
404 {
405 auto & A_underlying = A.template getUnderlyingView<rank>();
406 auto AAE = fullArgs;
407
408 if (B_1D && (get1DArgIndex(B) != -1))
409 {
410 const int argIndex = get1DArgIndex(B);
411 auto & B_underlying = B.template getUnderlyingView<1>();
412 switch (argIndex)
413 {
414 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg0); break;
415 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg1); break;
416 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg2); break;
417 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg3); break;
418 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg4); break;
419 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, AAE, arg5); break;
420 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
421 }
422 }
423 else
424 {
425 // A is full; B is not full, but not constant or full-extent 1D
426 // unoptimized in B access:
428 storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, AAE, BAE);
429 }
430 }
431 else // A is not full
432 {
433 if (A_1D && (get1DArgIndex(A) != -1))
434 {
435 const int argIndex = get1DArgIndex(A);
436 auto & A_underlying = A.template getUnderlyingView<1>();
437 if (B_full)
438 {
439 auto & B_underlying = B.template getUnderlyingView<rank>();
440 auto BAE = fullArgs;
441 switch (argIndex)
442 {
443 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg0, BAE); break;
444 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg1, BAE); break;
445 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg2, BAE); break;
446 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg3, BAE); break;
447 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg4, BAE); break;
448 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B_underlying, binaryOperator, thisAE, arg5, BAE); break;
449 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
450 }
451 }
452 else
453 {
454 auto BAE = fullArgsData;
455 switch (argIndex)
456 {
457 case 0: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg0, BAE); break;
458 case 1: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg1, BAE); break;
459 case 2: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg2, BAE); break;
460 case 3: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg3, BAE); break;
461 case 4: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg4, BAE); break;
462 case 5: storeInPlaceCombination(policy, this_underlying, A_underlying, B, binaryOperator, thisAE, arg5, BAE); break;
463 default: INTREPID2_TEST_FOR_EXCEPTION(true, std::invalid_argument, "Invalid/unexpected arg index");
464 }
465 }
466 }
467 else // A not full, and not full-extent 1D
468 {
469 // unoptimized in A, B accesses.
470 auto AAE = fullArgsData;
471 auto BAE = fullArgsData;
472 storeInPlaceCombination(policy, this_underlying, A, B, binaryOperator, thisAE, AAE, BAE);
473 }
474 }
475 }
476 else
477 {
478 // completely un-optimized case: we use Data objects for this, A, B.
479 auto thisAE = fullArgsWritable;
480 auto AAE = fullArgsData;
481 auto BAE = fullArgsData;
482 storeInPlaceCombination(policy, thisData, A, B, binaryOperator, thisAE, AAE, BAE);
483 }
484 }
485 }
486
488 template<int rank>
489 static
490 enable_if_t<rank == 7, void>
492 {
493 auto policy = thisData.template dataExtentRangePolicy<rank>();
494
495 using DataType = Data<DataScalar,DeviceType>;
499
500 const ordinal_type dim6 = thisData.getDataExtent(6);
501 const bool includeInnerLoop = true;
503 Functor functor(thisData, A, B, binaryOperator, dim6);
504 Kokkos::parallel_for("compute in-place", policy, functor);
505 }
506
507 static void storeInPlaceCombination(Data<DataScalar,DeviceType> &thisData, const Data<DataScalar,DeviceType> &A, const Data<DataScalar,DeviceType> &B, BinaryOperator binaryOperator)
508 {
509 using ExecutionSpace = typename DeviceType::execution_space;
510
511#ifdef INTREPID2_HAVE_DEBUG
512 // check logical extents
513 for (int d=0; d<rank_; d++)
514 {
515 INTREPID2_TEST_FOR_EXCEPTION(A.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
516 INTREPID2_TEST_FOR_EXCEPTION(B.extent_int(d) != thisData.extent_int(d), std::invalid_argument, "A, B, and this must agree on all logical extents");
517 }
518 // TODO: add some checks that data extent of this suffices to accept combined A + B data.
519#endif
520
521 const bool this_constant = (thisData.getUnderlyingViewRank() == 1) && (thisData.getUnderlyingViewSize() == 1);
522
523 // we special-case for constant output here; since the constant case is essentially all overhead, we want to avoid as much of the overhead of storeInPlaceCombination() as possible…
524 if (this_constant)
525 {
526 // constant data
527 Kokkos::RangePolicy<ExecutionSpace> policy(ExecutionSpace(),0,1); // just 1 entry
528
529 auto this_underlying = thisData.template getUnderlyingView<1>();
530 auto A_underlying = A.template getUnderlyingView<1>();
531 auto B_underlying = B.template getUnderlyingView<1>();
532
533 using ConstantCaseFunctor = InPlaceCombinationFunctorConstantCase<decltype(binaryOperator), decltype(this_underlying),
534 decltype(A_underlying), decltype(B_underlying)>;
535
536 ConstantCaseFunctor functor(this_underlying, A_underlying, B_underlying, binaryOperator);
537 Kokkos::parallel_for("compute in-place", policy,functor);
538 }
539 else
540 {
541 switch (thisData.rank())
542 {
543 case 1: storeInPlaceCombination<1>(thisData, A, B, binaryOperator); break;
544 case 2: storeInPlaceCombination<2>(thisData, A, B, binaryOperator); break;
545 case 3: storeInPlaceCombination<3>(thisData, A, B, binaryOperator); break;
546 case 4: storeInPlaceCombination<4>(thisData, A, B, binaryOperator); break;
547 case 5: storeInPlaceCombination<5>(thisData, A, B, binaryOperator); break;
548 case 6: storeInPlaceCombination<6>(thisData, A, B, binaryOperator); break;
549 case 7: storeInPlaceCombination<7>(thisData, A, B, binaryOperator); break;
550 default:
551 INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(true, std::logic_error, "unhandled rank in switch");
552 }
553 }
554 }
555};
556
557} // end namespace Intrepid2
558
559// We do ETI for basic double arithmetic on default device.
560//template<class Scalar> struct ScalarSumFunctor;
561//template<class Scalar> struct ScalarDifferenceFunctor;
562//template<class Scalar> struct ScalarProductFunctor;
563//template<class Scalar> struct ScalarQuotientFunctor;
564
569
570#endif /* Intrepid2_DataCombiners_hpp */
Header file with various static argument-extractor classes. These are useful for writing efficient,...
Defines functors for use with Data objects: so far, we include simple arithmetical functors for sum,...
Defines DataVariationType enum that specifies the types of variation possible within a Data object.
@ GENERAL
arbitrary variation
Defines the Data class, a wrapper around a Kokkos::View that allows data that is constant or repeatin...
#define INTREPID2_TEST_FOR_EXCEPTION_DEVICE_SAFE(test, x, msg)
static enable_if_t< rank==7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank of 7. (Not optimized; expect...
static void storeInPlaceCombination(PolicyType &policy, ThisUnderlyingViewType &this_underlying, AUnderlyingViewType &A_underlying, BUnderlyingViewType &B_underlying, BinaryOperator &binaryOperator, ArgExtractorThis argThis, ArgExtractorA argA, ArgExtractorB argB)
storeInPlaceCombination implementation for rank < 7, with compile-time underlying views and argument ...
static enable_if_t< rank !=7, void > storeInPlaceCombination(Data< DataScalar, DeviceType > &thisData, const Data< DataScalar, DeviceType > &A, const Data< DataScalar, DeviceType > &B, BinaryOperator binaryOperator)
storeInPlaceCombination with compile-time rank – implementation for rank < 7.
Wrapper around a Kokkos::View that allows data that is constant or repeating in various logical dimen...
KOKKOS_INLINE_FUNCTION int extent_int(const int &r) const
Returns the logical extent in the specified dimension.
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewSize() const
returns the number of entries in the View that stores the unique data
KOKKOS_INLINE_FUNCTION int getDataExtent(const ordinal_type &d) const
returns the true extent of the data corresponding to the logical dimension provided; if the data does...
KOKKOS_INLINE_FUNCTION bool underlyingMatchesLogical() const
Returns true if the underlying container has exactly the same rank and extents as the logical contain...
KOKKOS_INLINE_FUNCTION ordinal_type getUnderlyingViewRank() const
returns the rank of the View that stores the unique data
KOKKOS_INLINE_FUNCTION unsigned rank() const
Returns the logical rank of the Data container.
Argument extractor class which ignores the input arguments in favor of passing a single 0 argument to...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
For use with Data object into which a value will be stored. We use passThroughBlockDiagonalArgs = tru...
Argument extractor class which passes all arguments to the provided container.
Argument extractor class which passes a single argument, indicated by the template parameter whichArg...