Sacado Package Browser (Single Doxygen Collection)
Version of the Day
Toggle main menu visibility
Loading...
Searching...
No Matches
example
dfad_view_handle_example.cpp
Go to the documentation of this file.
1
// $Id$
2
// $Source$
3
// @HEADER
4
// ***********************************************************************
5
//
6
// Sacado Package
7
// Copyright (2006) Sandia Corporation
8
//
9
// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10
// the U.S. Government retains certain rights in this software.
11
//
12
// This library is free software; you can redistribute it and/or modify
13
// it under the terms of the GNU Lesser General Public License as
14
// published by the Free Software Foundation; either version 2.1 of the
15
// License, or (at your option) any later version.
16
//
17
// This library is distributed in the hope that it will be useful, but
18
// WITHOUT ANY WARRANTY; without even the implied warranty of
19
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20
// Lesser General Public License for more details.
21
//
22
// You should have received a copy of the GNU Lesser General Public
23
// License along with this library; if not, write to the Free Software
24
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25
// USA
26
// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27
// (etphipp@sandia.gov).
28
//
29
// ***********************************************************************
30
// @HEADER
31
32
// dfad_example
33
//
34
// usage:
35
// dfad_view_handle_example
36
//
37
// output:
38
// prints the results of differentiating a simple function with forward
39
// mode AD using the Sacado::Fad::DFad class (uses dynamic memory
40
// allocation for number of derivative components) and ViewFad as a
41
// handle into externally stored derivative data
42
43
#include <iostream>
44
#include <iomanip>
45
46
#include "
Sacado.hpp
"
47
48
// The function to differentiate
49
template
<
typename
ScalarRes,
typename
Scalar1,
typename
Scalar2>
50
ScalarRes
func
(
const
Scalar1&
a
,
const
Scalar1& b,
const
Scalar2&
c
) {
51
ScalarRes r =
c
*std::log(b+1.)/std::sin(
a
);
52
53
return
r;
54
}
55
56
// The analytic derivative of func(a,b,c) with respect to a and b
57
void
func_deriv
(
double
a
,
double
b,
double
c
,
double
& drda,
double
& drdb)
58
{
59
drda = -(
c
*std::log(b+1.)/
std::pow
(std::sin(
a
),2.))*std::cos(
a
);
60
drdb =
c
/ ((b+1.)*std::sin(
a
));
61
}
62
63
int
main
(
int
argc,
char
**argv)
64
{
65
Kokkos::initialize();
66
int
ret = 0;
67
{
68
69
double
pi =
std::atan
(1.0)*4.0;
70
71
// Values of function arguments
72
double
a
= pi/4;
73
double
b = 2.0;
74
double
c
= 3.0;
75
76
// View to store derivative data
77
const
int
num_deriv = 2;
78
Kokkos::View<double**,Kokkos::LayoutLeft,Kokkos::HostSpace> v(
"v"
, 2, num_deriv );
79
80
// Initialize derivative data
81
Kokkos::deep_copy( v, 0.0 );
82
v(0,0) = 1.0;
// First (0) indep. var
83
v(1,1) = 1.0;
// Second (1) indep. var
84
85
// The Fad type
86
typedef
Sacado::Fad::DFad<double>
FadType
;
87
88
// View handle type -- first 0 is static length (e.g., SFad), second 0
89
// is static stride, which you can make 1 if you know the View will be
90
// LayoutRight (e.g., not GPU). When values are 0, they are treated
91
// dynamically
92
typedef
Sacado::Fad::ViewFad<double,0,0,FadType>
ViewFadType;
93
94
// Fad objects
95
ViewFadType afad( &v(0,0), &
a
, num_deriv, v.stride_1() );
96
ViewFadType bfad( &v(1,0), &b, num_deriv, v.stride_1() );
97
FadType
cfad(
c
);
98
FadType
rfad;
99
100
// Compute function
101
double
r =
func<double>
(
a
, b,
c
);
102
103
// Compute derivative analytically
104
double
drda, drdb;
105
func_deriv
(
a
, b,
c
, drda, drdb);
106
107
// Compute function and derivative with AD
108
rfad =
func<FadType>
(afad, bfad, cfad);
109
110
// Extract value and derivatives
111
double
r_ad = rfad.val();
// r
112
double
drda_ad = rfad.dx(0);
// dr/da
113
double
drdb_ad = rfad.dx(1);
// dr/db
114
115
// Print the results
116
int
p = 4;
117
int
w = p+7;
118
std::cout.setf(std::ios::scientific);
119
std::cout.precision(p);
120
std::cout <<
" r = "
<< r <<
" (original) == "
<< std::setw(w) << r_ad
121
<<
" (AD) Error = "
<< std::setw(w) << r - r_ad << std::endl
122
<<
"dr/da = "
<< std::setw(w) << drda <<
" (analytic) == "
123
<< std::setw(w) << drda_ad <<
" (AD) Error = "
<< std::setw(w)
124
<< drda - drda_ad << std::endl
125
<<
"dr/db = "
<< std::setw(w) << drdb <<
" (analytic) == "
126
<< std::setw(w) << drdb_ad <<
" (AD) Error = "
<< std::setw(w)
127
<< drdb - drdb_ad << std::endl;
128
129
double
tol
= 1.0e-14;
130
if
(std::fabs(r - r_ad) <
tol
&&
131
std::fabs(drda - drda_ad) <
tol
&&
132
std::fabs(drdb - drdb_ad) <
tol
) {
133
std::cout <<
"\nExample passed!"
<< std::endl;
134
ret = 0;
135
}
136
else
{
137
std::cout <<
"\nSomething is wrong, example failed!"
<< std::endl;
138
ret = 1;
139
}
140
141
}
142
Kokkos::finalize();
143
return
ret;
144
}
Sacado.hpp
a
a
Definition
Sacado_CacheFad_Ops.hpp:426
c
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
Definition
Sacado_LFad_LogicalSparseOps.hpp:450
main
int main()
Definition
ad_example.cpp:191
FadType
Sacado::Fad::DFad< double > FadType
Definition
blas_example.cpp:49
Sacado::Fad::DFad
Definition
Sacado_Fad_DFadTraits.hpp:65
Sacado::Fad::ViewFad
Definition
Sacado_Fad_ViewFadTraits.hpp:65
func_deriv
void func_deriv(double a, double b, double c, double &drda, double &drdb)
Definition
dfad_view_handle_example.cpp:57
func
ScalarRes func(const Scalar1 &a, const Scalar1 &b, const Scalar2 &c)
Definition
dfad_view_handle_example.cpp:50
std::pow
PowExprType< Expr< T1 >, Expr< T2 > >::expr_type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
Definition
Sacado_Tay_CacheTaylorOps.hpp:1730
std::atan
ATanExprType< T >::expr_type atan(const Expr< T > &expr)
Definition
Sacado_Tay_CacheTaylorOps.hpp:1838
tol
const double tol
Definition
tradoptest_01.cpp:61
Generated by
1.17.0