NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
dot.hpp
Go to the documentation of this file.
1
28#pragma once
29
30#include <complex>
31
32#include "NumCpp/NdArray.hpp"
33
34namespace nc
35{
36 //============================================================================
37 // Method Description:
46 template<typename dtype>
47 NdArray<dtype> dot(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2)
48 {
49 return inArray1.dot(inArray2);
50 }
51
52 //============================================================================
53 // Method Description:
65 template<typename dtype>
66 NdArray<std::complex<dtype>> dot(const NdArray<dtype>& inArray1, const NdArray<std::complex<dtype>>& inArray2)
67 {
69
70 const auto shape1 = inArray1.shape();
71 const auto shape2 = inArray2.shape();
72
73 if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
74 {
75 const std::complex<dtype> dotProduct =
76 std::inner_product(inArray1.cbegin(), inArray1.cend(), inArray2.cbegin(), std::complex<dtype>{ 0 });
77 NdArray<std::complex<dtype>> returnArray = { dotProduct };
78 return returnArray;
79 }
80 if (shape1.cols == shape2.rows)
81 {
82 // 2D array, use matrix multiplication
83 NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
84 auto array2T = inArray2.transpose();
85
86 for (uint32 i = 0; i < shape1.rows; ++i)
87 {
88 for (uint32 j = 0; j < shape2.cols; ++j)
89 {
90 returnArray(i, j) = std::inner_product(array2T.cbegin(j),
91 array2T.cend(j),
92 inArray1.cbegin(i),
93 std::complex<dtype>{ 0 });
94 }
95 }
96
97 return returnArray;
98 }
99
100 std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
101 errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
102 errStr += " are not consistent.";
104
105 return NdArray<std::complex<dtype>>(); // get rid of compiler warning
106 }
107
108 //============================================================================
109 // Method Description:
121 template<typename dtype>
122 NdArray<std::complex<dtype>> dot(const NdArray<std::complex<dtype>>& inArray1, const NdArray<dtype>& inArray2)
123 {
125
126 const auto shape1 = inArray1.shape();
127 const auto shape2 = inArray2.shape();
128
129 if (shape1 == shape2 && (shape1.rows == 1 || shape1.cols == 1))
130 {
131 const std::complex<dtype> dotProduct =
132 std::inner_product(inArray1.cbegin(), inArray1.cend(), inArray2.cbegin(), std::complex<dtype>{ 0 });
133 NdArray<std::complex<dtype>> returnArray = { dotProduct };
134 return returnArray;
135 }
136 if (shape1.cols == shape2.rows)
137 {
138 // 2D array, use matrix multiplication
139 NdArray<std::complex<dtype>> returnArray(shape1.rows, shape2.cols);
140 auto array2T = inArray2.transpose();
141
142 for (uint32 i = 0; i < shape1.rows; ++i)
143 {
144 for (uint32 j = 0; j < shape2.cols; ++j)
145 {
146 returnArray(i, j) = std::inner_product(array2T.cbegin(j),
147 array2T.cend(j),
148 inArray1.cbegin(i),
149 std::complex<dtype>{ 0 });
150 }
151 }
152
153 return returnArray;
154 }
155
156 std::string errStr = "shapes of [" + utils::num2str(shape1.rows) + ", " + utils::num2str(shape1.cols) + "]";
157 errStr += " and [" + utils::num2str(shape2.rows) + ", " + utils::num2str(shape2.cols) + "]";
158 errStr += " are not consistent.";
160
161 return NdArray<std::complex<dtype>>(); // get rid of compiler warning
162 }
163} // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:139
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1365
self_type transpose() const
Definition: NdArrayCore.hpp:4882
self_type dot(const self_type &inOtherArray) const
Definition: NdArrayCore.hpp:2719
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4511
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1673
constexpr auto j
Definition: Core/Constants.hpp:42
std::string num2str(dtype inNumber)
Definition: num2str.hpp:44
Definition: Cartesian.hpp:40
NdArray< dtype > dot(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2)
Definition: dot.hpp:47
std::uint32_t uint32
Definition: Types.hpp:40