prime's diary

そすうの日々を垂れ流しちゃうやつだよ

効率的かつハードウェアフレンドリーな3値データのエンコーディングの提案 【Densely Packed Ternary, DPT】

皆さん、3値データ、使っていますか?

どうやら最近はLLMの文脈で3値データが注目を浴びているようです。

arxiv.org

この論文では、重みに3値データを使うことで大幅な軽量化が可能、という触れ込みです。 ここ一か月でも動作する実装が公開されたり、続きの論文が発表されるなど、ホットな状況が続いています。

3値(ternary)データとは、3種類の値を取るデータのことです。 3種類の値の例としては、 \lbrace 0, 1, 2 \rbrace の他に、 \lbrace -1, 0, 1 \rbrace  \lbrace ✊,✌,✋ \rbrace  \lbrace 松,竹,梅 \rbrace  \lbrace \lt , = , \gt \rbrace などもあります。

✊↔0, ✌↔1, ✋↔2のように、あらかじめ \lbrace 0, 1, 2 \rbrace との対応関係を決めておくことで、これらの値を \lbrace 0, 1, 2 \rbrace で表現することができます。 以下、3値データは \lbrace 0, 1, 2 \rbrace の形で表されているものとします。

用語・記法

2値(binary)データの大きさはbit(s)で表しますが、3値データの大きさを表すのにはtrit(s)が用いられます。 tritという名前はtrinary digitから来ているそうです。

2進法・3進法・16進法表記の接頭辞としてそれぞれ0b, 0t, 0xを使います。(10進法表記で)89は16進法表記で0x59、2進法表記で0b1011001、3進法表記で0t10022となります。 10進法表記と同じく、上位桁が前(左)に来るように並べます。

既存のternay->binaryエンコードアルゴリズムの課題

3値データをバイナリデータにエンコードすることを考えます。 5tritのデータは243通りあるので、8bit(256通り)のバイナリにエンコードすることができます。 5tritより長いデータは5trit単位で区切り、それぞれを8bitにエンコードしてから連結すればよいです*1。 こうすることでそこそこの空間効率を実現しつつ、計算コストを抑えることができます。

5tritを8bitにエンコードする方法として、愚直に3進2進変換をするアルゴリズムを考えます。 このアルゴリズムでは、デコード時に繰り返し3で除算する必要があり、デコードの計算が比較的重いという問題があります。 あらかじめ \frac{256}{243}倍しておいて、デコード時には3を繰り返し掛ける、という方法もあります。 詳細は解説記事(英語)を読んでください。

3を掛けるのはx + (x << 1)でできるため、デコード時には乗算器は不要で、加算器だけあればよいです。 しかし、今度はエンコード時に定数による除算(乗算に置き換え可能)が必要になってしまいます。 このエンコード処理を論理回路として実装すると、回路規模がかなり大きくなってしまいます。

Densely Packet Ternary(DPT)エンコード

今回紹介するDensely Packed Ternary(DPT)はエンコードもデコードも比較的低コストでできるアルゴリズムとなっています。

DPTでは同じく5tritを8bitに変換します。 まず、5tritを2,2,1tritの「サブブロック」に分割し、以後各サブブロックをB1, B2, B3とします。B1が下位trit側でB3が上位trit側となることに注意してください。 各サブブロックを3進2進変換でそれぞれ4,4,2bitsに変換します。 2tritを4bitに変換する際、上位tritに3を掛けるのに乗算が必要ですが、定数3と0,1,2のいずれかを掛けるだけなので、0,3,6から選択する回路で実現できます。

3 * t\_2 + t\_0

大きい数、小さい数

DPTにおいて重要な概念である、「大きい数」「小さい数」について説明します。

2tritを3進2進変換でエンコードすると、

ternary binary
00 0000
01 0001
02 0010
10 0011
11 0100
12 0101
20 0110
21 0111
22 1000

となります。ここで、binaryの最上位ビットに注目すると、1通りだけ1でそれ以外は0です。 デコードするときには、

  • 最上位ビットが1なら、下位ビットを見ずに22であると確定。
  • 最上位ビットが0なら、下位3ビットの組み合わせ全8通りにそれぞれ対応する値(00~21)があるのでそれを求める。

1tritをエンコードするときは、

ternary binary
0 00
1 01
2 10

となり、デコード時にはbinaryの最上位ビットが0b1なら0t2で確定、0b0なら下位1ビットに応じて0t0か0t1かが決まります。

1trit, 2tritのいずれの場合でも最上位ビットが0なら残りのビットをエンコードする必要があり、1なら残りのビットは省略することができます。 最上位ビットが0の数を「小さい数」、1の数を「大きい数」と呼んで区別することにします。

大小による場合分け

B1, B2の2ブロックの大小の組み合わせは全部で4通りあり、B3は1tritなので3通りあります。 B3は00,01,10のいずれかにエンコードされるため、11に別の組み合わせを割り当てることでうまく8bitに収めることができます。

具体的には、B3をエンコードした値を0bXY、B1,B2をエンコードした値の下位3bitを0bEFG 0bBCDと書くことにすると、

B2 B1 7 6 5 4 3 2 1 0
S S X Y B C D E F G
S L 1 1 X Y 0 B C D
L S 1 1 X Y 1 E F G
L L 1 1 1 1 X Y 0 0

(ただし、S が小さい数、 L が大きい数。上位trit/bitを左に書いていることに注意)

エンコードすればよいです。 デコード時には、上位ビットから読めば場合分けできます。

ここからさらに"よい"性質を持たせることを考えます。 例えば、上位3tritが0であるとわかっているようなときにエンコード・デコードを簡略化できるとうれしいです。 具体的には、0t00000から0t00022をエンコードした結果の下位4bitが0b0000から0b1000になっていると、この範囲のデコードが非常に単純になります。 今のエンコードでは0t22のときに0b11000000にエンコードされるので、下位ビットが0b1000とは異なっています。

実は、次のようにエンコードすることで実現できます。 小さい数の下位ビットをB1,B2,B3それぞれについて0bEFG 0bBCD 0bAと書くことにすると、

B3 B2 B1 7 6 5 4 3 2 1 0
S S S 0 B C D A E F G
L S S 1 B C D 0 E F G
S S L 1 B C D 1 0 0 A
L S L 1 B C D 1 0 1 0
S L S 1 E F G 1 1 0 A
L L S 1 E F G 1 1 1 0
S L L 1 0 0 A 1 0 1 1
L L L 1 0 1 0 1 0 1 1

(S が小さい数、 L が大きい数)

実際に確認してみましょう。 まず、上位3trit以上が0のとき、B2,B3は0となるため、必ず小さい数となります。 表からB2,B3がLとなっている行を取り除いて、

B1 7 6 5 4 3 2 1 0
S 0 B C D A E F G
L 1 B C D 1 0 0 A

さらに、A,B,C,Dは0となるため、

B1 7 6 5 4 3 2 1 0
S 0 0 0 0 0 E F G
L 1 0 0 0 1 0 0 0

下位4bitを見ると、確かに下位2tritであるB1を4bitにエンコードした結果と一致しています。また、上位bitは7bit目は3bit目をコピーし、あとは0埋めしたものになっています。

実は、このエンコーディングでは0t02122(=71)以下の値をエンコードすると、素直に2trit→4bit変換をして順にくっつけたものと下位7bitが一致します。 この範囲ではB3が0、B2が小さい数となるため、エンコード規則は次のように非常に単純になります。

B1 7 6 5 4 3 2 1 0
S 0 B C D 0 E F G
L 1 B C D 1 0 0 0

元ネタ

このDPTには「元ネタ」があります。10進数を2進数にエンコードするときに用いられるDPD(Densely Packed Decimal)です。 Wikipediaの記事にアルゴリズムの詳しい解説が載っています。

Densely packed decimal - Wikipedia

この記事の解説もWikipediaのDPDの解説を参考にさせていただきました。

実装

Rustにて実装したものがこちらになります。

https://github.com/primenumber/densely_packed_ternary

このリポジトリには、DPT以外にも他にいくつかのアルゴリズムの実装もおいてあります。 ぜひ使ってみてください(あとでドキュメント等を整備しつつcrates.ioにも上げておきます)。 それでは楽しい3値ライフを!

*1:長さが5の倍数でないときは0埋めしたり端数を別で持ったりすればよい