49 template<
typename dtype>
58 auto returnArray =
exp(inArray).template astype<double>();
59 returnArray /=
static_cast<double>(returnArray.sum().item());
64 auto returnArray =
exp(inArray).template astype<double>();
65 auto expSums = returnArray.sum(inAxis);
67 for (
uint32 row = 0; row < returnArray.shape().rows; ++row)
69 const auto rowExpSum =
static_cast<double>(expSums[row]);
72 [rowExpSum](
double& value) { value /= rowExpSum; });
79 auto returnArray =
exp(inArray.
transpose()).template astype<double>();
80 auto expSums = returnArray.sum(
Axis::COL);
82 for (
uint32 row = 0; row < returnArray.shape().rows; ++row)
84 const auto rowExpSum =
static_cast<double>(expSums[row]);
87 [rowExpSum](
double& value) { value /= rowExpSum; });
90 return returnArray.transpose();
#define THROW_INVALID_ARGUMENT_ERROR(msg)
Definition: Error.hpp:37
#define STATIC_ASSERT_ARITHMETIC(dtype)
Definition: StaticAsserts.hpp:39
self_type transpose() const
Definition: NdArrayCore.hpp:4882
Definition: airy_ai.hpp:39
NdArray< double > softmax(const NdArray< dtype > &inArray, Axis inAxis=Axis::NONE)
Definition: softmax.hpp:50
void for_each(InputIt first, InputIt last, UnaryFunction f)
Definition: StlAlgorithms.hpp:225
Axis
Enum To describe an axis.
Definition: Enums.hpp:36
auto exp(dtype inValue) noexcept
Definition: exp.hpp:49
std::uint32_t uint32
Definition: Types.hpp:40