17 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
22 #ifndef VQSORT_SECURE_RNG
23 #define VQSORT_SECURE_RNG 0
27 #include "third_party/absl/random/random.h"
37 #include <sanitizer/msan_interface.h>
43 #if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
44 defined(HWY_TARGET_TOGGLE)
45 #ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
46 #undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
48 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
60 #if HWY_TARGET == HWY_SCALAR
70 template <
class Traits,
typename T>
75 for (
size_t i = 1; i < num; i += 1) {
78 const size_t idx_parent = ((j - 1) / 1 / 2);
79 if (!st.Compare1(keys + idx_parent, keys + j)) {
82 Swap(keys + j, keys + idx_parent);
87 for (
size_t i = num - 1; i != 0; i -= 1) {
89 Swap(keys + 0, keys + i);
94 const size_t left = 2 * j + 1;
95 const size_t right = 2 * j + 2;
97 size_t idx_larger = j;
98 if (st.Compare1(keys + j, keys + left)) {
101 if (right < i && st.Compare1(keys + idx_larger, keys + right)) {
104 if (idx_larger == j)
break;
105 Swap(keys + j, keys + idx_larger);
119 template <
class Traits,
typename T>
121 constexpr
size_t N1 = st.LanesPerKey();
124 if (num < 2 * N1)
return;
127 for (
size_t i = N1; i < num; i += N1) {
130 const size_t idx_parent = ((j - N1) / N1 / 2) * N1;
131 if (
AllFalse(
d, st.Compare(
d, st.SetKey(
d, keys + idx_parent),
132 st.SetKey(
d, keys + j)))) {
135 st.Swap(keys + j, keys + idx_parent);
140 for (
size_t i = num - N1; i != 0; i -= N1) {
142 st.Swap(keys + 0, keys + i);
147 const size_t left = 2 * j + N1;
148 const size_t right = 2 * j + 2 * N1;
149 if (left >= i)
break;
150 size_t idx_larger = j;
151 const auto key_j = st.SetKey(
d, keys + j);
152 if (
AllTrue(
d, st.Compare(
d, key_j, st.SetKey(
d, keys + left)))) {
155 if (right < i &&
AllTrue(
d, st.Compare(
d, st.SetKey(
d, keys + idx_larger),
156 st.SetKey(
d, keys + right)))) {
159 if (idx_larger == j)
break;
160 st.Swap(keys + j, keys + idx_larger);
169 template <
class D,
class Traits,
typename T>
173 using V = decltype(
Zero(
d));
180 const size_t num_pow2 =
size_t{1}
182 static_cast<uint32_t
>(num - 1)));
190 for (i = 0; i +
N <= num; i +=
N) {
197 const V kPadding = st.LastValue(
d);
206 for (i = 0; i +
N <= num; i +=
N) {
216 template <
class D,
class Traits,
class T>
217 HWY_NOINLINE void PartitionToMultipleOfUnroll(D
d, Traits st,
219 size_t& left,
size_t& right,
226 const size_t num = right - left;
229 const size_t num_rem =
230 (num < 2 * kUnroll *
N) ? num : (num & (kUnroll *
N - 1));
232 for (; i +
N <= num_rem; i +=
N) {
233 const Vec<D> vL =
LoadU(
d, keys + readL);
236 const auto comp = st.Compare(
d, pivot, vL);
242 const auto mask =
FirstN(
d, num_rem - i);
243 const Vec<D> vL =
LoadU(
d, keys + readL);
245 const auto comp = st.Compare(
d, pivot, vL);
252 __msan_unpoison(buf, bufR *
sizeof(T));
262 memcpy(keys + left, keys + right, bufR *
sizeof(T));
263 memcpy(keys + right, buf, bufR *
sizeof(T));
266 template <
class D,
class Traits,
typename T>
267 HWY_INLINE void StoreLeftRight(D
d, Traits st,
const Vec<D>
v,
269 size_t& writeL,
size_t& writeR) {
272 const auto comp = st.Compare(
d, pivot,
v);
281 const auto mask =
Not(comp);
289 writeR -= (
N - num_left);
296 writeR -= (
N - num_left);
301 template <
class D,
class Traits,
typename T>
302 HWY_INLINE void StoreLeftRight4(D
d, Traits st,
const Vec<D> v0,
303 const Vec<D> v1,
const Vec<D> v2,
304 const Vec<D> v3,
const Vec<D> pivot,
307 StoreLeftRight(
d, st, v0, pivot, keys, writeL, writeR);
308 StoreLeftRight(
d, st, v1, pivot, keys, writeL, writeR);
309 StoreLeftRight(
d, st, v2, pivot, keys, writeL, writeR);
310 StoreLeftRight(
d, st, v3, pivot, keys, writeL, writeR);
317 template <
class D,
class Traits,
typename T>
319 size_t right,
const Vec<D> pivot,
321 using V = decltype(
Zero(
d));
329 const size_t last = right;
330 const V vlast =
LoadU(
d, keys + last);
332 PartitionToMultipleOfUnroll(
d, st, keys, left, right, pivot, buf);
336 size_t writeL = left;
337 size_t writeR = right;
339 const size_t num = right - left;
346 const V vL0 =
LoadU(
d, keys + left + 0 *
N);
347 const V vL1 =
LoadU(
d, keys + left + 1 *
N);
348 const V vL2 =
LoadU(
d, keys + left + 2 *
N);
349 const V vL3 =
LoadU(
d, keys + left + 3 *
N);
351 right -= kUnroll *
N;
352 const V vR0 =
LoadU(
d, keys + right + 0 *
N);
353 const V vR1 =
LoadU(
d, keys + right + 1 *
N);
354 const V vR2 =
LoadU(
d, keys + right + 2 *
N);
355 const V vR3 =
LoadU(
d, keys + right + 3 *
N);
358 while (left != right) {
363 const size_t capacityL = left - writeL;
364 const size_t capacityR = writeR - right;
366 if (capacityR < capacityL) {
367 right -= kUnroll *
N;
368 v0 =
LoadU(
d, keys + right + 0 *
N);
369 v1 =
LoadU(
d, keys + right + 1 *
N);
370 v2 =
LoadU(
d, keys + right + 2 *
N);
371 v3 =
LoadU(
d, keys + right + 3 *
N);
374 v0 =
LoadU(
d, keys + left + 0 *
N);
375 v1 =
LoadU(
d, keys + left + 1 *
N);
376 v2 =
LoadU(
d, keys + left + 2 *
N);
377 v3 =
LoadU(
d, keys + left + 3 *
N);
382 StoreLeftRight4(
d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR);
386 StoreLeftRight4(
d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR);
387 StoreLeftRight4(
d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR);
396 const size_t totalR = last - writeL;
397 const size_t startR = totalR <
N ? writeL + totalR -
N : writeL;
401 const auto comp = st.Compare(
d, pivot, vlast);
410 template <
class Traits,
class V>
411 HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
417 const auto sum =
Xor(
Xor(v0, v1), v2);
418 const auto first = st.First(
d, st.First(
d, v0, v1), v2);
419 const auto last = st.Last(
d, st.Last(
d, v0, v1), v2);
420 return Xor(
Xor(sum, first), last);
423 v1 = st.Last(
d, v0, v1);
424 v1 = st.First(
d, v1, v2);
430 template <
class D,
class Traits,
typename T>
431 Vec<D> RecursiveMedianOf3(D
d, Traits st, T*
HWY_RESTRICT keys,
size_t num,
434 constexpr
size_t N1 = st.LanesPerKey();
436 if (num < 3 * N1)
return st.SetKey(
d, keys);
442 for (; read + 3 *
N <= num; read += 3 *
N) {
443 const auto v0 =
Load(
d, keys + read + 0 *
N);
444 const auto v1 =
Load(
d, keys + read + 1 *
N);
445 const auto v2 =
Load(
d, keys + read + 2 *
N);
446 Store(MedianOf3(st, v0, v1, v2),
d, buf + written);
451 for (; read + 3 * N1 <= num; read += 3 * N1) {
452 const auto v0 = st.SetKey(
d, keys + read + 0 * N1);
453 const auto v1 = st.SetKey(
d, keys + read + 1 * N1);
454 const auto v2 = st.SetKey(
d, keys + read + 2 * N1);
455 StoreU(MedianOf3(st, v0, v1, v2),
d, buf + written);
460 return RecursiveMedianOf3(
d, st, buf, written, keys);
463 #if VQSORT_SECURE_RNG
464 using Generator = absl::BitGen;
467 #pragma pack(push, 1)
470 Generator(
const void* heap,
size_t num) {
475 uint64_t operator()() {
476 const uint64_t b = b_;
478 const uint64_t next = a_ ^ w_;
479 a_ = (b + (b << 3)) ^ (b >> 11);
480 const uint64_t rot = (b << 24) | (b >> 40);
497 HWY_INLINE size_t RandomChunkIndex(
const uint32_t num_chunks, uint32_t bits) {
498 const uint64_t chunk_index = (
static_cast<uint64_t
>(bits) * num_chunks) >> 32;
500 return static_cast<size_t>(chunk_index);
503 template <
class D,
class Traits,
typename T>
505 const size_t begin,
const size_t end,
507 using V = decltype(
Zero(
d));
514 size_t num = end - begin;
519 const size_t misalign =
520 (
reinterpret_cast<uintptr_t
>(keys) /
sizeof(T)) & (lanes_per_chunk - 1);
522 const size_t consume = lanes_per_chunk - misalign;
528 uint64_t* bits64 =
reinterpret_cast<uint64_t*
>(buf);
529 for (
size_t i = 0; i < 5; ++i) {
532 const uint32_t* bits =
reinterpret_cast<const uint32_t*
>(buf);
534 const uint32_t lpc32 =
static_cast<uint32_t
>(lanes_per_chunk);
537 const size_t num_chunks64 = num >> log2_lpc;
539 const uint32_t num_chunks =
540 static_cast<uint32_t
>(
HWY_MIN(num_chunks64, 0xFFFFFFFFull));
542 const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc;
543 const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc;
544 const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc;
545 const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc;
546 const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc;
547 const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc;
548 const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc;
549 const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc;
550 const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc;
551 for (
size_t i = 0; i < lanes_per_chunk; i +=
N) {
552 const V v0 =
Load(
d, keys + offset0 + i);
553 const V v1 =
Load(
d, keys + offset1 + i);
554 const V v2 =
Load(
d, keys + offset2 + i);
555 const V medians0 = MedianOf3(st, v0, v1, v2);
556 Store(medians0,
d, buf + i);
558 const V v3 =
Load(
d, keys + offset3 + i);
559 const V v4 =
Load(
d, keys + offset4 + i);
560 const V v5 =
Load(
d, keys + offset5 + i);
561 const V medians1 = MedianOf3(st, v3, v4, v5);
562 Store(medians1,
d, buf + i + lanes_per_chunk);
564 const V v6 =
Load(
d, keys + offset6 + i);
565 const V v7 =
Load(
d, keys + offset7 + i);
566 const V v8 =
Load(
d, keys + offset8 + i);
567 const V medians2 = MedianOf3(st, v6, v7, v8);
568 Store(medians2,
d, buf + i + lanes_per_chunk * 2);
571 return RecursiveMedianOf3(
d, st, buf, 3 * lanes_per_chunk,
572 buf + 3 * lanes_per_chunk);
577 template <
class D,
class Traits,
typename T>
583 first = st.LastValue(
d);
584 last = st.FirstValue(
d);
587 for (; i +
N <= num; i +=
N) {
588 const Vec<D>
v =
LoadU(
d, keys + i);
589 first = st.First(
d,
v, first);
590 last = st.Last(
d,
v, last);
594 const Vec<D>
v =
LoadU(
d, keys + num -
N);
595 first = st.First(
d,
v, first);
596 last = st.Last(
d,
v, last);
599 first = st.FirstOfLanes(
d, first, buf);
600 last = st.LastOfLanes(
d, last, buf);
603 template <
class D,
class Traits,
typename T>
604 void Recurse(D
d, Traits st, T*
HWY_RESTRICT keys,
const size_t begin,
605 const size_t end,
const Vec<D> pivot, T*
HWY_RESTRICT buf,
606 Generator& rng,
size_t remaining_levels) {
608 const size_t num = end - begin;
617 const ptrdiff_t base_case_num =
619 const size_t bound = Partition(
d, st, keys, begin, end, pivot, buf);
621 const ptrdiff_t num_left =
622 static_cast<ptrdiff_t
>(bound) -
static_cast<ptrdiff_t
>(begin);
623 const ptrdiff_t num_right =
624 static_cast<ptrdiff_t
>(end) -
static_cast<ptrdiff_t
>(bound);
633 ScanMinMax(
d, st, keys + begin, num, buf, first, last);
638 Recurse(
d, st, keys, begin, end, first, buf, rng, remaining_levels - 1);
643 BaseCase(
d, st, keys + begin,
static_cast<size_t>(num_left), buf);
645 const Vec<D> next_pivot = ChoosePivot(
d, st, keys, begin, bound, buf, rng);
646 Recurse(
d, st, keys, begin, bound, next_pivot, buf, rng,
647 remaining_levels - 1);
650 BaseCase(
d, st, keys + bound,
static_cast<size_t>(num_right), buf);
652 const Vec<D> next_pivot = ChoosePivot(
d, st, keys, bound, end, buf, rng);
653 Recurse(
d, st, keys, bound, end, next_pivot, buf, rng,
654 remaining_levels - 1);
659 template <
class D,
class Traits,
typename T>
660 bool HandleSpecialCases(D
d, Traits st, T*
HWY_RESTRICT keys,
size_t num,
668 const bool partial_128 =
N < 2 && st.Is128();
672 constexpr
bool kPotentiallyHuge =
674 const bool huge_vec = kPotentiallyHuge && (2 *
N > base_case_num);
675 if (partial_128 || huge_vec) {
683 BaseCase(
d, st, keys, num, buf);
707 template <
class D,
class Traits,
typename T>
710 #if HWY_TARGET == HWY_SCALAR
716 #if !HWY_HAVE_SCALABLE
722 static_assert(
sizeof(storage) <= 8192,
"Unexpectedly large, check size");
726 if (detail::HandleSpecialCases(
d, st, keys, num, buf))
return;
728 #if HWY_MAX_BYTES > 64
730 if (
Lanes(
d) > 64 /
sizeof(T)) {
731 return Sort(
CappedTag<T, 64 /
sizeof(T)>(), st, keys, num, buf);
736 detail::Generator rng(keys, num);
737 const Vec<D> pivot = detail::ChoosePivot(
d, st, keys, 0, num, buf, rng);
742 detail::Recurse(
d, st, keys, 0, num, pivot, buf, rng, max_levels);
#define HWY_MAX(a, b)
Definition: base.h:128
#define HWY_RESTRICT
Definition: base.h:63
#define HWY_NOINLINE
Definition: base.h:65
#define HWY_MIN(a, b)
Definition: base.h:127
#define HWY_INLINE
Definition: base.h:64
#define HWY_DASSERT(condition)
Definition: base.h:193
#define HWY_LIKELY(expr)
Definition: base.h:68
#define HWY_UNLIKELY(expr)
Definition: base.h:69
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void HeapSort(Traits st, T *HWY_RESTRICT keys, const size_t num)
Definition: vqsort-inl.h:71
HWY_INLINE void SortingNetwork(Traits st, T *HWY_RESTRICT buf, size_t cols)
Definition: sorting_networks-inl.h:603
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3111
HWY_INLINE bool AllFalse(hwy::SizeTag< 1 >, const Mask256< T > mask)
Definition: x86_256-inl.h:4066
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:936
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:4680
HWY_INLINE Vec128< T, N > Compress(Vec128< T, N > v, const uint64_t mask_bits)
Definition: arm_neon-inl.h:5020
void Swap(T *a, T *b)
Definition: vqsort-inl.h:63
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:825
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1553
hwy::SortConstants Constants
Definition: sorting_networks-inl.h:34
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:862
d
Definition: rvv-inl.h:1656
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:5244
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:1896
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:173
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2205
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:733
HWY_API size_t Lanes(Simd< T, N, kPow2 > d)
Definition: arm_sve-inl.h:218
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:708
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2224
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2031
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5061
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:189
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:79
N
Definition: rvv-inl.h:1656
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5052
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2397
const vfloat64m1_t v
Definition: rvv-inl.h:1656
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:633
HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:598
constexpr size_t CeilLog2(TI x)
Definition: base.h:700
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:82
#define HWY_LANES(T)
Definition: set_macros-inl.h:83
#define HWY_ALIGN
Definition: set_macros-inl.h:81
#define HWY_NAMESPACE
Definition: set_macros-inl.h:80
Definition: arm_neon-inl.h:4797
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N)
Definition: contrib/sort/shared-inl.h:68