NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
NdArrayBroadcast.hpp
Go to the documentation of this file.
1
28#pragma once
29
30#include <cmath>
31#include <utility>
32
35#include "NumCpp/Core/Types.hpp"
37
39{
40 //============================================================================
41 // Method Description:
51 template<typename dtypeIn1, typename dtypeIn2, typename Function, typename... AdditionalFunctionArgs>
53 const NdArray<dtypeIn2>& inArray2,
54 const Function& function,
55 const AdditionalFunctionArgs&&... additionalFunctionArgs)
56 {
57 if (inArray1.shape() == inArray2.shape())
58 {
60 inArray1.cbegin(),
61 inArray1.cend(),
62 inArray2.cbegin(),
63 inArray1.begin(),
64 [&function, &additionalFunctionArgs...](const auto& inValue1, const auto& inValue2) -> dtypeIn1 {
65 return function(inValue1,
66 inValue2,
67 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
68 });
69 }
70 else if (inArray2.isscalar())
71 {
72 const auto value = inArray2.item();
74 inArray1.cbegin(),
75 inArray1.cend(),
76 inArray1.begin(),
77 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1
78 { return function(inValue, value, std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...); });
79 }
80 else if (inArray2.isflat())
81 {
82 if (inArray2.numRows() > 1 && inArray2.numRows() == inArray1.numRows())
83 {
84 for (uint32 row = 0; row < inArray1.numRows(); ++row)
85 {
86 const auto value = inArray2[row];
88 inArray1.cbegin(row),
89 inArray1.cend(row),
90 inArray1.begin(row),
91 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1 {
92 return function(inValue,
93 value,
94 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
95 });
96 }
97 }
98 else if (inArray2.numCols() > 1 && inArray2.numCols() == inArray1.numCols())
99 {
100 for (uint32 col = 0; col < inArray1.numCols(); ++col)
101 {
102 const auto value = inArray2[col];
104 inArray1.ccolbegin(col),
105 inArray1.ccolend(col),
106 inArray1.colbegin(col),
107 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeIn1 {
108 return function(inValue,
109 value,
110 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
111 });
112 }
113 }
114 else
115 {
116 THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
117 }
118 }
119 else
120 {
121 THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
122 }
123
124 return inArray1;
125 }
126
127 //============================================================================
128 // Method Description:
138 template<typename dtypeOut,
139 typename dtypeIn1,
140 typename dtypeIn2,
141 typename Function,
142 typename... AdditionalFunctionArgs>
144 const NdArray<dtypeIn2>& inArray2,
145 const Function& function,
146 const AdditionalFunctionArgs&&... additionalFunctionArgs)
147 {
148 if (inArray1.shape() == inArray2.shape())
149 {
150 return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
151 {
152 NdArray<dtypeOut> returnArray(inArray1.shape());
154 inArray1.cbegin(),
155 inArray1.cend(),
156 inArray2.cbegin(),
157 returnArray.begin(),
158 [&function, &additionalFunctionArgs...](const auto& inValue1, const auto& inValue2) -> dtypeOut {
159 return function(inValue1,
160 inValue2,
161 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
162 });
163
164 return returnArray;
165 }();
166 }
167 else if (inArray1.isscalar())
168 {
169 const auto value = inArray1.item();
170 return [&inArray2, &value, &function, &additionalFunctionArgs...]
171 {
172 NdArray<dtypeOut> returnArray(inArray2.shape());
174 inArray2.cbegin(),
175 inArray2.cend(),
176 returnArray.begin(),
177 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
178 return function(inValue,
179 value,
180 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
181 });
182 return returnArray;
183 }();
184 }
185 else if (inArray2.isscalar())
186 {
187 const auto value = inArray2.item();
188 return [&inArray1, &value, &function, &additionalFunctionArgs...]
189 {
190 NdArray<dtypeOut> returnArray(inArray1.shape());
192 inArray1.cbegin(),
193 inArray1.cend(),
194 returnArray.begin(),
195 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
196 return function(inValue,
197 value,
198 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
199 });
200 return returnArray;
201 }();
202 }
203 else if (inArray1.isflat() && inArray2.isflat())
204 {
205 return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
206 {
207 const auto numRows = std::max(inArray1.numRows(), inArray2.numRows());
208 const auto numCols = std::max(inArray1.numCols(), inArray2.numCols());
209 NdArray<dtypeOut> returnArray(numRows, numCols);
210 if (inArray1.numRows() > 1)
211 {
212 for (uint32 row = 0; row < inArray1.numRows(); ++row)
213 {
214 for (uint32 col = 0; col < inArray2.numCols(); ++col)
215 {
216 returnArray(row, col) =
217 function(inArray1[row],
218 inArray2[col],
219 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
220 }
221 }
222 }
223 else
224 {
225 for (uint32 row = 0; row < inArray2.numRows(); ++row)
226 {
227 for (uint32 col = 0; col < inArray1.numCols(); ++col)
228 {
229 returnArray(row, col) =
230 function(inArray1[col],
231 inArray2[row],
232 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
233 }
234 }
235 }
236 return returnArray;
237 }();
238 }
239 else if (inArray1.isflat())
240 {
241 return broadcaster<dtypeOut>(inArray2,
242 inArray1,
243 function,
244 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
245 }
246 else if (inArray2.isflat())
247 {
248 if (inArray2.numRows() > 1 && inArray2.numRows() == inArray1.numRows())
249 {
250 return [&inArray1, &inArray2, &function, &additionalFunctionArgs...]
251 {
252 NdArray<dtypeOut> returnArray(inArray1.shape());
253 for (uint32 row = 0; row < inArray1.numRows(); ++row)
254 {
255 const auto value = inArray2[row];
257 inArray1.cbegin(row),
258 inArray1.cend(row),
259 returnArray.begin(row),
260 [&value, &function, &additionalFunctionArgs...](const auto& inValue) -> dtypeOut {
261 return function(inValue,
262 value,
263 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...);
264 });
265 }
266 return returnArray;
267 }();
268 }
269 else if (inArray2.numCols() > 1 && inArray2.numCols() == inArray1.numCols())
270 {
271 return broadcaster<dtypeOut>(inArray1.transpose(),
272 inArray2.transpose(),
273 function,
274 std::forward<AdditionalFunctionArgs>(additionalFunctionArgs)...)
275 .transpose();
276 }
277 else
278 {
279 THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
280 }
281 }
282 else
283 {
284 THROW_INVALID_ARGUMENT_ERROR("operands could not be broadcast together");
285 }
286
287 return {}; // get rid of compiler warning
288 }
289} // namespace nc::broadcast
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:139
const_iterator cbegin() const noexcept
Definition: NdArrayCore.hpp:1365
const_column_iterator ccolbegin() const noexcept
Definition: NdArrayCore.hpp:1442
self_type transpose() const
Definition: NdArrayCore.hpp:4882
bool isflat() const noexcept
Definition: NdArrayCore.hpp:2945
size_type numCols() const noexcept
Definition: NdArrayCore.hpp:3465
column_iterator colbegin() noexcept
Definition: NdArrayCore.hpp:1392
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4511
size_type numRows() const noexcept
Definition: NdArrayCore.hpp:3477
const_iterator cend() const noexcept
Definition: NdArrayCore.hpp:1673
iterator begin() noexcept
Definition: NdArrayCore.hpp:1315
bool isscalar() const noexcept
Definition: NdArrayCore.hpp:2956
value_type item() const
Definition: NdArrayCore.hpp:3022
const_column_iterator ccolend() const noexcept
Definition: NdArrayCore.hpp:1827
Definition: NdArrayBroadcast.hpp:39
NdArray< dtypeIn1 > & broadcaster(NdArray< dtypeIn1 > &inArray1, const NdArray< dtypeIn2 > &inArray2, const Function &function, const AdditionalFunctionArgs &&... additionalFunctionArgs)
Definition: NdArrayBroadcast.hpp:52
OutputIt transform(InputIt first, InputIt last, OutputIt destination, UnaryOperation unaryFunction)
Definition: StlAlgorithms.hpp:775
std::uint32_t uint32
Definition: Types.hpp:40
NdArray< dtype > max(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: max.hpp:44