読者です 読者をやめる 読者になる 読者になる

prime's diary

そすうの日々を垂れ流しちゃうやつだよ

Bit scan reverseをAVXで並列化する

Bit scan reverse (bsr)

00000001010100010100000100101101
<-clz->*<---bit scan reverse--->

1になっている最も高位のビットのインデックスを0-indexedで計算します。0に対する結果は未定義(不定)です。 似た処理にcount leading zero(clz)があり、こちらは最上位ビット側から数えます。(0に対する結果は整数型のビット数num_bitsになる) したがって、0以外の値に対してはbsr + 1 + clz = num_bitsが成り立ちます。

bsr/lzcnt命令

x86にはズバリbsr命令があり、32ビット/64ビットの値に対してBit scan reverseを計算できます。 また、BMI拡張命令セットにはlzcntという命令があり、clzも1命令で計算できます。

しかし、bsr、lzcntともにSSEやAVX拡張命令セットの中に対応する命令がないため、並列に計算することができません。

今回はAVX命令を使って256ビット幅で並列にBit scan reverseを計算する方法を考えます。

また、8/16/32/64ビットそれぞれについて、どのアルゴリズムが最適なのかベンチマークを取ります。

アルゴリズム

スカラーに変換してbsrで処理 (naive)

まずは単純にスカラー値に変換してbsr命令で処理するのが一番自然でしょう。 pextr{b,w,d,q}は意外にコストが高いのでvmovdqaで一括でメモリ上(実際にはL1キャッシュに乗るはず)にコピーして処理することにします。 計算結果をYMMレジスタに格納するのは_mm256_set_epi* intrinsicsに任せました(メモリ上に書き戻してvmovdqaで格納するより速かった)

inline __m256i bsr_256_8_naive(__m256i x) {
  alignas(32) uint8_t b[32];
  _mm256_store_si256((__m256i*)b, x);
  return _mm256_setr_epi8(
      __bsrd(b[ 0]), __bsrd(b[ 1]), __bsrd(b[ 2]), __bsrd(b[ 3]),
      __bsrd(b[ 4]), __bsrd(b[ 5]), __bsrd(b[ 6]), __bsrd(b[ 7]),
      __bsrd(b[ 8]), __bsrd(b[ 9]), __bsrd(b[10]), __bsrd(b[11]),
      __bsrd(b[12]), __bsrd(b[13]), __bsrd(b[14]), __bsrd(b[15]),
      __bsrd(b[16]), __bsrd(b[17]), __bsrd(b[18]), __bsrd(b[19]),
      __bsrd(b[20]), __bsrd(b[21]), __bsrd(b[22]), __bsrd(b[23]),
      __bsrd(b[24]), __bsrd(b[25]), __bsrd(b[26]), __bsrd(b[27]),
      __bsrd(b[28]), __bsrd(b[29]), __bsrd(b[30]), __bsrd(b[31]));
}

inline __m256i bsr_256_16_naive(__m256i x) {
  alignas(32) uint16_t b[16];
  _mm256_store_si256((__m256i*)b, x);
  return _mm256_setr_epi16(
      __bsrd(b[ 0]), __bsrd(b[ 1]), __bsrd(b[ 2]), __bsrd(b[ 3]),
      __bsrd(b[ 4]), __bsrd(b[ 5]), __bsrd(b[ 6]), __bsrd(b[ 7]),
      __bsrd(b[ 8]), __bsrd(b[ 9]), __bsrd(b[10]), __bsrd(b[11]),
      __bsrd(b[12]), __bsrd(b[13]), __bsrd(b[14]), __bsrd(b[15]));
}

inline __m256i bsr_256_32_naive(__m256i x) {
  alignas(32) uint32_t b[8];
  _mm256_store_si256((__m256i*)b, x);
  return _mm256_setr_epi32(
      __bsrd(b[0]), __bsrd(b[1]), __bsrd(b[2]), __bsrd(b[3]),
      __bsrd(b[4]), __bsrd(b[5]), __bsrd(b[6]), __bsrd(b[7]));
}

inline __m256i bsr_256_64_naive(__m256i x) {
  alignas(32) uint64_t b[4];
  _mm256_store_si256((__m256i*)b, x);
  return _mm256_setr_epi64x(
      __bsrq(b[0]), __bsrq(b[1]), __bsrq(b[2]), __bsrq(b[3]));
}

popcountに帰着 (popcount)

uint32_t bsr32(uint32_t x) {
  x |= x >> 1;
  x |= x >> 2;
  x |= x >> 4;
  x |= x >> 8;
  x |= x >> 16;
  return popcount32(x>>1);
}

右シフトとORを組み合わせることで1の立っている一番高位のビットより下位のビットを1で埋めることができます。

これをpopcountと組み合わせればbsrが求められます。

inline __m256i popcount_256_8(__m256i x) {
  x = _mm256_sub_epi8(x, _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi8(0x55)));
  x = _mm256_add_epi8(_mm256_and_si256(x, _mm256_set1_epi8(0x33)), _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi8(0x33)));
  return _mm256_and_si256(_mm256_add_epi8(x, _mm256_srli_epi16(x, 4)), _mm256_set1_epi8(0x0F));
}

inline __m256i bsr_256_8_popcnt(__m256i x) {
  x = _mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi8(0x7F)));
  x = _mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi8(0x3F)));
  x = _mm256_and_si256(_mm256_srli_epi16(_mm256_or_si256(x, _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi8(0x0F))), 1), _mm256_set1_epi8(0x7F));
  return popcount_256_8(x);
}

inline __m256i popcount_256_16(__m256i x) {
  x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x5555)), _mm256_and_si256(_mm256_srli_epi16(x, 1), _mm256_set1_epi16(0x5555)));
  x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x3333)), _mm256_and_si256(_mm256_srli_epi16(x, 2), _mm256_set1_epi16(0x3333)));
  x = _mm256_add_epi16(_mm256_and_si256(x, _mm256_set1_epi16(0x0F0F)), _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi16(0x0F0F)));
  return _mm256_maddubs_epi16(x, _mm256_set1_epi16(0x0101));
}

inline __m256i bsr_256_16_popcnt(__m256i x) {
  x = _mm256_or_si256(x, _mm256_srli_epi16(x, 1));
  x = _mm256_or_si256(x, _mm256_srli_epi16(x, 2));
  x = _mm256_or_si256(x, _mm256_srli_epi16(x, 4));
  x = _mm256_srli_epi16(_mm256_or_si256(x, _mm256_srli_epi16(x, 8)), 1);
  return popcount_256_16(x);
}

inline __m256i popcount_256_32(__m256i x) {
  x = _mm256_sub_epi32(x, _mm256_and_si256(_mm256_srli_epi32(x, 1), _mm256_set1_epi8(0x55)));
  x = _mm256_add_epi32(_mm256_and_si256(x, _mm256_set1_epi8(0x33)), _mm256_and_si256(_mm256_srli_epi32(x, 2), _mm256_set1_epi8(0x33)));
  x = _mm256_and_si256(_mm256_add_epi8(x, _mm256_srli_epi32(x, 4)), _mm256_set1_epi8(0x0F));
  return _mm256_srli_epi16(_mm256_madd_epi16(x, _mm256_set1_epi16(0x0101)), 8);
}

inline __m256i bsr_256_32_popcnt(__m256i x) {
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 1));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 2));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 4));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 8));
  x = _mm256_srli_epi32(_mm256_or_si256(x, _mm256_srli_epi32(x, 16)), 1);
  return popcount_256_32(x);
}

inline __m256i bsr_256_64_popcnt(__m256i x) {
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 2));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 4));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 8));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 16));
  x = _mm256_srli_epi64(_mm256_or_si256(x, _mm256_srli_epi64(x, 32)), 1);
  return popcount_256_64(x);
}

シャッフル命令vpshufbを使った最適化 (rev popcount)

64ビット単位のときのみ、しかも若干効果がある程度ですが、vpshufbを使ってバイト順を反転させ、立っている一番下位のビットを算術演算で求めるテクニックを使うことができます。

inline __m256i bswap_256_64(__m256i x) {
  __m128i swap_table_128 = _mm_setr_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8);
  __m256i swap_table = _mm256_broadcastsi128_si256(swap_table_128);
  return _mm256_shuffle_epi8(x, swap_table);
}

inline __m256i bsr_256_64_rev_popcnt(__m256i x) {
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 1));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 2));
  x = _mm256_or_si256(x, _mm256_srli_epi64(x, 4));
  x = _mm256_andnot_si256(_mm256_srli_epi64(x, 1), x);
  x = bswap_256_64(x);
  x = _mm256_and_si256(x, _mm256_sub_epi64(_mm256_setzero_si256(), x));
  x = bswap_256_64(x);
  return popcount_256_64(_mm256_sub_epi64(x, _mm256_set1_epi64x(1)));
}

ギャザー命令を使ったテーブル参照 (table gather)

32/64ビット単位で表引きができるvpgather{d,q}{d,q}を使ってpopcountを表引きします。32ビット単位なら4byte*231byte = 8GBに収まります。 表引きは000111...111の形のものしかないので実は32通りしかないので、L1キャッシュに乗る可能性が高いです。 しかもmallocでメモリを確保してもアクセスしていない領域には実際にはページ割当がされないので物理メモリ使用量はそこまで大きくなりません。

int32_t *table_32;

void init_table_32() {
  table_32 = (int32_t*)malloc(sizeof(int32_t)*((size_t)1 << 31));
  table_32[0] = 0;
  int64_t j = 1;
  for (int i = 1; i < 31; ++i) {
    table_32[j] = i;
    j = 2*j + 1;
  }
}

inline __m256i bsr_256_32_table_gather(__m256i x) {
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 1));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 2));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 4));
  x = _mm256_or_si256(x, _mm256_srli_epi32(x, 8));
  x = _mm256_srli_epi32(_mm256_or_si256(x, _mm256_srli_epi32(x, 16)), 1);
  return _mm256_i32gather_epi32(table_32, x, 4);
}

vcvtdq2ps命令による浮動小数点数への変換 (cvtfloat)

浮動小数点数のビット表現は

kivantium.hateblo.jp

などを参照してください。

vcvtdq2ps命令で浮動小数点数へ変換して、指数部だけになるようにシフトして、下駄履き(単精度なら127)を引いてやることでbsrが得られます。 しかしこれには罠があります。具体的には、

  • ビットが浮動小数点数の精度以上に連続していると丸めによって間違った結果になる(1大きくなってしまう)。
  • 最上位ビットは符号として扱われるので間違った結果になる(負数扱いになり全く違う値になる)。

これらへの対策については

或るプログラマの一生 » x86 のビットスキャン命令の SIMD 化

或るプログラマの一生 » x86 のビットスキャン命令の SIMD 化(その2)

を参照してください。

16ビット以下では対策は不要になります。 また、64ビット版では64ビットを3分割することで問題を回避したバージョンも示します。

inline __m256i bsr_256_32_cvtfloat_impl(__m256i x, int32_t sub) {
  __m256i cvt_fl = _mm256_castps_si256(_mm256_cvtepi32_ps(x));
  __m256i shifted = _mm256_srli_epi32(cvt_fl, 23);
  return _mm256_sub_epi32(shifted, _mm256_set1_epi32(sub));
}

inline __m256i bsr_256_8_cvtfloat(__m256i x) {
  __m256i r0 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi32(0x000000FF)), 127);
  __m256i r1 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi32(x,  8), _mm256_set1_epi32(0x000000FF)), 127);
  __m256i r2 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi32(x, 16), _mm256_set1_epi32(0x000000FF)), 127);
  __m256i r3 = bsr_256_32_cvtfloat_impl(_mm256_srli_epi32(x, 24), 127);
  __m256i r02 = _mm256_blend_epi16(r0, _mm256_slli_epi32(r2, 16), 0xAA);
  __m256i r13 = _mm256_blend_epi16(r1, _mm256_slli_epi32(r3, 16), 0xAA);
  return _mm256_blendv_epi8(r02, _mm256_slli_epi16(r13, 8), _mm256_set1_epi16(0xFF00));
}

inline __m256i bsr_256_16_cvtfloat(__m256i x) {
  __m256i lo = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi32(0x0000FFFF)), 127);
  __m256i hi = bsr_256_32_cvtfloat_impl(_mm256_srli_epi32(x, 16), 127);
  return _mm256_blend_epi16(lo, _mm256_slli_epi32(hi, 16), 0xAA);
}

inline __m256i bsr_256_32_cvtfloat(__m256i x) {
  x = _mm256_andnot_si256(_mm256_srli_epi32(x, 1), x); // 連続するビット対策
  __m256i result = bsr_256_32_cvtfloat_impl(x, 127);
  result = _mm256_or_si256(result, _mm256_srai_epi32(x, 31));
  return _mm256_and_si256(result, _mm256_set1_epi32(0x0000001F));
}

inline __m256i bsr_256_64_cvtfloat(__m256i x) {
  __m256i bsr32 = bsr_256_32_cvtfloat(x);
  __m256i higher = _mm256_add_epi32(_mm256_srli_epi64(bsr32, 32), _mm256_set1_epi64x(0x0000000000000020));
  __m256i mask = _mm256_shuffle_epi32(_mm256_cmpeq_epi32(x, _mm256_setzero_si256()), 0xF5);
  return _mm256_blendv_epi8(higher, _mm256_and_si256(bsr32, _mm256_set1_epi64x(0xFFFFFFFF)), mask);
}

inline __m256i bsr_256_64_cvtfloat2(__m256i x) {
  __m256i bsr0 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(x, _mm256_set1_epi64x(0x3FFFFF)), 127);
  __m256i bsr1 = bsr_256_32_cvtfloat_impl(_mm256_and_si256(_mm256_srli_epi64(x, 22), _mm256_set1_epi64x(0x3FFFFF)), 105);
  __m256i bsr2 = bsr_256_32_cvtfloat_impl(_mm256_srli_epi64(x, 44), 83);
  __m256i result = _mm256_max_epi32(_mm256_max_epi32(bsr0, bsr1), bsr2);
  return _mm256_and_si256(result, _mm256_set1_epi64x(0xFFFFFFFF));
}

vpshufbを使ったテーブル参照 (table)

バイト単位のシャッフル命令vpshufbは、バイト列の並べ替え以外に、4ビット入力、8ビット出力のルックアップテーブルとして用いることができます。 これを用いて、4ビット単位でbsrを求めて、それを組み合わせることで8ビット以上のbsrを求めることができます。

inline __m256i bsr_256_8_table(__m256i x) {
  __m128i table128_lo = _mm_setr_epi8(0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3);
  __m128i table128_hi = _mm_setr_epi8(0, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7);
  __m256i table_lo = _mm256_broadcastsi128_si256(table128_lo);
  __m256i table_hi = _mm256_broadcastsi128_si256(table128_hi);
  return _mm256_max_epi8(
      _mm256_shuffle_epi8(table_lo, _mm256_and_si256(x, _mm256_set1_epi8(0x0F))),
      _mm256_shuffle_epi8(table_hi, _mm256_and_si256(_mm256_srli_epi16(x, 4), _mm256_set1_epi8(0x0F))));
}

inline __m256i bsr_256_16_table(__m256i x) {
  __m256i half = bsr_256_8_table(x);
  __m256i mask = _mm256_cmpgt_epi8(x, _mm256_setzero_si256());
  __m256i res8 = _mm256_add_epi8(half, _mm256_and_si256(mask, _mm256_set1_epi16(0x0800)));
  return _mm256_max_epi16(_mm256_and_si256(res8, _mm256_set1_epi16(0x00FF)), _mm256_srli_epi16(res8, 8)); 
}

ベンチマーク

Haswell/Broadwell/Skylakeの各アーキテクチャのCPUで各アルゴリズムを並列に実行した時のの実行時間を計測しました。

bsrが合計231回実行されるように、8ビット幅なら226回、64ビット幅なら229回ループを回します。

Parallel bit scan reverse

環境

Haswell:

  • Google compute engine n1-highmem-2 (2vCPU, 13GB mem)
  • Xeon 2.3GHz, Turbo Boost時は不明
  • OS: Ubuntu 16.10
  • Compiler: GCC 6.1.1

Broadwell:

  • Google compute engine n1-highmem-2 (2vCPU, 13GB mem)
  • Xeon 2.2GHz, Turbo Boost時は不明
  • OS: Ubuntu 16.10
  • Compiler: GCC 6.1.1

Skylake:

  • ローカル
  • Core i7-6700K 4コア8スレッド 4GHz, Turbo Boost 4.2GHz 64GB mem
  • OS: Arch Linux
  • Compiler: GCC 6.2.1

コンパイルオプションはいずれも

-O3 -flto -mtune=(アーキテクチャ名) -march=(アーキテクチャ名) -lboost_system -lboost_timer

結果

(異なるアーキテクチャ同士で実行時間を比較するのはTurbo boostなどの影響があるのであまり意味がないと思います。)

8ビット

アルゴリズム Haswell Broadwell Skylake
naive 1.323s 0.993s 0.832s
cvtfloat 0.284s 0.209s 0.165s
popcnt 0.264s 0.201s 0.169s
table 0.079s 0.058s 0.049s

いずれも表引きが一番高速なようです。

16ビット

アルゴリズム Haswell Broadwell Skylake
naive 1.562s 1.192s 0.948s
cvtfloat 0.257s 0.200s 0.146s
popcnt 0.550s 0.423s 0.337s
table 0.262s 0.199s 0.179s

表引きあるいは浮動小数点数への変換が一番高速な結果になりました。

32ビット

アルゴリズム Haswell Broadwell Skylake
naive 1.581s 1.220s 0.993s
cvtfloat 0.416s 0.318s 0.237s
popcnt 1.278s 0.966s 0.701s
table gather 1.202s 0.915s 0.523s

浮動小数点数への変換が圧倒的に速い結果になりました。 gather命令を使った表引きはSkylakeで若干高速になったようです。

64ビット

アルゴリズム Haswell Broadwell Skylake
naive 1.574s 1.220s 0.911s
cvtfloat 1.369s 1.037s 0.804s
cvtfloat2 1.595s 1.225s 0.897s
popcnt 2.521s 1.924s 1.498s
rev popcnt 2.448s 1.898s 1.476s

ここでも浮動小数点数への変換が一番速い結果となりました。3分割してコーナーケースの処理を不要にしたバージョンは思ったほど性能が出ませんでした。 64ビット幅になるとpopcntは命令数が増えてナイーブな実装より遅くなってしまいました。

来るAVX512時代には

少なくとも32ビット幅以上なら並列でclzを求めるvplzcnt{d,q}命令が追加されるのでそれを使えば良さそうです。 8ビット幅の場合vpermi2bを使って7ビットの表引きをして求めるのがいいかもしれません。 16ビット幅の場合はやってみないとわからなさそうです。

参考文献

或るプログラマの一生 » x86 のビットスキャン命令の SIMD 化

或るプログラマの一生 » x86 のビットスキャン命令の SIMD 化(その2)

bsrを並列処理する方法について考察した記事。当時はAVX2が使えるプロセッサはまだ発売されていなかった。