Grok  9.7.5
dot-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Include guard (still compiled once per target)
17 #include <cmath>
18 
19 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
20  defined(HWY_TARGET_TOGGLE)
21 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
23 #else
24 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
25 #endif
26 
27 #include "hwy/highway.h"
28 
30 namespace hwy {
31 namespace HWY_NAMESPACE {
32 
33 struct Dot {
34  // Specify zero or more of these, ORed together, as the kAssumptions template
35  // argument to Compute. Each one may improve performance or reduce code size,
36  // at the cost of additional requirements on the arguments.
37  enum Assumptions {
38  // num_elements is at least N, which may be up to HWY_MAX_LANES(T).
40  // num_elements is divisible by N (a power of two, so this can be used if
41  // the problem size is known to be a power of two >= HWY_MAX_LANES(T)).
43  // RoundUpTo(num_elements, N) elements are accessible; their value does not
44  // matter (will be treated as if they were zero).
46  // Pointers pa and pb, respectively, are multiples of N * sizeof(T).
47  // For example, aligned_allocator.h ensures this. Note that it is still
48  // beneficial to ensure such alignment even if these flags are not set.
49  // If not set, the pointers need only be aligned to alignof(T).
52  };
53 
54  // Returns sum{pa[i] * pb[i]} for float or double inputs.
55  template <int kAssumptions, class D, typename T = TFromD<D>,
56  HWY_IF_NOT_LANE_SIZE_D(D, 2)>
57  static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
58  const T* const HWY_RESTRICT pb,
59  const size_t num_elements) {
60  static_assert(IsFloat<T>(), "MulAdd requires float type");
61  using V = decltype(Zero(d));
62 
63  const size_t N = Lanes(d);
64  size_t i = 0;
65 
66  constexpr bool kIsAtLeastOneVector =
67  (kAssumptions & kAtLeastOneVector) != 0;
68  constexpr bool kIsMultipleOfVector =
69  (kAssumptions & kMultipleOfVector) != 0;
70  constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
71  constexpr bool kIsAlignedA = (kAssumptions & kVectorAlignedA) != 0;
72  constexpr bool kIsAlignedB = (kAssumptions & kVectorAlignedB) != 0;
73 
74  // Won't be able to do a full vector load without padding => scalar loop.
75  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
76  HWY_UNLIKELY(num_elements < N)) {
77  // Only 2x unroll to avoid excessive code size.
78  T sum0 = T(0);
79  T sum1 = T(0);
80  for (; i + 2 <= num_elements; i += 2) {
81  sum0 += pa[i + 0] * pb[i + 0];
82  sum1 += pa[i + 1] * pb[i + 1];
83  }
84  if (i < num_elements) {
85  sum1 += pa[i] * pb[i];
86  }
87  return sum0 + sum1;
88  }
89 
90  // Compiler doesn't make independent sum* accumulators, so unroll manually.
91  // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
92  // for unaligned inputs (each unaligned pointer halves the throughput
93  // because it occupies both L1 load ports for a cycle). We cannot have
94  // arrays of vectors on RVV/SVE, so always unroll 4x.
95  V sum0 = Zero(d);
96  V sum1 = Zero(d);
97  V sum2 = Zero(d);
98  V sum3 = Zero(d);
99 
100  // Main loop: unrolled
101  for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
102  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
103  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
104  i += N;
105  sum0 = MulAdd(a0, b0, sum0);
106  const auto a1 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
107  const auto b1 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
108  i += N;
109  sum1 = MulAdd(a1, b1, sum1);
110  const auto a2 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
111  const auto b2 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
112  i += N;
113  sum2 = MulAdd(a2, b2, sum2);
114  const auto a3 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
115  const auto b3 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
116  i += N;
117  sum3 = MulAdd(a3, b3, sum3);
118  }
119 
120  // Up to 3 iterations of whole vectors
121  for (; i + N <= num_elements; i += N) {
122  const auto a = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
123  const auto b = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
124  sum0 = MulAdd(a, b, sum0);
125  }
126 
127  if (!kIsMultipleOfVector) {
128  const size_t remaining = num_elements - i;
129  if (remaining != 0) {
130  if (kIsPaddedToVector) {
131  const auto mask = FirstN(d, remaining);
132  const auto a = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
133  const auto b = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
134  sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
135  } else {
136  // Unaligned load such that the last element is in the highest lane -
137  // ensures we do not touch any elements outside the valid range.
138  // If we get here, then num_elements >= N.
139  HWY_DASSERT(i >= N);
140  i += remaining - N;
141  const auto skip = FirstN(d, N - remaining);
142  const auto a = LoadU(d, pa + i); // always unaligned
143  const auto b = LoadU(d, pb + i);
144  sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
145  }
146  }
147  } // kMultipleOfVector
148 
149  // Reduction tree: sum of all accumulators by pairs, then across lanes.
150  sum0 = Add(sum0, sum1);
151  sum2 = Add(sum2, sum3);
152  sum0 = Add(sum0, sum2);
153  return GetLane(SumOfLanes(d, sum0));
154  }
155 
156  // Returns sum{pa[i] * pb[i]} for bfloat16 inputs.
157  template <int kAssumptions, class D>
158  static HWY_INLINE float Compute(const D d,
159  const bfloat16_t* const HWY_RESTRICT pa,
160  const bfloat16_t* const HWY_RESTRICT pb,
161  const size_t num_elements) {
162  const RebindToUnsigned<D> du16;
163  const Repartition<float, D> df32;
164 
165  using V = decltype(Zero(df32));
166  const size_t N = Lanes(d);
167  size_t i = 0;
168 
169  constexpr bool kIsAtLeastOneVector =
170  (kAssumptions & kAtLeastOneVector) != 0;
171  constexpr bool kIsMultipleOfVector =
172  (kAssumptions & kMultipleOfVector) != 0;
173  constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
174  constexpr bool kIsAlignedA = (kAssumptions & kVectorAlignedA) != 0;
175  constexpr bool kIsAlignedB = (kAssumptions & kVectorAlignedB) != 0;
176 
177  // Won't be able to do a full vector load without padding => scalar loop.
178  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
179  HWY_UNLIKELY(num_elements < N)) {
180  float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
181  float sum1 = 0.0f; // this unlikely(?) case.
182  for (; i + 2 <= num_elements; i += 2) {
183  sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
184  sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
185  }
186  if (i < num_elements) {
187  sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
188  }
189  return sum0 + sum1;
190  }
191 
192  // See comment in the other Compute() overload. Unroll 2x, but we need
193  // twice as many sums for ReorderWidenMulAccumulate.
194  V sum0 = Zero(df32);
195  V sum1 = Zero(df32);
196  V sum2 = Zero(df32);
197  V sum3 = Zero(df32);
198 
199  // Main loop: unrolled
200  for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
201  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
202  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
203  i += N;
204  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
205  const auto a1 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
206  const auto b1 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
207  i += N;
208  sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
209  }
210 
211  // Possibly one more iteration of whole vectors
212  if (i + N <= num_elements) {
213  const auto a0 = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
214  const auto b0 = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
215  i += N;
216  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
217  }
218 
219  if (!kIsMultipleOfVector) {
220  const size_t remaining = num_elements - i;
221  if (remaining != 0) {
222  if (kIsPaddedToVector) {
223  const auto mask = FirstN(du16, remaining);
224  const auto va = kIsAlignedA ? Load(d, pa + i) : LoadU(d, pa + i);
225  const auto vb = kIsAlignedB ? Load(d, pb + i) : LoadU(d, pb + i);
226  const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
227  const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
228  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
229 
230  } else {
231  // Unaligned load such that the last element is in the highest lane -
232  // ensures we do not touch any elements outside the valid range.
233  // If we get here, then num_elements >= N.
234  HWY_DASSERT(i >= N);
235  i += remaining - N;
236  const auto skip = FirstN(du16, N - remaining);
237  const auto va = LoadU(d, pa + i); // always unaligned
238  const auto vb = LoadU(d, pb + i);
239  const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
240  const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
241  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
242  }
243  }
244  } // kMultipleOfVector
245 
246  // Reduction tree: sum of all accumulators by pairs, then across lanes.
247  sum0 = Add(sum0, sum1);
248  sum2 = Add(sum2, sum3);
249  sum0 = Add(sum0, sum2);
250  return GetLane(SumOfLanes(df32, sum0));
251  }
252 };
253 
254 // NOLINTNEXTLINE(google-readability-namespace-comments)
255 } // namespace HWY_NAMESPACE
256 } // namespace hwy
258 
259 #endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_INLINE
Definition: base.h:64
#define HWY_DASSERT(condition)
Definition: base.h:193
#define HWY_UNLIKELY(expr)
Definition: base.h:69
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
d
Definition: rvv-inl.h:1656
HWY_API uint8_t GetLane(const Vec128< uint8_t, 16 > v)
Definition: arm_neon-inl.h:767
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:1896
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1290
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N, 0 >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4437
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:201
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2205
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:733
HWY_API size_t Lanes(Simd< T, N, kPow2 > d)
Definition: arm_sve-inl.h:218
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N, 0 > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:3688
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:1711
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:5217
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2031
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:710
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:1718
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:207
N
Definition: rvv-inl.h:1656
Definition: aligned_allocator.h:27
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:746
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
Definition: dot-inl.h:33
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:57
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:158
Assumptions
Definition: dot-inl.h:37
@ kMultipleOfVector
Definition: dot-inl.h:42
@ kPaddedToVector
Definition: dot-inl.h:45
@ kVectorAlignedA
Definition: dot-inl.h:50
@ kAtLeastOneVector
Definition: dot-inl.h:39
@ kVectorAlignedB
Definition: dot-inl.h:51
Definition: base.h:253