NumCpp  2.16.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
Loading...
Searching...
No Matches
dot.hpp
Go to the documentation of this file.
1
28#pragma once
29
30#include <complex>
31
33#include "NumCpp/NdArray.hpp"
34
35namespace nc
36{
37 //============================================================================
38 // Method Description:
47 template<typename dtype>
52
53 //============================================================================
54 // Method Description:
66 template<typename dtype>
68 {
70
71 const auto shape1 = inArray1.shape();
72 const auto shape2 = inArray2.shape();
73
74 if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
75 {
76 const std::complex<dtype> dotProduct = stl_algorithms::transform_reduce(inArray1.cbegin(),
77 inArray1.cend(),
78 inArray2.cbegin(),
79 std::complex<dtype>{ 0 });
81 return returnArray;
82 }
83 if (shape1.cols == shape2.rows)
84 {
85 // 2D array, use matrix multiplication
87 auto array2T = inArray2.transpose();
88
89 for (uint32 i = 0; i < shape1.rows; ++i)
90 {
91 for (uint32 j = 0; j < shape2.cols; ++j)
92 {
94 array2T.cend(j),
95 inArray1.cbegin(i),
96 std::complex<dtype>{ 0 });
97 }
98 }
99
100 return returnArray;
101 }
102
103 std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
104 errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
105 errStr += " are not consistent.";
107
108 return NdArray<std::complex<dtype>>(); // get rid of compiler warning
109 }
110
111 //============================================================================
112 // Method Description:
124 template<typename dtype>
126 {
128
129 const auto shape1 = inArray1.shape();
130 const auto shape2 = inArray2.shape();
131
132 if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
133 {
134 const std::complex<dtype> dotProduct = stl_algorithms::transform_reduce(inArray1.cbegin(),
135 inArray1.cend(),
136 inArray2.cbegin(),
137 std::complex<dtype>{ 0 });
139 return returnArray;
140 }
141 if (shape1.cols == shape2.rows)
142 {
143 // 2D array, use matrix multiplication
145 auto array2T = inArray2.transpose();
146
147 for (uint32 i = 0; i < shape1.rows; ++i)
148 {
149 for (uint32 j = 0; j < shape2.cols; ++j)
150 {
152 array2T.cend(j),
153 inArray1.cbegin(i),
154 std::complex<dtype>{ 0 });
155 }
156 }
157
158 return returnArray;
159 }
160
161 std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
162 errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
163 errStr += " are not consistent.";
165
166 return NdArray<std::complex<dtype>>(); // get rid of compiler warning
167 }
168} // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition StaticAsserts.hpp:39
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition NdArrayCore.hpp:139
self_type dot(const self_type &inOtherArray) const
Definition NdArrayCore.hpp:2795
T transform_reduce(ForwardIt1 first1, ForwardIt1 last1, ForwardIt2 first2, T init)
Definition StlAlgorithms.hpp:825
std::string num2str(dtype inNumber)
Definition num2str.hpp:44
Definition Cartesian.hpp:40
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition dot.hpp:48
NdArray< dtype > arange(dtype inStart, dtype inStop, dtype inStep=1)
Definition arange.hpp:59
std::uint32_t uint32
Definition Types.hpp:40