NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
cross.hpp
Go to the documentation of this file.
1
28#pragma once
29
30#include <string>
31
34#include "NumCpp/Core/Shape.hpp"
35#include "NumCpp/Core/Types.hpp"
36#include "NumCpp/NdArray.hpp"
37
38namespace nc
39{
40 //============================================================================
41 // Method Description:
51 template<typename dtype>
52 NdArray<dtype> cross(const NdArray<dtype>& inArray1, const NdArray<dtype>& inArray2, Axis inAxis = Axis::NONE)
53 {
55
56 if (inArray1.shape() != inArray2.shape())
57 {
58 THROW_INVALID_ARGUMENT_ERROR("the input array dimensions are not consistant.");
59 }
60
61 switch (inAxis)
62 {
63 case Axis::NONE:
64 {
65 const uint32 arraySize = inArray1.size();
66 if (arraySize != inArray2.size() || arraySize < 2 || arraySize > 3)
67 {
69 "incompatible dimensions for cross product (dimension must be 2 or 3)");
70 }
71
72 NdArray<dtype> in1 = inArray1.flatten();
73 NdArray<dtype> in2 = inArray2.flatten();
74
75 switch (arraySize)
76 {
77 case 2:
78 {
79 NdArray<dtype> returnArray = { in1[0] * in2[1] - in1[1] * in2[0] };
80 return returnArray;
81 }
82 case 3:
83 {
84 dtype i = in1[1] * in2[2] - in1[2] * in2[1];
85 dtype j = -(in1[0] * in2[2] - in1[2] * in2[0]);
86 dtype k = in1[0] * in2[1] - in1[1] * in2[0];
87
88 NdArray<dtype> returnArray = { i, j, k };
89 return returnArray;
90 }
91 default:
92 {
93 THROW_INVALID_ARGUMENT_ERROR("Unimplemented array size.");
94 return {}; // get rid of compiler warning
95 }
96 }
97 }
98 case Axis::ROW:
99 {
100 const Shape arrayShape = inArray1.shape();
101 if (arrayShape != inArray2.shape() || arrayShape.rows < 2 || arrayShape.rows > 3)
102 {
104 "incompatible dimensions for cross product (dimension must be 2 or 3)");
105 }
106
107 Shape returnArrayShape;
108 returnArrayShape.cols = arrayShape.cols;
109 if (arrayShape.rows == 2)
110 {
111 returnArrayShape.rows = 1;
112 }
113 else
114 {
115 returnArrayShape.rows = 3;
116 }
117
118 NdArray<dtype> returnArray(returnArrayShape);
119 for (uint32 col = 0; col < arrayShape.cols; ++col)
120 {
121 const auto theCol = static_cast<int32>(col);
122 NdArray<dtype> vec1 = inArray1(inArray1.rSlice(), { theCol, theCol + 1 });
123 NdArray<dtype> vec2 = inArray2(inArray2.rSlice(), { theCol, theCol + 1 });
124 NdArray<dtype> vecCross = cross(vec1, vec2, Axis::NONE);
125
126 returnArray.put({ 0, static_cast<int32>(returnArrayShape.rows) }, { theCol, theCol + 1 }, vecCross);
127 }
128
129 return returnArray;
130 }
131 case Axis::COL:
132 {
133 const Shape arrayShape = inArray1.shape();
134 if (arrayShape != inArray2.shape() || arrayShape.cols < 2 || arrayShape.cols > 3)
135 {
137 "incompatible dimensions for cross product (dimension must be 2 or 3)");
138 }
139
140 Shape returnArrayShape;
141 returnArrayShape.rows = arrayShape.rows;
142 if (arrayShape.cols == 2)
143 {
144 returnArrayShape.cols = 1;
145 }
146 else
147 {
148 returnArrayShape.cols = 3;
149 }
150
151 NdArray<dtype> returnArray(returnArrayShape);
152 for (uint32 row = 0; row < arrayShape.rows; ++row)
153 {
154 const auto theRow = static_cast<int32>(row);
155 NdArray<dtype> vec1 = inArray1({ theRow, theRow + 1 }, inArray1.cSlice());
156 NdArray<dtype> vec2 = inArray2({ theRow, theRow + 1 }, inArray2.cSlice());
157 NdArray<dtype> vecCross = cross(vec1, vec2, Axis::NONE);
158
159 returnArray.put({ theRow, theRow + 1 }, { 0, static_cast<int32>(returnArrayShape.cols) }, vecCross);
160 }
161
162 return returnArray;
163 }
164 default:
165 {
166 THROW_INVALID_ARGUMENT_ERROR("Unimplemented axis type.");
167 return {}; // get rid of compiler warning
168 }
169 }
170 }
171} // namespace nc
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC_OR_COMPLEX(dtype)
Definition: StaticAsserts.hpp:56
Holds 1D and 2D arrays, the main work horse of the NumCpp library.
Definition: NdArrayCore.hpp:139
size_type size() const noexcept
Definition: NdArrayCore.hpp:4524
const Shape & shape() const noexcept
Definition: NdArrayCore.hpp:4511
self_type flatten() const
Definition: NdArrayCore.hpp:2847
Slice rSlice(index_type inStartIdx=0, size_type inStepSize=1) const
Definition: NdArrayCore.hpp:1022
Slice cSlice(index_type inStartIdx=0, size_type inStepSize=1) const
Definition: NdArrayCore.hpp:1008
self_type & put(index_type inIndex, const value_type &inValue)
Definition: NdArrayCore.hpp:3693
A Shape Class for NdArrays.
Definition: Core/Shape.hpp:41
uint32 rows
Definition: Core/Shape.hpp:44
uint32 cols
Definition: Core/Shape.hpp:45
constexpr auto j
Definition: Core/Constants.hpp:42
Definition: Cartesian.hpp:40
Axis
Enum To describe an axis.
Definition: Enums.hpp:36
std::int32_t int32
Definition: Types.hpp:36
NdArray< dtype > cross(const NdArray< dtype > &inArray1, const NdArray< dtype > &inArray2, Axis inAxis=Axis::NONE)
Definition: cross.hpp:52
std::uint32_t uint32
Definition: Types.hpp:40