19 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
20 defined(HWY_TARGET_TOGGLE)
21 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
24 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
55 template <
int kAssumptions,
class D,
typename T = TFromD<D>,
56 HWY_IF_NOT_LANE_SIZE_D(D, 2)>
59 const size_t num_elements) {
60 static_assert(IsFloat<T>(),
"MulAdd requires float type");
61 using V = decltype(
Zero(
d));
66 constexpr
bool kIsAtLeastOneVector =
68 constexpr
bool kIsMultipleOfVector =
70 constexpr
bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
75 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
80 for (; i + 2 <= num_elements; i += 2) {
81 sum0 += pa[i + 0] * pb[i + 0];
82 sum1 += pa[i + 1] * pb[i + 1];
84 if (i < num_elements) {
85 sum1 += pa[i] * pb[i];
101 for (; i + 4 *
N <= num_elements; ) {
102 const auto a0 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
103 const auto b0 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
105 sum0 =
MulAdd(a0, b0, sum0);
106 const auto a1 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
107 const auto b1 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
109 sum1 =
MulAdd(a1, b1, sum1);
110 const auto a2 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
111 const auto b2 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
113 sum2 =
MulAdd(a2, b2, sum2);
114 const auto a3 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
115 const auto b3 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
117 sum3 =
MulAdd(a3, b3, sum3);
121 for (; i +
N <= num_elements; i +=
N) {
122 const auto a = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
123 const auto b = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
124 sum0 =
MulAdd(a, b, sum0);
127 if (!kIsMultipleOfVector) {
128 const size_t remaining = num_elements - i;
129 if (remaining != 0) {
130 if (kIsPaddedToVector) {
131 const auto mask =
FirstN(
d, remaining);
132 const auto a = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
133 const auto b = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
141 const auto skip =
FirstN(
d,
N - remaining);
142 const auto a =
LoadU(
d, pa + i);
143 const auto b =
LoadU(
d, pb + i);
150 sum0 =
Add(sum0, sum1);
151 sum2 =
Add(sum2, sum3);
152 sum0 =
Add(sum0, sum2);
157 template <
int kAssumptions,
class D>
161 const size_t num_elements) {
165 using V = decltype(
Zero(df32));
169 constexpr
bool kIsAtLeastOneVector =
171 constexpr
bool kIsMultipleOfVector =
173 constexpr
bool kIsPaddedToVector = (kAssumptions &
kPaddedToVector) != 0;
178 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
182 for (; i + 2 <= num_elements; i += 2) {
186 if (i < num_elements) {
200 for (; i + 2 *
N <= num_elements; ) {
201 const auto a0 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
202 const auto b0 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
205 const auto a1 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
206 const auto b1 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
212 if (i +
N <= num_elements) {
213 const auto a0 = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
214 const auto b0 = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
219 if (!kIsMultipleOfVector) {
220 const size_t remaining = num_elements - i;
221 if (remaining != 0) {
222 if (kIsPaddedToVector) {
223 const auto mask =
FirstN(du16, remaining);
224 const auto va = kIsAlignedA ?
Load(
d, pa + i) :
LoadU(
d, pa + i);
225 const auto vb = kIsAlignedB ?
Load(
d, pb + i) :
LoadU(
d, pb + i);
236 const auto skip =
FirstN(du16,
N - remaining);
237 const auto va =
LoadU(
d, pa + i);
238 const auto vb =
LoadU(
d, pb + i);
247 sum0 =
Add(sum0, sum1);
248 sum2 =
Add(sum2, sum3);
249 sum0 =
Add(sum0, sum2);
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_INLINE
Definition: base.h:64
#define HWY_DASSERT(condition)
Definition: base.h:193
#define HWY_UNLIKELY(expr)
Definition: base.h:69
d
Definition: rvv-inl.h:1656
HWY_API uint8_t GetLane(const Vec128< uint8_t, 16 > v)
Definition: arm_neon-inl.h:767
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:1896
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1290
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N, 0 >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4437
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:201
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2205
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:733
HWY_API size_t Lanes(Simd< T, N, kPow2 > d)
Definition: arm_sve-inl.h:218
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N, 0 > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:3688
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:1711
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:5217
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2031
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:710
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:1718
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:207
N
Definition: rvv-inl.h:1656
Definition: aligned_allocator.h:27
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:746
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:57
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:158
Assumptions
Definition: dot-inl.h:37
@ kMultipleOfVector
Definition: dot-inl.h:42
@ kPaddedToVector
Definition: dot-inl.h:45
@ kVectorAlignedA
Definition: dot-inl.h:50
@ kAtLeastOneVector
Definition: dot-inl.h:39
@ kVectorAlignedB
Definition: dot-inl.h:51