- 001
- 002
- 003
- 004
- 005
- 006
- 007
- 008
- 009
- 010
- 011
- 012
- 013
- 014
- 015
- 016
- 017
- 018
- 019
- 020
- 021
- 022
- 023
- 024
- 025
- 026
- 027
- 028
- 029
- 030
- 031
- 032
- 033
- 034
- 035
- 036
- 037
- 038
- 039
- 040
- 041
- 042
- 043
- 044
- 045
- 046
- 047
- 048
- 049
- 050
- 051
- 052
- 053
- 054
- 055
- 056
- 057
- 058
- 059
- 060
- 061
- 062
- 063
- 064
- 065
- 066
- 067
- 068
- 069
- 070
- 071
- 072
- 073
- 074
- 075
- 076
- 077
- 078
- 079
- 080
- 081
- 082
- 083
- 084
- 085
- 086
- 087
- 088
- 089
- 090
- 091
- 092
- 093
- 094
- 095
- 096
- 097
- 098
- 099
- 100
// https://github.com/google/ruy/blob/2887692065c38ef6617f423feafc6b69dd0a0681/ruy/pack_avx2_fma.cc#L66
inline void Pack8bitColMajorForAvx2Packer(
const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride, int remaining_src_cols,
int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr,
std::int8_t* trailing_buf) {
using Layout = PackImpl8bitAvx2::Layout;
RUY_DCHECK_EQ(Layout::kCols, 8);
RUY_DCHECK_EQ(Layout::kRows, 4);
// Each Layout::Rows is 4 contiguous input, contiguous packed elements.
// We process 8 of these chunks at a time, padding short input chunks.
constexpr int kNumRowChunks = 8;
constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
const std::int8_t* src_ptr0 = src_ptr;
const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
std::int64_t src_inc0 = kNumChunkedSrcRows;
std::int64_t src_inc1 = kNumChunkedSrcRows;
std::int64_t src_inc2 = kNumChunkedSrcRows;
std::int64_t src_inc3 = kNumChunkedSrcRows;
std::int64_t src_inc4 = kNumChunkedSrcRows;
std::int64_t src_inc5 = kNumChunkedSrcRows;
std::int64_t src_inc6 = kNumChunkedSrcRows;
std::int64_t src_inc7 = kNumChunkedSrcRows;
// Handle cases where source does not have Layout::kCols (8) columns.
if (remaining_src_cols < 8) {
if (remaining_src_cols <= 0) {
src_ptr0 = zerobuf;
src_inc0 = 0;
}
if (remaining_src_cols <= 1) {
src_ptr1 = zerobuf;
src_inc1 = 0;
}
if (remaining_src_cols <= 2) {
src_ptr2 = zerobuf;
src_inc2 = 0;
}
if (remaining_src_cols <= 3) {
src_ptr3 = zerobuf;
src_inc3 = 0;
}
if (remaining_src_cols <= 4) {
src_ptr4 = zerobuf;
src_inc4 = 0;
}
if (remaining_src_cols <= 5) {
src_ptr5 = zerobuf;
src_inc5 = 0;
}
if (remaining_src_cols <= 6) {
src_ptr6 = zerobuf;
src_inc6 = 0;
}
src_ptr7 = zerobuf;
src_inc7 = 0;
}
const std::int8_t zero_point = zerobuf[0];
if (sums_ptr) {
// i: Layout::kCols.
for (int i = 0; i < 8; ++i) {
sums_ptr[i] = 0;
}
}
std::int32_t sums_adjustment = 0;
const __m256i ones_16bit = _mm256_set1_epi16(1);
__m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0);
__m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0);
// The overall packing effectively pads the source rows to
// (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
// only pack for (src_rows + 31) & ~31. When there is an incomplete
// destination block, this is stored into trailing_buf instead of packed_ptr.
for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
// Available source rows.
// If this is less than 0 (for m=1), we skip, having filled trailing
// buffer for m=0. Also, if source rows is zero on m=1, then we filled
// exactly to the end of the column in the packed buffer.
const int available_src_rows = src_rows - k;
// Effectively,
// available rows = std::max(0, std::min(8, src_rows - k));
// treat each case separately.
if (available_src_rows >= kNumChunkedSrcRows) {
if (sums_ptr) {
__m256i t0, t1, t2, t3, t4, t5, t6, t7;
__m256i r0, r1, r2, r3, r4, r5, r6, r7;
const __m256i input_xor_v = _mm256_set1_epi8(input_xor);
t0 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr0));
t4 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr4));
t1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_ptr1));
Интересно, они это вручную всё писали, или какой-то хуйней генерировали?