NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
Brent.hpp
Go to the documentation of this file.
1
33#pragma once
34
35#include <cmath>
36#include <functional>
37#include <utility>
38
40#include "NumCpp/Core/Types.hpp"
43
44namespace nc::roots
45{
46 //================================================================================
47 // Class Description:
50 class Brent : public Iteration
51 {
52 public:
53 //============================================================================
54 // Method Description:
60 Brent(const double epsilon, std::function<double(double)> f) noexcept :
61 Iteration(epsilon),
62 f_(std::move(f))
63 {
64 }
65
66 //============================================================================
67 // Method Description:
74 Brent(const double epsilon, const uint32 maxNumIterations, std::function<double(double)> f) noexcept :
75 Iteration(epsilon, maxNumIterations),
76 f_(std::move(f))
77 {
78 }
79
80 //============================================================================
81 // Method Description:
84 ~Brent() override = default;
85
86 //============================================================================
87 // Method Description:
94 double solve(double a, double b)
95 {
97
98 double fa = f_(a);
99 double fb = f_(b);
100
101 checkAndFixAlgorithmCriteria(a, b, fa, fb);
102
103 double lastB = a; // b_{k-1}
104 double lastFb = fa;
105 double s = DtypeInfo<double>::max();
106 double fs = DtypeInfo<double>::max();
107 double penultimateB = a; // b_{k-2}
108
109 bool bisection = true;
110 while (std::fabs(fb) > epsilon_ && std::fabs(fs) > epsilon_ && std::fabs(b - a) > epsilon_)
111 {
112 if (useInverseQuadraticInterpolation(fa, fb, lastFb))
113 {
114 s = calculateInverseQuadraticInterpolation(a, b, lastB, fa, fb, lastFb);
115 }
116 else
117 {
118 s = calculateSecant(a, b, fa, fb);
119 }
120
121 if (useBisection(bisection, b, lastB, penultimateB, s))
122 {
123 s = calculateBisection(a, b);
124 bisection = true;
125 }
126 else
127 {
128 bisection = false;
129 }
130
131 fs = f_(s);
132 penultimateB = lastB;
133 lastB = b;
134
135 if (fa * fs < 0)
136 {
137 b = s;
138 }
139 else
140 {
141 a = s;
142 }
143
144 fa = f_(a);
145 lastFb = fb;
146 fb = f_(b);
147 checkAndFixAlgorithmCriteria(a, b, fa, fb);
148
150 }
151
152 return fb < fs ? b : s;
153 }
154
155 private:
156 //============================================================================
157 const std::function<double(double)> f_;
158
159 //============================================================================
160 // Method Description:
167 static double calculateBisection(const double a, const double b) noexcept
168 {
169 return 0.5 * (a + b);
170 }
171
172 //============================================================================
173 // Method Description:
182 static double calculateSecant(const double a, const double b, const double fa, const double fb) noexcept
183 {
184 // No need to check division by 0, in this case the method returns NAN which is taken care by
185 // useSecantMethod method
186 return b - fb * (b - a) / (fb - fa);
187 }
188
189 //============================================================================
190 // Method Description:
201 static double calculateInverseQuadraticInterpolation(const double a,
202 const double b,
203 const double lastB,
204 const double fa,
205 const double fb,
206 const double lastFb) noexcept
207 {
208 return a * fb * lastFb / ((fa - fb) * (fa - lastFb)) + b * fa * lastFb / ((fb - fa) * (fb - lastFb)) +
209 lastB * fa * fb / ((lastFb - fa) * (lastFb - fb));
210 }
211
212 //============================================================================
213 // Method Description:
221 static bool useInverseQuadraticInterpolation(const double fa, const double fb, const double lastFb) noexcept
222 {
223 return !utils::essentiallyEqual(fa, lastFb) && utils::essentiallyEqual(fb, lastFb);
224 }
225
226 //============================================================================
227 // Method Description:
235 static void checkAndFixAlgorithmCriteria(double &a, double &b, double &fa, double &fb) noexcept
236 {
237 // Algorithm works in range [a,b] if criteria f(a)*f(b) < 0 and f(a) > f(b) is fulfilled
238 if (std::fabs(fa) < std::fabs(fb))
239 {
240 std::swap(a, b);
241 std::swap(fa, fb);
242 }
243 }
244
245 //============================================================================
246 // Method Description:
256 [[nodiscard]] bool useBisection(const bool bisection,
257 const double b,
258 const double lastB,
259 const double penultimateB,
260 const double s) const noexcept
261 {
262 const double DELTA = epsilon_ + std::numeric_limits<double>::min();
263
264 return (bisection &&
265 std::fabs(s - b) >=
266 0.5 * std::fabs(b - lastB)) || // Bisection was used in last step but |s-b|>=|b-lastB|/2 <-
267 // Interpolation step would be to rough, so still use bisection
268 (!bisection && std::fabs(s - b) >=
269 0.5 * std::fabs(lastB - penultimateB)) || // Interpolation was used in last step
270 // but |s-b|>=|lastB-penultimateB|/2 <-
271 // Interpolation step would be to small
272 (bisection &&
273 std::fabs(b - lastB) < DELTA) || // If last iteration was using bisection and difference between
274 // b and lastB is < delta use bisection for next iteration
275 (!bisection && std::fabs(lastB - penultimateB) <
276 DELTA); // If last iteration was using interpolation but difference between
277 // lastB ond penultimateB is < delta use biscetion for next iteration
278 }
279 };
280} // namespace nc::roots
static constexpr dtype max() noexcept
Definition: DtypeInfo.hpp:110
Definition: Brent.hpp:51
Brent(const double epsilon, const uint32 maxNumIterations, std::function< double(double)> f) noexcept
Definition: Brent.hpp:74
double solve(double a, double b)
Definition: Brent.hpp:94
~Brent() override=default
Brent(const double epsilon, std::function< double(double)> f) noexcept
Definition: Brent.hpp:60
ABC for iteration classes to derive from.
Definition: Iteration.hpp:46
Iteration(double epsilon) noexcept
Definition: Iteration.hpp:54
const double epsilon_
Definition: Iteration.hpp:114
void resetNumberOfIterations() noexcept
Definition: Iteration.hpp:94
void incrementNumberOfIterations()
Definition: Iteration.hpp:103
dtype f(GeneratorType &generator, dtype inDofN, dtype inDofD)
Definition: f.hpp:56
Definition: Bisection.hpp:43
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:49
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40