NumCpp  2.12.1
A Templatized Header Only C++ Implementation of the Python NumPy Library
softmax.hpp
Go to the documentation of this file.
1
28#pragma once
29
32#include "NumCpp/Core/Types.hpp"
34#include "NumCpp/NdArray.hpp"
35
36namespace nc::special
37{
38 //============================================================================
39 // Method Description:
49 template<typename dtype>
51 {
53
54 switch (inAxis)
55 {
56 case Axis::NONE:
57 {
58 auto returnArray = exp(inArray).template astype<double>();
59 returnArray /= static_cast<double>(returnArray.sum().item());
60 return returnArray;
61 }
62 case Axis::COL:
63 {
64 auto returnArray = exp(inArray).template astype<double>();
65 auto expSums = returnArray.sum(inAxis);
66
67 for (uint32 row = 0; row < returnArray.shape().rows; ++row)
68 {
69 const auto rowExpSum = static_cast<double>(expSums[row]);
70 stl_algorithms::for_each(returnArray.begin(row),
71 returnArray.end(row),
72 [rowExpSum](double& value) { value /= rowExpSum; });
73 }
74
75 return returnArray;
76 }
77 case Axis::ROW:
78 {
79 auto returnArray = exp(inArray.transpose()).template astype<double>();
80 auto expSums = returnArray.sum(Axis::COL);
81
82 for (uint32 row = 0; row < returnArray.shape().rows; ++row)
83 {
84 const auto rowExpSum = static_cast<double>(expSums[row]);
85 stl_algorithms::for_each(returnArray.begin(row),
86 returnArray.end(row),
87 [rowExpSum](double& value) { value /= rowExpSum; });
88 }
89
90 return returnArray.transpose();
91 }
92 default:
93 {
94 THROW_INVALID_ARGUMENT_ERROR("Unimplemented axis type.");
95 return {}; // get rid of compiler warning
96 }
97 }
98 }
99} // namespace nc::special
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
self_type transpose() const
Definition: NdArrayCore.hpp:4882
Definition: airy_ai.hpp:39
NdArray< double > softmax(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: softmax.hpp:50
void for_each(InputIt first, InputIt last, UnaryFunction f)
Definition: StlAlgorithms.hpp:225
Axis
Enum To describe an axis.
Definition: Enums.hpp:36
auto exp(dtype inValue) noexcept
Definition: exp.hpp:49
std::uint32_t uint32
Definition: Types.hpp:40