NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
SVDClass.hpp
Go to the documentation of this file.
1
29#pragma once
30
31#include <cmath>
32#include <limits>
33#include <string>
34
36#include "NumCpp/Core/Types.hpp"
37#include "NumCpp/NdArray.hpp"
39
40namespace nc::linalg
41{
42 // =============================================================================
43 // Class Description:
46 class SVD
47 {
48 public:
49 // =============================================================================
50 // Description:
55 explicit SVD(const NdArray<double>& inMatrix) :
56 m_(inMatrix.shape().rows),
57 n_(inMatrix.shape().cols),
58 u_(inMatrix),
59 v_(n_, n_),
60 s_(1, n_),
61 eps_(std::numeric_limits<double>::epsilon())
62 {
63 decompose();
64 reorder();
65 tsh_ = 0.5 * std::sqrt(m_ + n_ + 1.) * s_.front() * eps_;
66 }
67
68 // =============================================================================
69 // Description:
74 const NdArray<double>& u() noexcept
75 {
76 return u_;
77 }
78
79 // =============================================================================
80 // Description:
85 const NdArray<double>& v() noexcept
86 {
87 return v_;
88 }
89
90 // =============================================================================
91 // Description:
96 const NdArray<double>& s() noexcept
97 {
98 return s_;
99 }
100
101 // =============================================================================
102 // Description:
110 NdArray<double> solve(const NdArray<double>& inInput, double inThresh = -1.)
111 {
112 double ss{};
113
114 if (inInput.size() != m_)
115 {
116 THROW_INVALID_ARGUMENT_ERROR("bad sizes.");
117 }
118
119 NdArray<double> returnArray(1, n_);
120
121 NdArray<double> tmp(1, n_);
122
123 tsh_ = (inThresh >= 0. ? inThresh : 0.5 * sqrt(m_ + n_ + 1.) * s_.front() * eps_);
124
125 for (uint32 j = 0; j < n_; j++)
126 {
127 ss = 0.;
128 if (s_[j] > tsh_)
129 {
130 for (uint32 i = 0; i < m_; i++)
131 {
132 ss += u_(i, j) * inInput[i];
133 }
134 ss /= s_[j];
135 }
136 tmp[j] = ss;
137 }
138
139 for (uint32 j = 0; j < n_; j++)
140 {
141 ss = 0.;
142 for (uint32 jj = 0; jj < n_; jj++)
143 {
144 ss += v_(j, jj) * tmp[jj];
145 }
146
147 returnArray[j] = ss;
148 }
149
150 return returnArray;
151 }
152
153 private:
154 // =============================================================================
155 // Description:
163 static double SIGN(double inA, double inB) noexcept
164 {
165 return inB >= 0 ? (inA >= 0 ? inA : -inA) : (inA >= 0 ? -inA : inA);
166 }
167
168 // =============================================================================
169 // Description:
172 void decompose()
173 {
174 bool flag{};
175 uint32 i{};
176 uint32 its{};
177 uint32 j{};
178 uint32 jj{};
179 uint32 k{};
180 uint32 l{};
181 uint32 nm{};
182
183 double anorm{};
184 double c{};
185 double f{};
186 double g{};
187 double h{};
188 double ss{};
189 double scale{};
190 double x{};
191 double y{};
192 double z{};
193
194 NdArray<double> rv1(n_, 1);
195
196 for (i = 0; i < n_; ++i)
197 {
198 l = i + 2;
199 rv1[i] = scale * g;
200 g = ss = scale = 0.;
201
202 if (i < m_)
203 {
204 for (k = i; k < m_; ++k)
205 {
206 scale += std::abs(u_(k, i));
207 }
208
209 if (!utils::essentiallyEqual(scale, 0.))
210 {
211 for (k = i; k < m_; ++k)
212 {
213 u_(k, i) /= scale;
214 ss += u_(k, i) * u_(k, i);
215 }
216
217 f = u_(i, i);
218 g = -SIGN(std::sqrt(ss), f);
219 h = f * g - ss;
220 u_(i, i) = f - g;
221
222 for (j = l - 1; j < n_; ++j)
223 {
224 for (ss = 0., k = i; k < m_; ++k)
225 {
226 ss += u_(k, i) * u_(k, j);
227 }
228
229 f = ss / h;
230
231 for (k = i; k < m_; ++k)
232 {
233 u_(k, j) += f * u_(k, i);
234 }
235 }
236
237 for (k = i; k < m_; ++k)
238 {
239 u_(k, i) *= scale;
240 }
241 }
242 }
243
244 s_[i] = scale * g;
245 g = ss = scale = 0.;
246
247 if (i + 1 <= m_ && i + 1 != n_)
248 {
249 for (k = l - 1; k < n_; ++k)
250 {
251 scale += std::abs(u_(i, k));
252 }
253
254 if (!utils::essentiallyEqual(scale, 0.))
255 {
256 for (k = l - 1; k < n_; ++k)
257 {
258 u_(i, k) /= scale;
259 ss += u_(i, k) * u_(i, k);
260 }
261
262 f = u_(i, l - 1);
263 g = -SIGN(std::sqrt(ss), f);
264 h = f * g - ss;
265 u_(i, l - 1) = f - g;
266
267 for (k = l - 1; k < n_; ++k)
268 {
269 rv1[k] = u_(i, k) / h;
270 }
271
272 for (j = l - 1; j < m_; ++j)
273 {
274 for (ss = 0., k = l - 1; k < n_; ++k)
275 {
276 ss += u_(j, k) * u_(i, k);
277 }
278
279 for (k = l - 1; k < n_; ++k)
280 {
281 u_(j, k) += ss * rv1[k];
282 }
283 }
284
285 for (k = l - 1; k < n_; ++k)
286 {
287 u_(i, k) *= scale;
288 }
289 }
290 }
291
292 anorm = std::max(anorm, (std::abs(s_[i]) + std::abs(rv1[i])));
293 }
294
295 for (i = n_ - 1; i != static_cast<uint32>(-1); --i)
296 {
297 if (i < n_ - 1)
298 {
299 if (!utils::essentiallyEqual(g, 0.))
300 {
301 for (j = l; j < n_; ++j)
302 {
303 v_(j, i) = (u_(i, j) / u_(i, l)) / g;
304 }
305
306 for (j = l; j < n_; ++j)
307 {
308 for (ss = 0., k = l; k < n_; ++k)
309 {
310 ss += u_(i, k) * v_(k, j);
311 }
312
313 for (k = l; k < n_; ++k)
314 {
315 v_(k, j) += ss * v_(k, i);
316 }
317 }
318 }
319
320 for (j = l; j < n_; ++j)
321 {
322 v_(i, j) = v_(j, i) = 0.;
323 }
324 }
325
326 v_(i, i) = 1.;
327 g = rv1[i];
328 l = i;
329 }
330
331 for (i = std::min(m_, n_) - 1; i != static_cast<uint32>(-1); --i)
332 {
333 l = i + 1;
334 g = s_[i];
335
336 for (j = l; j < n_; ++j)
337 {
338 u_(i, j) = 0.;
339 }
340
341 if (!utils::essentiallyEqual(g, 0.))
342 {
343 g = 1. / g;
344
345 for (j = l; j < n_; ++j)
346 {
347 for (ss = 0., k = l; k < m_; ++k)
348 {
349 ss += u_(k, i) * u_(k, j);
350 }
351
352 f = (ss / u_(i, i)) * g;
353
354 for (k = i; k < m_; ++k)
355 {
356 u_(k, j) += f * u_(k, i);
357 }
358 }
359
360 for (j = i; j < m_; ++j)
361 {
362 u_(j, i) *= g;
363 }
364 }
365 else
366 {
367 for (j = i; j < m_; ++j)
368 {
369 u_(j, i) = 0.;
370 }
371 }
372
373 ++u_(i, i);
374 }
375
376 for (k = n_ - 1; k != static_cast<uint32>(-1); --k)
377 {
378 for (its = 0; its < 30; ++its)
379 {
380 flag = true;
381 for (l = k; l != static_cast<uint32>(-1); --l)
382 {
383 nm = l - 1;
384 if (l == 0 || std::abs(rv1[l]) <= eps_ * anorm)
385 {
386 flag = false;
387 break;
388 }
389
390 if (std::abs(s_[nm]) <= eps_ * anorm)
391 {
392 break;
393 }
394 }
395
396 if (flag)
397 {
398 c = 0.;
399 ss = 1.;
400 for (i = l; i < k + 1; ++i)
401 {
402 f = ss * rv1[i];
403 rv1[i] = c * rv1[i];
404
405 if (std::abs(f) <= eps_ * anorm)
406 {
407 break;
408 }
409
410 g = s_[i];
411 h = pythag(f, g);
412 s_[i] = h;
413 h = 1. / h;
414 c = g * h;
415 ss = -f * h;
416
417 for (j = 0; j < m_; ++j)
418 {
419 y = u_(j, nm);
420 z = u_(j, i);
421 u_(j, nm) = y * c + z * ss;
422 u_(j, i) = z * c - y * ss;
423 }
424 }
425 }
426
427 z = s_[k];
428 if (l == k)
429 {
430 if (z < 0.)
431 {
432 s_[k] = -z;
433 for (j = 0; j < n_; ++j)
434 {
435 v_(j, k) = -v_(j, k);
436 }
437 }
438 break;
439 }
440
441 if (its == 29)
442 {
443 THROW_INVALID_ARGUMENT_ERROR("no convergence in 30 svdcmp iterations");
444 }
445
446 x = s_[l];
447 nm = k - 1;
448 y = s_[nm];
449 g = rv1[nm];
450 h = rv1[k];
451 f = ((y - z) * (y + z) + (g - h) * (g + h)) / (2. * h * y);
452 g = pythag(f, 1.);
453 f = ((x - z) * (x + z) + h * ((y / (f + SIGN(g, f))) - h)) / x;
454 c = ss = 1.;
455
456 for (j = l; j <= nm; j++)
457 {
458 i = j + 1;
459 g = rv1[i];
460 y = s_[i];
461 h = ss * g;
462 g = c * g;
463 z = pythag(f, h);
464 rv1[j] = z;
465 c = f / z;
466 ss = h / z;
467 f = x * c + g * ss;
468 g = g * c - x * ss;
469 h = y * ss;
470 y *= c;
471
472 for (jj = 0; jj < n_; ++jj)
473 {
474 x = v_(jj, j);
475 z = v_(jj, i);
476 v_(jj, j) = x * c + z * ss;
477 v_(jj, i) = z * c - x * ss;
478 }
479
480 z = pythag(f, h);
481 s_[j] = z;
482
483 if (!utils::essentiallyEqual(z, 0.))
484 {
485 z = 1. / z;
486 c = f * z;
487 ss = h * z;
488 }
489
490 f = c * g + ss * y;
491 x = c * y - ss * g;
492
493 for (jj = 0; jj < m_; ++jj)
494 {
495 y = u_(jj, j);
496 z = u_(jj, i);
497 u_(jj, j) = y * c + z * ss;
498 u_(jj, i) = z * c - y * ss;
499 }
500 }
501 rv1[l] = 0.;
502 rv1[k] = f;
503 s_[k] = x;
504 }
505 }
506 }
507
508 // =============================================================================
509 // Description:
512 void reorder()
513 {
514 uint32 i = 0;
515 uint32 j = 0;
516 uint32 k = 0;
517 uint32 ss = 0;
518 uint32 inc = 1;
519
520 double sw{};
521 NdArray<double> su(m_, 1);
522 NdArray<double> sv(n_, 1);
523
524 do
525 {
526 inc *= 3;
527 ++inc;
528 } while (inc <= n_);
529
530 do
531 {
532 inc /= 3;
533 for (i = inc; i < n_; ++i)
534 {
535 sw = s_[i];
536 for (k = 0; k < m_; ++k)
537 {
538 su[k] = u_(k, i);
539 }
540
541 for (k = 0; k < n_; ++k)
542 {
543 sv[k] = v_(k, i);
544 }
545
546 j = i;
547 while (s_[j - inc] < sw)
548 {
549 s_[j] = s_[j - inc];
550
551 for (k = 0; k < m_; ++k)
552 {
553 u_(k, j) = u_(k, j - inc);
554 }
555
556 for (k = 0; k < n_; ++k)
557 {
558 v_(k, j) = v_(k, j - inc);
559 }
560
561 j -= inc;
562
563 if (j < inc)
564 {
565 break;
566 }
567 }
568
569 s_[j] = sw;
570
571 for (k = 0; k < m_; ++k)
572 {
573 u_(k, j) = su[k];
574 }
575
576 for (k = 0; k < n_; ++k)
577 {
578 v_(k, j) = sv[k];
579 }
580 }
581 } while (inc > 1);
582
583 for (k = 0; k < n_; ++k)
584 {
585 ss = 0;
586
587 for (i = 0; i < m_; i++)
588 {
589 if (u_(i, k) < 0.)
590 {
591 ss++;
592 }
593 }
594
595 for (j = 0; j < n_; ++j)
596 {
597 if (v_(j, k) < 0.)
598 {
599 ss++;
600 }
601 }
602
603 if (ss > (m_ + n_) / 2)
604 {
605 for (i = 0; i < m_; ++i)
606 {
607 u_(i, k) = -u_(i, k);
608 }
609
610 for (j = 0; j < n_; ++j)
611 {
612 v_(j, k) = -v_(j, k);
613 }
614 }
615 }
616 }
617
618 // =============================================================================
619 // Description:
627 static double pythag(double inA, double inB) noexcept
628 {
629 const double absa = std::abs(inA);
630 const double absb = std::abs(inB);
631 return (absa > absb
632 ? absa * std::sqrt(1. + utils::sqr(absb / absa))
633 : (utils::essentiallyEqual(absb, 0.) ? 0. : absb * std::sqrt(1. + utils::sqr(absa / absb))));
634 }
635
636 private:
637 // ===============================Attributes====================================
638 const uint32 m_{};
639 const uint32 n_{};
640 NdArray<double> u_{};
641 NdArray<double> v_{};
642 NdArray<double> s_{};
643 double eps_{};
644 double tsh_{};
645 };
646} // namespace nc::linalg
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
size_type size() const noexcept
Definition: NdArrayCore.hpp:4524
const_reference front() const noexcept
Definition: NdArrayCore.hpp:2860
Definition: SVDClass.hpp:47
NdArray< double > solve(const NdArray< double > &inInput, double inThresh=-1.)
Definition: SVDClass.hpp:110
const NdArray< double > & s() noexcept
Definition: SVDClass.hpp:96
const NdArray< double > & v() noexcept
Definition: SVDClass.hpp:85
SVD(const NdArray< double > &inMatrix)
Definition: SVDClass.hpp:55
const NdArray< double > & u() noexcept
Definition: SVDClass.hpp:74
constexpr auto j
Definition: Core/Constants.hpp:42
constexpr double c
speed of light
Definition: Core/Constants.hpp:36
Definition: cholesky.hpp:41
dtype f(GeneratorType &generator, dtype inDofN, dtype inDofD)
Definition: f.hpp:56
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:49
constexpr dtype sqr(dtype inValue) noexcept
Definition: sqr.hpp:42
NdArray< dtype > min(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: min.hpp:44
auto abs(dtype inValue) noexcept
Definition: abs.hpp:49
auto sqrt(dtype inValue) noexcept
Definition: sqrt.hpp:48
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44