NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
pivotLU_decomposition.hpp
Go to the documentation of this file.
1
33#pragma once
34
35#include <cmath>
36#include <tuple>
37
40#include "NumCpp/Core/Types.hpp"
43#include "NumCpp/NdArray.hpp"
45
46namespace nc::linalg
47{
48 //============================================================================
49 // Method Description:
56 template<typename dtype>
57 std::tuple<NdArray<double>, NdArray<double>, NdArray<double>> pivotLU_decomposition(const NdArray<dtype>& inMatrix)
58 {
60
61 const auto shape = inMatrix.shape();
62
63 if (!shape.issquare())
64 {
65 THROW_RUNTIME_ERROR("Input matrix should be square.");
66 }
67
68 NdArray<double> lMatrix = zeros_like<double>(inMatrix);
69 NdArray<double> uMatrix = inMatrix.template astype<double>();
70 NdArray<double> pMatrix = eye<double>(shape.rows);
71
72 for (uint32 k = 0; k < shape.rows; ++k)
73 {
74 double max = 0.;
75 uint32 pk = 0;
76 for (uint32 i = k; i < shape.rows; ++i)
77 {
78 double s = 0.;
79 for (uint32 j = k; j < shape.cols; ++j)
80 {
81 s += std::fabs(uMatrix(i, j));
82 }
83
84 const double q = std::fabs(uMatrix(i, k)) / s;
85 if (q > max)
86 {
87 max = q;
88 pk = i;
89 }
90 }
91
92 if (utils::essentiallyEqual(max, double{ 0. }))
93 {
94 THROW_RUNTIME_ERROR("Division by 0.");
95 }
96
97 if (pk != k)
98 {
99 for (uint32 j = 0; j < shape.cols; ++j)
100 {
101 std::swap(pMatrix(k, j), pMatrix(pk, j));
102 std::swap(lMatrix(k, j), lMatrix(pk, j));
103 std::swap(uMatrix(k, j), uMatrix(pk, j));
104 }
105 }
106
107 for (uint32 i = k + 1; i < shape.rows; ++i)
108 {
109 lMatrix(i, k) = uMatrix(i, k) / uMatrix(k, k);
110
111 for (uint32 j = k; j < shape.cols; ++j)
112 {
113 uMatrix(i, j) = uMatrix(i, j) - lMatrix(i, k) * uMatrix(k, j);
114 }
115 }
116 }
117
118 for (uint32 k = 0; k < shape.rows; ++k)
119 {
120 lMatrix(k, k) = 1.;
121 }
122
123 return std::make_tuple(lMatrix, uMatrix, pMatrix);
124 }
125} // namespace nc::linalg
#define THROW_RUNTIME_ERROR(msg)
Definition: Error.hpp:40
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4511
uint32 rows
Definition: Core/Shape.hpp:44
bool issquare() const noexcept
Definition: Core/Shape.hpp:125
uint32 cols
Definition: Core/Shape.hpp:45
constexpr auto j
Definition: Core/Constants.hpp:42
Definition: cholesky.hpp:41
std::tuple< NdArray< double >, NdArray< double >, NdArray< double > > pivotLU_decomposition(const NdArray< dtype > &inMatrix)
Definition: pivotLU_decomposition.hpp:57
bool essentiallyEqual(dtype inValue1, dtype inValue2) noexcept
Definition: essentiallyEqual.hpp:49
void swap(NdArray< dtype > &inArray1, NdArray< dtype > &inArray2) noexcept
Definition: swap.hpp:42
Shape shape(const NdArray< dtype > &inArray) noexcept
Definition: Functions/Shape.hpp:42
std::uint32_t uint32
Definition: Types.hpp:40
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44