46#ifndef MUELU_LOWPRECISIONFACTORY_DEF_HPP
47#define MUELU_LOWPRECISIONFACTORY_DEF_HPP
49#include <Xpetra_Matrix.hpp>
50#include <Xpetra_Operator.hpp>
51#include <Xpetra_TpetraOperator.hpp>
52#include <Tpetra_CrsMatrixMultiplyOp.hpp>
62 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
64 RCP<ParameterList> validParamList = rcp(
new ParameterList());
66 validParamList->set<std::string>(
"matrix key",
"A",
"");
67 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
68 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
69 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
71 return validParamList;
74 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
78 std::string matrixKey = pL.get<std::string>(
"matrix key");
79 Input(currentLevel, matrixKey);
82 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
84 using Teuchos::ParameterList;
87 std::string matrixKey = pL.get<std::string>(
"matrix key");
89 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
93 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
94 Set(currentLevel, matrixKey, A);
98#if defined(HAVE_TPETRA_INST_DOUBLE) && defined(HAVE_TPETRA_INST_FLOAT)
99 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
101 RCP<ParameterList> validParamList = rcp(
new ParameterList());
103 validParamList->set<std::string>(
"matrix key",
"A",
"");
104 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
105 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
106 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
108 return validParamList;
111 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
114 const ParameterList& pL = GetParameterList();
115 std::string matrixKey = pL.get<std::string>(
"matrix key");
116 Input(currentLevel, matrixKey);
119 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
121 using Teuchos::ParameterList;
122 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
124 const ParameterList& pL = GetParameterList();
125 std::string matrixKey = pL.get<std::string>(
"matrix key");
127 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
129 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel, matrixKey);
131 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<Scalar, double>::value) {
132 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
133 auto tpLowA = tpA->template convert<HalfScalar>();
134 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar,HalfScalar,LocalOrdinal,GlobalOrdinal,Node>(tpLowA));
136 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
137 Set(currentLevel, matrixKey, xpLowOpA);
141 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
142 Set(currentLevel, matrixKey, A);
147#if defined(HAVE_TPETRA_INST_COMPLEX_DOUBLE) && defined(HAVE_TPETRA_INST_COMPLEX_FLOAT)
148 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
150 RCP<ParameterList> validParamList = rcp(
new ParameterList());
152 validParamList->set<std::string>(
"matrix key",
"A",
"");
153 validParamList->set< RCP<const FactoryBase> >(
"R", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
154 validParamList->set< RCP<const FactoryBase> >(
"A", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
155 validParamList->set< RCP<const FactoryBase> >(
"P", Teuchos::null,
"Generating factory of the matrix A to be converted to lower precision");
157 return validParamList;
160 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
163 const ParameterList& pL = GetParameterList();
164 std::string matrixKey = pL.get<std::string>(
"matrix key");
165 Input(currentLevel, matrixKey);
168 template <
class LocalOrdinal,
class GlobalOrdinal,
class Node>
170 using Teuchos::ParameterList;
171 using HalfScalar =
typename Teuchos::ScalarTraits<Scalar>::halfPrecision;
173 const ParameterList& pL = GetParameterList();
174 std::string matrixKey = pL.get<std::string>(
"matrix key");
176 FactoryMonitor m(*
this,
"Converting " + matrixKey +
" to half precision", currentLevel);
178 RCP<Matrix> A = Get< RCP<Matrix> >(currentLevel, matrixKey);
180 if ((A->getRowMap()->lib() == Xpetra::UseTpetra) && std::is_same<
Scalar, std::complex<double> >::value) {
181 auto tpA = rcp_dynamic_cast<TpetraCrsMatrix>(rcp_dynamic_cast<CrsMatrixWrap>(A)->getCrsMatrix(),
true)->getTpetra_CrsMatrix();
182 auto tpLowA = tpA->template convert<HalfScalar>();
183 auto tpLowOpA = rcp(
new Tpetra::CrsMatrixMultiplyOp<Scalar,HalfScalar,LocalOrdinal,GlobalOrdinal,Node>(tpLowA));
185 auto xpLowOpA = rcp_dynamic_cast<Operator>(xpTpLowOpA);
186 Set(currentLevel, matrixKey, xpLowOpA);
190 GetOStream(
Warnings) <<
"Matrix not converted to half precision. This only works for Tpetra and when both Scalar and HalfScalar have been instantiated." << std::endl;
191 Set(currentLevel, matrixKey, A);
MueLu::DefaultLocalOrdinal LocalOrdinal
MueLu::DefaultScalar Scalar
MueLu::DefaultGlobalOrdinal GlobalOrdinal
Timer to be used in factories. Similar to Monitor but with additional timers.
void Input(Level &level, const std::string &varName) const
T Get(Level &level, const std::string &varName) const
void Set(Level &level, const std::string &varName, const T &data) const
Class that holds all level-specific information.
Factory for converting matrices to half precision operators.
RCP< const ParameterList > GetValidParameterList() const
Return a const parameter list of valid parameters that setParameterList() will accept.
void DeclareInput(Level ¤tLevel) const
Input.
void Build(Level ¤tLevel) const
Build method.
virtual const Teuchos::ParameterList & GetParameterList() const
Wraps an existing MueLu::Hierarchy as a Tpetra::Operator.
Teuchos::FancyOStream & GetOStream(MsgType type, int thisProcRankOnly=0) const
Get an output stream for outputting the input message type.
Namespace for MueLu classes and methods.
@ Warnings
Print all warning messages.