NumCpp  2.16.0
A Templatized Header Only C++ Implementation of the Python NumPy Library
Loading...
Searching...
No Matches
lstsq.hpp
Go to the documentation of this file.
1
28#pragma once
29
32#include "NumCpp/NdArray.hpp"
33
34namespace nc::linalg
35{
36 //============================================================================
37 // Method Description:
56 template<typename dtype>
58 {
60
61 const auto& aShape = inA.shape();
62 const auto& bShape = inB.shape();
63
64 const auto bIsFlat = inB.isflat();
65 if (bIsFlat && bShape.size() != aShape.rows)
66 {
67 THROW_INVALID_ARGUMENT_ERROR("Invalid matrix dimensions");
68 }
69 else if (!bIsFlat && inA.shape().rows != bShape.rows)
70 {
71 THROW_INVALID_ARGUMENT_ERROR("Invalid matrix dimensions");
72 }
73
74 SVD svd(inA.template astype<double>());
75
76 if (bIsFlat)
77 {
78 return svd.lstsq(inB.template astype<double>());
79 }
80
81 const auto bCast = inB.template astype<double>();
82 const auto bRowSlice = bCast.rSlice();
83
84 auto result = NdArray<double>(aShape.cols, bShape.cols);
85 const auto resultRowSlice = result.rSlice();
86
87 for (uint32 col = 0; col < bShape.cols; ++col)
88 {
89 result.put(resultRowSlice, col, svd.lstsq(bCast(bRowSlice, col)));
90 }
91
92 return result;
93 }
94} // namespace nc::linalg
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition StaticAsserts.hpp:39
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition NdArrayCore.hpp:139
Performs the singular value decomposition of a general matrix.
Definition svd/svd.hpp:50
Definition cholesky.hpp:41
void svd(const NdArray< dtype > &inArray, NdArray< double > &outU, NdArray< double > &outS, NdArray< double > &outVT)
Definition svd.hpp:51
NdArray< double > lstsq(const NdArray< dtype > &inA, const NdArray< dtype > &inB)
Definition lstsq.hpp:57
NdArray< dtype > arange(dtype inStart, dtype inStop, dtype inStep=1)
Definition arange.hpp:59
std::uint32_t uint32
Definition Types.hpp:40