NumCpp  2.16.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Loading...
Searching...
No Matches
svd/svd.hpp
Go to the documentation of this file.
1
29#pragma once
30
31#include <cmath>
32#include <limits>
33#include <string>
34
36#include "NumCpp/Core/Types.hpp"
40#include "NumCpp/Linalg/eig.hpp"
41#include "NumCpp/NdArray.hpp"
42
43namespace nc::linalg
44{
45 // =============================================================================
46 // Class Description:
48 template<typename dtype>
49 class SVD
50 {
51 public:
53
54 static constexpr auto TOLERANCE = 1e-12;
55
56 // =============================================================================
57 // Description:
62 explicit SVD(const NdArray<dtype>& inMatrix) :
63 m_{ inMatrix.shape().rows },
64 n_{ inMatrix.shape().cols },
65 s_(1, m_)
66 {
67 compute(inMatrix.template astype<double>());
68 }
69
70 // =============================================================================
71 // Description:
77 {
78 return u_;
79 }
80
81 // =============================================================================
82 // Description:
88 {
89 return v_;
90 }
91
92 // =============================================================================
93 // Description:
99 {
100 return s_;
101 }
102
103 // =============================================================================
104 // Description:
110 {
111 // lazy evaluation
112 if (pinv_.isempty())
113 {
114 auto sInverse = nc::zeros<double>(n_, m_); // transpose
115 for (auto i = 0u; i < std::min(m_, n_); ++i)
116 {
117 if (s_[i] > TOLERANCE)
118 {
119 sInverse(i, i) = 1. / s_[i];
120 }
121 }
122
123 pinv_ = dot(v_, dot(sInverse, u_.transpose()));
124 }
125
126 return pinv_;
127 }
128
129 // =============================================================================
130 // Description:
138 {
139 if (inInput.size() != m_)
140 {
141 THROW_INVALID_ARGUMENT_ERROR("Invalid matrix dimensions");
142 }
143
144 if (inInput.numCols() == 1)
145 {
146 return dot(pinv(), inInput);
147 }
148 else
149 {
150 const auto input = inInput.copy().reshape(inInput.size(), 1);
151 return dot(pinv(), input);
152 }
153 }
154
155 private:
156 // =============================================================================
157 // Description:
162 void compute(const NdArray<double>& A)
163 {
164 const auto At = A.transpose();
165 const auto AtA = dot(At, A);
166 const auto AAt = dot(A, At);
167
168 const auto& [sigmaSquaredU, U] = eig(AAt);
169 const auto& [sigmaSquaredV, V] = eig(AtA);
170
171 auto rank = 0u;
172 for (auto i = 0u; i < std::min(m_, n_); ++i)
173 {
174 if (sigmaSquaredV[i] > TOLERANCE)
175 {
176 s_[i] = std::sqrt(sigmaSquaredV[i]);
177 rank++;
178 }
179 }
180
181 // std::cout << U.front() << ' ' << U.back() << '\n';
182 // std::cout << V.front() << ' ' << V.back() << '\n';
183 // std::cout << "hello world\n";
184
185 u_ = std::move(U);
186 v_ = std::move(V);
187
188 auto Av = NdArray<double>(m_, 1);
189 for (auto i = 0u; i < rank; ++i)
190 {
191 for (auto j = 0u; j < m_; ++j)
192 {
193 auto sum = 0.;
194 for (auto k = 0u; k < n_; ++k)
195 {
196 sum += A(j, k) * v_(k, i);
197 }
198 Av[j] = sum;
199 }
200
201 const auto normalization = norm(Av).item();
202
204 {
205 for (auto j = 0u; j < m_; ++j)
206 {
207 u_(j, i) = Av[j] / normalization;
208 }
209 }
210 }
211 }
212
213 private:
214 // ===============================Attributes====================================
215 const uint32 m_{};
216 const uint32 n_{};
217 NdArray<double> u_{};
218 NdArray<double> v_{};
219 NdArray<double> s_{};
220 NdArray<double> pinv_{};
221 };
222} // namespace nc::linalg
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition Error.hpp:37
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition NdArrayCore.hpp:139
self_type transpose() const
Definition NdArrayCore.hpp:4959
bool isempty() const noexcept
Definition NdArrayCore.hpp:3008
value_type item() const
Definition NdArrayCore.hpp:3098
Performs the singular value decomposition of a general matrix.
Definition svd/svd.hpp:50
const NdArray< double > & u() const noexcept
Definition svd/svd.hpp:76
const NdArray< double > & s() const noexcept
Definition svd/svd.hpp:98
STATIC_ASSERT_ARITHMETIC(dtype)
static constexpr auto TOLERANCE
Definition svd/svd.hpp:54
NdArray< double > pinv()
Definition svd/svd.hpp:109
SVD(const NdArray< dtype > &inMatrix)
Definition svd/svd.hpp:62
NdArray< double > lstsq(const NdArray< double > &inInput)
Definition svd/svd.hpp:137
const NdArray< double > & v() const noexcept
Definition svd/svd.hpp:87
constexpr auto j
Definition Core/Constants.hpp:42
Definition cholesky.hpp:41
std::pair< NdArray< double >, NdArray< double > > eig(const NdArray< dtype > &inA, double inTolerance=1e-12)
Definition eig.hpp:53
NdArray< double > norm(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition norm.hpp:51
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition dot.hpp:47
NdArray< dtype > arange(dtype inStart, dtype inStop, dtype inStep=1)
Definition arange.hpp:59
NdArray< dtype > sum(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition sum.hpp:46
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition Functions/shape.hpp:42
std::uint32_t uint32
Definition Types.hpp:40