10 template <
typename FieldLHS,
typename FieldRHS>
16 template <
typename FieldLHS,
typename FieldRHS>
18 const Layout_t& layout_r = this->rhs_mp->getLayout();
26 std::array<bool, Dim> isParallel = layout_r.
isParallel();
27 for (
unsigned d = 0; d <
Dim; ++d) {
29 originComplex[d] = 0.0;
31 if (this->params_m.template get<int>(
"r2c_direction") == (
int)d) {
32 domainComplex[d] =
Index(domain_m[d].length() / 2 + 1);
34 domainComplex[d] =
Index(domain_m[d].length());
38 layoutComplex_mp = std::make_shared<Layout_t>(layout_r.
comm, domainComplex, isParallel);
40 mesh_type meshComplex(domainComplex, hComplex, originComplex);
42 fieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
44 if (this->params_m.template get<int>(
"output_type") == Base::GRAD) {
45 tempFieldComplex_m.initialize(meshComplex, *layoutComplex_mp);
48 fft_mp = std::make_shared<FFT_t>(layout_r, *layoutComplex_mp, this->params_m);
49 fft_mp->warmup(*this->rhs_mp, fieldComplex_m);
52 template <
typename FieldLHS,
typename FieldRHS>
54 fft_mp->transform(
FORWARD, *this->rhs_mp, fieldComplex_m);
56 auto view = fieldComplex_m.getView();
57 const int nghost = fieldComplex_m.getNghost();
60 const mesh_type& mesh = this->rhs_mp->get_mesh();
61 const auto& lDomComplex = layoutComplex_mp->getLocalNDIndex();
62 using vector_type =
typename mesh_type::vector_type;
68 for (
size_t d = 0; d <
Dim; ++d) {
69 N[d] = domain_m[d].length();
70 rmax[d] = origin[d] + (N[d] * hx[d]);
77 switch (this->params_m.template get<int>(
"output_type")) {
81 KOKKOS_LAMBDA(
const index_array_type& args) {
83 for (
unsigned d = 0; d <
Dim; ++d) {
84 iVec[d] += lDomComplex[d].first();
89 for (
size_t d = 0; d <
Dim; ++d) {
91 bool shift = (iVec[d] > (N[d] / 2));
92 kVec[d] = 2 *
pi / Len * (iVec[d] - shift * N[d]);
96 for (
unsigned d = 0; d <
Dim; ++d) {
97 Dr += kVec[d] * kVec[d];
100 bool isNotZero = (Dr != 0.0);
101 scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
103 apply(view, args) *= factor;
106 fft_mp->transform(
BACKWARD, *this->rhs_mp, fieldComplex_m);
115 auto tempview = tempFieldComplex_m.getView();
116 auto viewRhs = this->rhs_mp->getView();
117 auto viewLhs = this->lhs_mp->getView();
118 const int nghostL = this->lhs_mp->getNghost();
120 for (
size_t gd = 0; gd <
Dim; ++gd) {
122 "Gradient FFTPeriodicPoissonSolver",
getRangePolicy(view, nghost),
123 KOKKOS_LAMBDA(
const index_array_type& args) {
125 for (
unsigned d = 0; d <
Dim; ++d) {
126 iVec[d] += lDomComplex[d].first();
131 for (
size_t d = 0; d <
Dim; ++d) {
133 bool shift = (iVec[d] > (N[d] / 2));
134 bool notMid = (iVec[d] != (N[d] / 2));
137 kVec[d] = notMid * 2 *
pi / Len * (iVec[d] - shift * N[d]);
141 for (
unsigned d = 0; d <
Dim; ++d) {
142 Dr += kVec[d] * kVec[d];
147 bool isNotZero = (Dr != 0.0);
148 scalar_type factor = isNotZero * (1.0 / (Dr + ((!isNotZero) * 1.0)));
150 apply(tempview, args) *= -(imag * kVec[gd] * factor);
153 fft_mp->transform(
BACKWARD, *this->rhs_mp, tempFieldComplex_m);
156 "Assign Gradient FFTPeriodicPoissonSolver",
158 KOKKOS_LAMBDA(
const index_array_type& args) {
159 apply(viewLhs, args)[gd] =
apply(viewRhs, args);
167 throw IpplException(
"FFTPeriodicPoissonSolver::solve",
"Unrecognized output_type");
Implementations for FFT constructor/destructor and transforms.
void initialize(int &argc, char *argv[], MPI_Comm comm)
KOKKOS_INLINE_FUNCTION constexpr decltype(auto) apply(const View &view, const Coords &coords)
RangePolicy< View::rank, typenameView::execution_space, PolicyArgs... >::policy_type getRangePolicy(const View &view, int shift=0)
void parallel_for(const std::string &name, const ExecPolicy &policy, const FunctorType &functor)
const NDIndex< Dim > & getDomain() const
std::array< bool, Dim > isParallel() const
typename FieldLHS::Mesh_t::vector_type vector_type
typename FFT_t::Complex_t Complex_t
void setRhs(rhs_type &rhs) override
typename FieldRHS::Mesh_t mesh_type
typename FieldLHS::Mesh_t::value_type scalar_type