NumCpp  2.16.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Loading...
Searching...
No Matches
eig.hpp
Go to the documentation of this file.
1
28#pragma once
29
30#include <utility>
31
34#include "NumCpp/NdArray.hpp"
35#include "NumCpp/Utils/sqr.hpp"
36
37namespace nc::linalg
38{
39 //============================================================================
40 // Method Description:
52 template<typename dtype>
53 std::pair<NdArray<double>, NdArray<double>> eig(const NdArray<dtype>& inA, double inTolerance = 1e-12)
54 {
56
57 if (!inA.issquare())
58 {
59 THROW_INVALID_ARGUMENT_ERROR("Input array must be square.");
60 }
61
62 const auto n = inA.numRows();
63 auto b = inA.template astype<double>();
65 auto eigenVals = NdArray<double>(1, n);
66
67 constexpr auto MAX_ITERATIONS = 10000;
68 for (auto iter = 0u; iter < MAX_ITERATIONS; ++iter)
69 {
70 auto max_off_diag = 0.;
71 auto p = 0u;
72 auto q = 1u;
73
74 for (auto i = 0u; i < n; i++)
75 {
76 for (auto j = i + 1; j < n; j++)
77 {
78 const auto val = std::fabs(b(i, j));
79 if (val > max_off_diag)
80 {
82 p = i;
83 q = j;
84 }
85 }
86 }
87
89 {
90 break;
91 }
92
93 const auto app = b(p, p);
94 const auto aqq = b(q, q);
95 const auto apq = b(p, q);
96
97 const auto theta = (aqq - app) / (2. * apq);
98 const auto onePlusThetaSqr = std::sqrt(1. + utils::sqr(theta));
99 const auto t = (theta >= 0.) ? 1. / (theta + onePlusThetaSqr) : 1. / (theta - onePlusThetaSqr);
100 const auto c = 1.0 / std::sqrt(1. + utils::sqr(t));
101 const auto s = t * c;
102
103 for (auto i = 0u; i < n; ++i)
104 {
105 if (i != p && i != q)
106 {
107 const auto bip = b(i, p);
108 const auto biq = b(i, q);
109 b(i, p) = c * bip - s * biq;
110 b(p, i) = b(i, p);
111 b(i, q) = s * bip + c * biq;
112 b(q, i) = b(i, q);
113 }
114 }
115
116 b(p, p) = c * c * app + s * s * aqq - 2. * c * s * apq;
117 b(q, q) = s * s * app + c * c * aqq + 2. * c * s * apq;
118 b(p, q) = 0.;
119 b(q, p) = 0.;
120
121 for (auto i = 0u; i < n; ++i)
122 {
123 const auto vip = eigenVectors(i, p);
124 const auto viq = eigenVectors(i, q);
125 eigenVectors(i, p) = c * vip - s * viq;
126 eigenVectors(i, q) = s * vip + c * viq;
127 }
128 }
129
130 for (auto i = 0u; i < n; ++i)
131 {
132 eigenVals[i] = b(i, i);
133 }
134
135 for (auto i = 0u; i < n - 1; ++i)
136 {
137 for (auto j = i + 1; j < n; ++j)
138 {
139 if (eigenVals[i] < eigenVals[j])
140 {
141 std::swap(eigenVals[i], eigenVals[j]);
142
143 for (auto k = 0u; k < n; ++k)
144 {
145 std::swap(eigenVectors(k, i), eigenVectors(k, j));
146 }
147 }
148 }
149 }
150
151 return std::make_pair(eigenVals, eigenVectors);
152 }
153} // namespace nc::linalg
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition StaticAsserts.hpp:39
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition NdArrayCore.hpp:139
Definition cholesky.hpp:41
std::pair< NdArray< double >, NdArray< double > > eig(const NdArray< dtype > &inA, double inTolerance=1e-12)
Definition eig.hpp:53
constexpr dtype sqr(dtype inValue) noexcept
Definition sqr.hpp:42
NdArray< dtype > arange(dtype inStart, dtype inStop, dtype inStep=1)
Definition arange.hpp:59