Sacado Package Browser (Single Doxygen Collection)
Version of the Day
Toggle main menu visibility
Loading...
Searching...
No Matches
example
trad_dfad_example.cpp
Go to the documentation of this file.
1
// $Id$
2
// $Source$
3
// @HEADER
4
// ***********************************************************************
5
//
6
// Sacado Package
7
// Copyright (2006) Sandia Corporation
8
//
9
// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
10
// the U.S. Government retains certain rights in this software.
11
//
12
// This library is free software; you can redistribute it and/or modify
13
// it under the terms of the GNU Lesser General Public License as
14
// published by the Free Software Foundation; either version 2.1 of the
15
// License, or (at your option) any later version.
16
//
17
// This library is distributed in the hope that it will be useful, but
18
// WITHOUT ANY WARRANTY; without even the implied warranty of
19
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20
// Lesser General Public License for more details.
21
//
22
// You should have received a copy of the GNU Lesser General Public
23
// License along with this library; if not, write to the Free Software
24
// Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
25
// USA
26
// Questions? Contact David M. Gay (dmgay@sandia.gov) or Eric T. Phipps
27
// (etphipp@sandia.gov).
28
//
29
// ***********************************************************************
30
// @HEADER
31
32
// trad_dfad_example
33
//
34
// usage:
35
// trad_dfad_example
36
//
37
// output:
38
// prints the results of computing the second derivative a simple function // with forward nested forward and reverse mode AD using the
39
// Sacado::Fad::DFad and Sacado::Rad::ADvar classes.
40
41
#include <iostream>
42
#include <iomanip>
43
44
#include "
Sacado_No_Kokkos.hpp
"
45
46
// The function to differentiate
47
template
<
typename
ScalarT>
48
ScalarT
func
(
const
ScalarT&
a
,
const
ScalarT& b,
const
ScalarT&
c
) {
49
ScalarT r =
c
*std::log(b+1.)/std::sin(
a
);
50
return
r;
51
}
52
53
// The analytic derivative of func(a,b,c) with respect to a and b
54
void
func_deriv
(
double
a
,
double
b,
double
c
,
double
& drda,
double
& drdb)
55
{
56
drda = -(
c
*std::log(b+1.)/
std::pow
(std::sin(
a
),2.))*std::cos(
a
);
57
drdb =
c
/ ((b+1.)*std::sin(
a
));
58
}
59
60
// The analytic second derivative of func(a,b,c) with respect to a and b
61
void
func_deriv2
(
double
a
,
double
b,
double
c
,
double
& d2rda2,
double
& d2rdb2,
62
double
& d2rdadb)
63
{
64
d2rda2 =
c
*std::log(b+1.)/std::sin(
a
) + 2.*(
c
*std::log(b+1.)/
std::pow
(std::sin(
a
),3.))*
std::pow
(std::cos(
a
),2.);
65
d2rdb2 = -
c
/ (
std::pow
(b+1.,2.)*std::sin(
a
));
66
d2rdadb = -
c
/ ((b+1.)*
std::pow
(std::sin(
a
),2.))*std::cos(
a
);
67
}
68
69
int
main
(
int
argc,
char
**argv)
70
{
71
double
pi =
std::atan
(1.0)*4.0;
72
73
// Values of function arguments
74
double
a
= pi/4;
75
double
b = 2.0;
76
double
c
= 3.0;
77
78
// Number of independent variables
79
int
num_deriv = 2;
80
81
// Fad objects
82
Sacado::Rad::ADvar< Sacado::Fad::DFad<double>
> arad =
83
Sacado::Fad::DFad<double>
(num_deriv, 0,
a
);
84
Sacado::Rad::ADvar< Sacado::Fad::DFad<double>
> brad =
85
Sacado::Fad::DFad<double>
(num_deriv, 1, b);
86
Sacado::Rad::ADvar< Sacado::Fad::DFad<double>
> crad =
c
;
87
Sacado::Rad::ADvar< Sacado::Fad::DFad<double>
> rrad;
88
89
// Compute function
90
double
r =
func
(
a
, b,
c
);
91
92
// Compute derivative analytically
93
double
drda, drdb;
94
func_deriv
(
a
, b,
c
, drda, drdb);
95
96
// Compute second derivative analytically
97
double
d2rda2, d2rdb2, d2rdadb;
98
func_deriv2
(
a
, b,
c
, d2rda2, d2rdb2, d2rdadb);
99
100
// Compute function and derivative with AD
101
rrad =
func
(arad, brad, crad);
102
103
Sacado::Rad::ADvar< Sacado::Fad::DFad<double>
>::Gradcomp();
104
105
// Extract value and derivatives
106
double
r_ad = rrad.
val
().val();
// r
107
double
drda_ad = arad.
adj
().val();
// dr/da
108
double
drdb_ad = brad.
adj
().val();
// dr/db
109
double
d2rda2_ad = arad.
adj
().dx(0);
// d^2r/da^2
110
double
d2rdadb_ad = arad.
adj
().dx(1);
// d^2r/dadb
111
double
d2rdbda_ad = brad.
adj
().dx(0);
// d^2r/dbda
112
double
d2rdb2_ad = brad.
adj
().dx(1);
// d^2/db^2
113
114
// Print the results
115
int
p = 4;
116
int
w = p+7;
117
std::cout.setf(std::ios::scientific);
118
std::cout.precision(p);
119
std::cout <<
" r = "
<< std::setw(w) << r <<
" (original) == "
120
<< std::setw(w) << r_ad <<
" (AD) Error = "
<< std::setw(w)
121
<< r - r_ad << std::endl
122
<<
" dr/da = "
<< std::setw(w) << drda <<
" (analytic) == "
123
<< std::setw(w) << drda_ad <<
" (AD) Error = "
<< std::setw(w)
124
<< drda - drda_ad << std::endl
125
<<
" dr/db = "
<< std::setw(w) << drdb <<
" (analytic) == "
126
<< std::setw(w) << drdb_ad <<
" (AD) Error = "
<< std::setw(w)
127
<< drdb - drdb_ad << std::endl
128
<<
"d^2r/da^2 = "
<< std::setw(w) << d2rda2 <<
" (analytic) == "
129
<< std::setw(w) << d2rda2_ad <<
" (AD) Error = "
<< std::setw(w)
130
<< d2rda2 - d2rda2_ad << std::endl
131
<<
"d^2r/db^2 = "
<< std::setw(w) << d2rdb2 <<
" (analytic) == "
132
<< std::setw(w) << d2rdb2_ad <<
" (AD) Error = "
<< std::setw(w)
133
<< d2rdb2 - d2rdb2_ad << std::endl
134
<<
"d^2r/dadb = "
<< std::setw(w) << d2rdadb <<
" (analytic) == "
135
<< std::setw(w) << d2rdadb_ad <<
" (AD) Error = "
<< std::setw(w)
136
<< d2rdadb - d2rdadb_ad << std::endl
137
<<
"d^2r/dbda = "
<< std::setw(w) << d2rdadb <<
" (analytic) == "
138
<< std::setw(w) << d2rdbda_ad <<
" (AD) Error = "
<< std::setw(w)
139
<< d2rdadb - d2rdbda_ad << std::endl;
140
141
// Free Rad's memory to avoid memory leaks. The zero_out() call is
142
// necessary to destroy dynamically allocated DFad arrays (which are
143
// stored outside of Rad's memory management).
144
Sacado::Rad::ADcontext< Sacado::Fad::DFad<double>
>::zero_out();
145
Sacado::Rad::ADcontext< Sacado::Fad::DFad<double>
>::free_all();
146
147
double
tol
= 1.0e-14;
148
if
(std::fabs(r - r_ad) <
tol
&&
149
std::fabs(drda - drda_ad) <
tol
&&
150
std::fabs(drdb - drdb_ad) <
tol
&&
151
std::fabs(d2rda2 - d2rda2_ad) <
tol
&&
152
std::fabs(d2rdb2 - d2rdb2_ad) <
tol
&&
153
std::fabs(d2rdadb - d2rdadb_ad) <
tol
) {
154
std::cout <<
"\nExample passed!"
<< std::endl;
155
return
0;
156
}
157
else
{
158
std::cout <<
"\nSomething is wrong, example failed!"
<< std::endl;
159
return
1;
160
}
161
}
a
a
Definition
Sacado_CacheFad_Ops.hpp:426
c
expr expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c *expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr1 c expr2 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 expr2 expr1 expr2 expr1 expr1 expr1 c
Definition
Sacado_LFad_LogicalSparseOps.hpp:450
Sacado_No_Kokkos.hpp
main
int main()
Definition
ad_example.cpp:191
Sacado::Fad::DFad
Definition
Sacado_Fad_DFadTraits.hpp:65
Sacado::Rad::ADcontext
Definition
Sacado_trad.hpp:231
Sacado::Rad::ADvar
Definition
Sacado_trad.hpp:851
Sacado::Rad::IndepADvar::val
Double val() const
Definition
Sacado_trad.hpp:755
Sacado::Rad::IndepADvar::adj
Double adj() const
Definition
Sacado_trad.hpp:762
std::pow
PowExprType< Expr< T1 >, Expr< T2 > >::expr_type pow(const Expr< T1 > &expr1, const Expr< T2 > &expr2)
Definition
Sacado_Tay_CacheTaylorOps.hpp:1730
std::atan
ATanExprType< T >::expr_type atan(const Expr< T > &expr)
Definition
Sacado_Tay_CacheTaylorOps.hpp:1838
func_deriv
void func_deriv(double a, double b, double c, double &drda, double &drdb)
Definition
trad_dfad_example.cpp:54
func_deriv2
void func_deriv2(double a, double b, double c, double &d2rda2, double &d2rdb2, double &d2rdadb)
Definition
trad_dfad_example.cpp:61
func
ScalarT func(const ScalarT &a, const ScalarT &b, const ScalarT &c)
Definition
trad_dfad_example.cpp:48
tol
const double tol
Definition
tradoptest_01.cpp:61
Generated by
1.17.0