Grok  9.7.5
traits-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Per-target
17 #if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \
18  defined(HWY_TARGET_TOGGLE)
19 #ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
20 #undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
21 #else
22 #define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
23 #endif
24 
26 #include "hwy/contrib/sort/shared-inl.h" // SortConstants
27 #include "hwy/contrib/sort/vqsort.h" // SortDescending
28 #include "hwy/highway.h"
29 
31 namespace hwy {
32 namespace HWY_NAMESPACE {
33 namespace detail {
34 
35 // Highway does not provide a lane type for 128-bit keys, so we use uint64_t
36 // along with an abstraction layer for single-lane vs. lane-pair, which is
37 // independent of the order.
38 struct KeyLane {
39  constexpr size_t LanesPerKey() const { return 1; }
40 
41  // For HeapSort
42  template <typename T>
43  HWY_INLINE void Swap(T* a, T* b) const {
44  const T temp = *a;
45  *a = *b;
46  *b = temp;
47  }
48 
49  // Broadcasts one key into a vector
50  template <class D>
51  HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
52  return Set(d, *key);
53  }
54 
55  template <class D>
57  return Reverse(d, v);
58  }
59 
60  template <class D>
62  return Reverse2(d, v);
63  }
64 
65  template <class D>
67  return Reverse4(d, v);
68  }
69 
70  template <class D>
72  return Reverse8(d, v);
73  }
74 
75  template <class D>
77  static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit");
78  return ReverseKeys(d, v);
79  }
80 
81  template <class V>
82  HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
83  return OddEven(odd, even);
84  }
85 
86  template <class D, HWY_IF_LANE_SIZE_D(D, 2)>
88  const Repartition<uint32_t, D> du32;
89  return BitCast(d, Shuffle2301(BitCast(du32, v)));
90  }
91  template <class D, HWY_IF_LANE_SIZE_D(D, 4)>
92  HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
93  return Shuffle1032(v);
94  }
95  template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
96  HWY_INLINE Vec<D> SwapAdjacentPairs(D /* tag */, const Vec<D> v) const {
97  return SwapAdjacentBlocks(v);
98  }
99 
100  template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
102 #if HWY_HAVE_FLOAT64 // in case D is float32
103  const RepartitionToWide<D> dw;
104 #else
106 #endif
107  return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v)));
108  }
109  template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
111  // Assumes max vector size = 512
112  return ConcatLowerUpper(d, v, v);
113  }
114 
115  template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
117  const Vec<D> even) const {
118 #if HWY_HAVE_FLOAT64 // in case D is float32
119  const RepartitionToWide<D> dw;
120 #else
122 #endif
123  return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even)));
124  }
125  template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
126  HWY_INLINE Vec<D> OddEvenPairs(D /* tag */, Vec<D> odd, Vec<D> even) const {
127  return OddEvenBlocks(odd, even);
128  }
129 
130  template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
131  HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
132 #if HWY_HAVE_FLOAT64 // in case D is float32
133  const RepartitionToWide<D> dw;
134 #else
136 #endif
137  return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even)));
138  }
139  template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
140  HWY_INLINE Vec<D> OddEvenQuads(D d, Vec<D> odd, Vec<D> even) const {
141  return ConcatUpperLower(d, odd, even);
142  }
143 };
144 
145 // Anything order-related depends on the key traits *and* the order (see
146 // FirstOfLanes). We cannot implement just one Compare function because Lt128
147 // only compiles if the lane type is u64. Thus we need either overloaded
148 // functions with a tag type, class specializations, or separate classes.
149 // We avoid overloaded functions because we want all functions to be callable
150 // from a SortTraits without per-function wrappers. Specializing would work, but
151 // we are anyway going to specialize at a higher level.
152 struct OrderAscending : public KeyLane {
154 
155  template <typename T>
156  HWY_INLINE bool Compare1(const T* a, const T* b) {
157  return *a < *b;
158  }
159 
160  template <class D>
161  HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
162  return Lt(a, b);
163  }
164 
165  // Two halves of Sort2, used in ScanMinMax.
166  template <class D>
167  HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
168  return Min(a, b);
169  }
170 
171  template <class D>
172  HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
173  return Max(a, b);
174  }
175 
176  template <class D>
178  TFromD<D>* HWY_RESTRICT /* buf */) const {
179  return MinOfLanes(d, v);
180  }
181 
182  template <class D>
184  TFromD<D>* HWY_RESTRICT /* buf */) const {
185  return MaxOfLanes(d, v);
186  }
187 
188  template <class D>
190  return Set(d, hwy::LowestValue<TFromD<D>>());
191  }
192 
193  template <class D>
195  return Set(d, hwy::HighestValue<TFromD<D>>());
196  }
197 };
198 
199 struct OrderDescending : public KeyLane {
201 
202  template <typename T>
203  HWY_INLINE bool Compare1(const T* a, const T* b) {
204  return *b < *a;
205  }
206 
207  template <class D>
208  HWY_INLINE Mask<D> Compare(D /* tag */, Vec<D> a, Vec<D> b) const {
209  return Lt(b, a);
210  }
211 
212  template <class D>
213  HWY_INLINE Vec<D> First(D /* tag */, const Vec<D> a, const Vec<D> b) const {
214  return Max(a, b);
215  }
216 
217  template <class D>
218  HWY_INLINE Vec<D> Last(D /* tag */, const Vec<D> a, const Vec<D> b) const {
219  return Min(a, b);
220  }
221 
222  template <class D>
224  TFromD<D>* HWY_RESTRICT /* buf */) const {
225  return MaxOfLanes(d, v);
226  }
227 
228  template <class D>
230  TFromD<D>* HWY_RESTRICT /* buf */) const {
231  return MinOfLanes(d, v);
232  }
233 
234  template <class D>
236  return Set(d, hwy::HighestValue<TFromD<D>>());
237  }
238 
239  template <class D>
241  return Set(d, hwy::LowestValue<TFromD<D>>());
242  }
243 };
244 
245 // Shared code that depends on Order.
246 template <class Base>
247 struct TraitsLane : public Base {
248  constexpr bool Is128() const { return false; }
249 
250  // For each lane i: replaces a[i] with the first and b[i] with the second
251  // according to Base.
252  // Corresponds to a conditional swap, which is one "node" of a sorting
253  // network. Min/Max are cheaper than compare + blend at least for integers.
254  template <class D>
255  HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
256  const Base* base = static_cast<const Base*>(this);
257 
258  const Vec<D> a_copy = a;
259  // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4
260  // instructions. We can reduce it to a compare + 2 IfThenElse.
261 #if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
262  if (sizeof(TFromD<D>) == 8) {
263  const Mask<D> cmp = base->Compare(d, a, b);
264  a = IfThenElse(cmp, a, b);
265  b = IfThenElse(cmp, b, a_copy);
266  return;
267  }
268 #endif
269  a = base->First(d, a, b);
270  b = base->Last(d, a_copy, b);
271  }
272 
273  // Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
274  template <class D, HWY_IF_LANE_SIZE_D(D, 8)>
276  const Base* base = static_cast<const Base*>(this);
277  Vec<D> swapped = base->ReverseKeys2(d, v);
278  // Further to the above optimization, Sort2+OddEvenKeys compile to four
279  // instructions; we can save one by combining two blends.
280 #if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3
281  const Vec<D> cmp = VecFromMask(d, base->Compare(d, v, swapped));
282  return IfVecThenElse(DupOdd(cmp), swapped, v);
283 #else
284  Sort2(d, v, swapped);
285  return base->OddEvenKeys(swapped, v);
286 #endif
287  }
288 
289  // (See above - we use Sort2 for non-64-bit types.)
290  template <class D, HWY_IF_NOT_LANE_SIZE_D(D, 8)>
292  const Base* base = static_cast<const Base*>(this);
293  Vec<D> swapped = base->ReverseKeys2(d, v);
294  Sort2(d, v, swapped);
295  return base->OddEvenKeys(swapped, v);
296  }
297 
298  // Swaps with the vector formed by reversing contiguous groups of 4 keys.
299  template <class D>
301  const Base* base = static_cast<const Base*>(this);
302  Vec<D> swapped = base->ReverseKeys4(d, v);
303  Sort2(d, v, swapped);
304  return base->OddEvenPairs(d, swapped, v);
305  }
306 
307  // Conditionally swaps lane 0 with 4, 1 with 5 etc.
308  template <class D>
310  const Base* base = static_cast<const Base*>(this);
311  Vec<D> swapped = base->SwapAdjacentQuads(d, v);
312  // Only used in Merge16, so this will not be used on AVX2 (which only has 4
313  // u64 lanes), so skip the above optimization for 64-bit AVX2.
314  Sort2(d, v, swapped);
315  return base->OddEvenQuads(d, swapped, v);
316  }
317 };
318 
319 } // namespace detail
320 // NOLINTNEXTLINE(google-readability-namespace-comments)
321 } // namespace HWY_NAMESPACE
322 } // namespace hwy
324 
325 #endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_INLINE
Definition: base.h:64
HWY_INLINE Vec128< T, N > OddEven(hwy::SizeTag< 1 >, const Vec128< T, N > a, const Vec128< T, N > b)
Definition: wasm_128-inl.h:2568
HWY_INLINE Vec128< T, 1 > MinOfLanes(hwy::SizeTag< sizeof(T)>, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:4309
HWY_INLINE Vec128< T, 1 > MaxOfLanes(hwy::SizeTag< sizeof(T)>, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:4314
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:680
d
Definition: rvv-inl.h:1656
HWY_API Vec128< T, N > OddEvenBlocks(Vec128< T, N >, Vec128< T, N > even)
Definition: arm_neon-inl.h:4038
HWY_API Vec128< T, N > DupOdd(Vec128< T, N > v)
Definition: arm_neon-inl.h:4003
HWY_API Vec128< T > Shuffle1032(const Vec128< T > v)
Definition: arm_neon-inl.h:3531
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5252
Repartition< MakeWide< TFromD< D > >, D > RepartitionToWide
Definition: ops/shared-inl.h:210
HWY_API Vec128< uint64_t, N > Min(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:1957
HWY_API Vec128< uint64_t, N > Max(const Vec128< uint64_t, N > a, const Vec128< uint64_t, N > b)
Definition: arm_neon-inl.h:1995
HWY_API Vec128< T, N > ConcatLowerUpper(const Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:3869
HWY_API Vec128< T, N > IfVecThenElse(Vec128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: arm_neon-inl.h:1505
HWY_API Vec128< T, N > VecFromMask(Simd< T, N, 0 > d, const Mask128< T, N > v)
Definition: arm_neon-inl.h:1681
HWY_API Vec128< T, N > Reverse4(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:3490
HWY_API Vec128< T, N > SwapAdjacentBlocks(Vec128< T, N > v)
Definition: arm_neon-inl.h:4045
HWY_API Vec128< T, N > Reverse2(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:3461
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:282
HWY_API Vec128< T, N > Reverse8(Simd< T, N, 0 > d, const Vec128< T, N > v)
Definition: arm_neon-inl.h:3513
HWY_API Vec128< T, N > ConcatUpperLower(Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:3895
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:710
HWY_API Vec64< uint32_t > Shuffle2301(const Vec64< uint32_t > v)
Definition: arm_neon-inl.h:1778
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:207
decltype(MaskFromVec(Zero(D()))) Mask
Definition: generic_ops-inl.h:38
HWY_API Vec128< T, 1 > Reverse(Simd< T, 1, 0 >, const Vec128< T, 1 > v)
Definition: arm_neon-inl.h:3430
const vfloat64m1_t v
Definition: rvv-inl.h:1656
typename D::T TFromD
Definition: ops/shared-inl.h:192
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
constexpr HWY_API T LowestValue()
Definition: base.h:512
constexpr HWY_API T HighestValue()
Definition: base.h:525
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
Definition: traits-inl.h:38
HWY_INLINE Vec< D > ReverseKeys2(D d, Vec< D > v) const
Definition: traits-inl.h:61
HWY_INLINE V OddEvenKeys(const V odd, const V even) const
Definition: traits-inl.h:82
constexpr size_t LanesPerKey() const
Definition: traits-inl.h:39
HWY_INLINE Vec< D > OddEvenQuads(D d, Vec< D > odd, Vec< D > even) const
Definition: traits-inl.h:131
HWY_INLINE Vec< D > SwapAdjacentQuads(D d, const Vec< D > v) const
Definition: traits-inl.h:101
HWY_INLINE Vec< D > ReverseKeys(D d, Vec< D > v) const
Definition: traits-inl.h:56
HWY_INLINE Vec< D > SetKey(D d, const TFromD< D > *key) const
Definition: traits-inl.h:51
HWY_INLINE void Swap(T *a, T *b) const
Definition: traits-inl.h:43
HWY_INLINE Vec< D > ReverseKeys8(D d, Vec< D > v) const
Definition: traits-inl.h:71
HWY_INLINE Vec< D > OddEvenPairs(D d, const Vec< D > odd, const Vec< D > even) const
Definition: traits-inl.h:116
HWY_INLINE Vec< D > SwapAdjacentPairs(D d, const Vec< D > v) const
Definition: traits-inl.h:87
HWY_INLINE Vec< D > SwapAdjacentPairs(D, const Vec< D > v) const
Definition: traits-inl.h:92
HWY_INLINE Vec< D > OddEvenPairs(D, Vec< D > odd, Vec< D > even) const
Definition: traits-inl.h:126
HWY_INLINE Vec< D > ReverseKeys4(D d, Vec< D > v) const
Definition: traits-inl.h:66
HWY_INLINE Vec< D > ReverseKeys16(D d, Vec< D > v) const
Definition: traits-inl.h:76
Definition: traits-inl.h:152
HWY_INLINE Vec< D > LastOfLanes(D d, Vec< D > v, TFromD< D > *HWY_RESTRICT) const
Definition: traits-inl.h:183
HWY_INLINE Vec< D > Last(D, const Vec< D > a, const Vec< D > b) const
Definition: traits-inl.h:172
HWY_INLINE Vec< D > First(D, const Vec< D > a, const Vec< D > b) const
Definition: traits-inl.h:167
HWY_INLINE Vec< D > LastValue(D d) const
Definition: traits-inl.h:194
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits-inl.h:156
HWY_INLINE Vec< D > FirstValue(D d) const
Definition: traits-inl.h:189
HWY_INLINE Mask< D > Compare(D, Vec< D > a, Vec< D > b) const
Definition: traits-inl.h:161
HWY_INLINE Vec< D > FirstOfLanes(D d, Vec< D > v, TFromD< D > *HWY_RESTRICT) const
Definition: traits-inl.h:177
HWY_INLINE bool Compare1(const T *a, const T *b)
Definition: traits-inl.h:203
HWY_INLINE Mask< D > Compare(D, Vec< D > a, Vec< D > b) const
Definition: traits-inl.h:208
HWY_INLINE Vec< D > First(D, const Vec< D > a, const Vec< D > b) const
Definition: traits-inl.h:213
HWY_INLINE Vec< D > Last(D, const Vec< D > a, const Vec< D > b) const
Definition: traits-inl.h:218
HWY_INLINE Vec< D > FirstValue(D d) const
Definition: traits-inl.h:235
HWY_INLINE Vec< D > FirstOfLanes(D d, Vec< D > v, TFromD< D > *HWY_RESTRICT) const
Definition: traits-inl.h:223
HWY_INLINE Vec< D > LastOfLanes(D d, Vec< D > v, TFromD< D > *HWY_RESTRICT) const
Definition: traits-inl.h:229
HWY_INLINE Vec< D > LastValue(D d) const
Definition: traits-inl.h:240
Definition: traits-inl.h:247
HWY_INLINE Vec< D > SortPairsReverse4(D d, Vec< D > v) const
Definition: traits-inl.h:300
HWY_INLINE void Sort2(D d, Vec< D > &a, Vec< D > &b) const
Definition: traits-inl.h:255
constexpr bool Is128() const
Definition: traits-inl.h:248
HWY_INLINE Vec< D > SortPairsDistance1(D d, Vec< D > v) const
Definition: traits-inl.h:275
HWY_INLINE Vec< D > SortPairsDistance4(D d, Vec< D > v) const
Definition: traits-inl.h:309
Definition: vqsort.h:35
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
Definition: vqsort.h:38
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()