皆さん、3値データ、使っていますか?
どうやら最近はLLMの文脈で3値データが注目を浴びているようです。
この論文では、重みに3値データを使うことで大幅な軽量化が可能、という触れ込みです。 ここ一か月でも動作する実装が公開されたり、続きの論文が発表されるなど、ホットな状況が続いています。
3値(ternary)データとは、3種類の値を取るデータのことです。 3種類の値の例としては、の他に、や、、 などもあります。
✊↔0, ✌↔1, ✋↔2のように、あらかじめとの対応関係を決めておくことで、これらの値をで表現することができます。 以下、3値データはの形で表されているものとします。
用語・記法
2値(binary)データの大きさはbit(s)で表しますが、3値データの大きさを表すのにはtrit(s)が用いられます。 tritという名前はtrinary digitから来ているそうです。
2進法・3進法・16進法表記の接頭辞としてそれぞれ0b
, 0t
, 0x
を使います。(10進法表記で)89
は16進法表記で0x59
、2進法表記で0b1011001
、3進法表記で0t10022
となります。
10進法表記と同じく、上位桁が前(左)に来るように並べます。
既存のternay->binaryエンコードアルゴリズムの課題
3値データをバイナリデータにエンコードすることを考えます。 5tritのデータは243通りあるので、8bit(256通り)のバイナリにエンコードすることができます。 5tritより長いデータは5trit単位で区切り、それぞれを8bitにエンコードしてから連結すればよいです*1。 こうすることでそこそこの空間効率を実現しつつ、計算コストを抑えることができます。
5tritを8bitにエンコードする方法として、愚直に3進2進変換をするアルゴリズムを考えます。 このアルゴリズムでは、デコード時に繰り返し3で除算する必要があり、デコードの計算が比較的重いという問題があります。 あらかじめ倍しておいて、デコード時には3を繰り返し掛ける、という方法もあります。 詳細は解説記事(英語)を読んでください。
3を掛けるのはx + (x << 1)
でできるため、デコード時には乗算器は不要で、加算器だけあればよいです。
しかし、今度はエンコード時に定数による除算(乗算に置き換え可能)が必要になってしまいます。
このエンコード処理を論理回路として実装すると、回路規模がかなり大きくなってしまいます。
Densely Packet Ternary(DPT)エンコード
今回紹介するDensely Packed Ternary(DPT)はエンコードもデコードも比較的低コストでできるアルゴリズムとなっています。
DPTでは同じく5tritを8bitに変換します。 まず、5tritを2,2,1tritの「サブブロック」に分割し、以後各サブブロックをB1, B2, B3とします。B1が下位trit側でB3が上位trit側となることに注意してください。 各サブブロックを3進2進変換でそれぞれ4,4,2bitsに変換します。 2tritを4bitに変換する際、上位tritに3を掛けるのに乗算が必要ですが、定数3と0,1,2のいずれかを掛けるだけなので、0,3,6から選択する回路で実現できます。
大きい数、小さい数
DPTにおいて重要な概念である、「大きい数」「小さい数」について説明します。
2tritを3進2進変換でエンコードすると、
ternary | binary |
---|---|
00 | 0000 |
01 | 0001 |
02 | 0010 |
10 | 0011 |
11 | 0100 |
12 | 0101 |
20 | 0110 |
21 | 0111 |
22 | 1000 |
となります。ここで、binaryの最上位ビットに注目すると、1通りだけ1でそれ以外は0です。 デコードするときには、
- 最上位ビットが1なら、下位ビットを見ずに22であると確定。
- 最上位ビットが0なら、下位3ビットの組み合わせ全8通りにそれぞれ対応する値(00~21)があるのでそれを求める。
1tritをエンコードするときは、
ternary | binary |
---|---|
0 | 00 |
1 | 01 |
2 | 10 |
となり、デコード時にはbinaryの最上位ビットが0b1なら0t2で確定、0b0なら下位1ビットに応じて0t0か0t1かが決まります。
1trit, 2tritのいずれの場合でも最上位ビットが0なら残りのビットをエンコードする必要があり、1なら残りのビットは省略することができます。 最上位ビットが0の数を「小さい数」、1の数を「大きい数」と呼んで区別することにします。
大小による場合分け
B1, B2の2ブロックの大小の組み合わせは全部で4通りあり、B3は1tritなので3通りあります。 B3は00,01,10のいずれかにエンコードされるため、11に別の組み合わせを割り当てることでうまく8bitに収めることができます。
具体的には、B3をエンコードした値を0bXY
、B1,B2をエンコードした値の下位3bitを0bEFG
0bBCD
と書くことにすると、
B2 | B1 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 |
---|---|---|---|---|---|---|---|---|---|
S | S | X | Y | B | C | D | E | F | G |
S | L | 1 | 1 | X | Y | 0 | B | C | D |
L | S | 1 | 1 | X | Y | 1 | E | F | G |
L | L | 1 | 1 | 1 | 1 | X | Y | 0 | 0 |
(ただし、S
が小さい数、 L
が大きい数。上位trit/bitを左に書いていることに注意)
とエンコードすればよいです。 デコード時には、上位ビットから読めば場合分けできます。
ここからさらに"よい"性質を持たせることを考えます。 例えば、上位3tritが0であるとわかっているようなときにエンコード・デコードを簡略化できるとうれしいです。 具体的には、0t00000から0t00022をエンコードした結果の下位4bitが0b0000から0b1000になっていると、この範囲のデコードが非常に単純になります。 今のエンコードでは0t22のときに0b11000000にエンコードされるので、下位ビットが0b1000とは異なっています。
実は、次のようにエンコードすることで実現できます。
小さい数の下位ビットをB1,B2,B3それぞれについて0bEFG
0bBCD
0bA
と書くことにすると、
B3 | B2 | B1 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 |
---|---|---|---|---|---|---|---|---|---|---|
S | S | S | 0 | B | C | D | A | E | F | G |
L | S | S | 1 | B | C | D | 0 | E | F | G |
S | S | L | 1 | B | C | D | 1 | 0 | 0 | A |
L | S | L | 1 | B | C | D | 1 | 0 | 1 | 0 |
S | L | S | 1 | E | F | G | 1 | 1 | 0 | A |
L | L | S | 1 | E | F | G | 1 | 1 | 1 | 0 |
S | L | L | 1 | 0 | 0 | A | 1 | 0 | 1 | 1 |
L | L | L | 1 | 0 | 1 | 0 | 1 | 0 | 1 | 1 |
(S
が小さい数、 L
が大きい数)
実際に確認してみましょう。 まず、上位3trit以上が0のとき、B2,B3は0となるため、必ず小さい数となります。 表からB2,B3がLとなっている行を取り除いて、
B1 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 |
---|---|---|---|---|---|---|---|---|
S | 0 | B | C | D | A | E | F | G |
L | 1 | B | C | D | 1 | 0 | 0 | A |
さらに、A,B,C,Dは0となるため、
B1 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 |
---|---|---|---|---|---|---|---|---|
S | 0 | 0 | 0 | 0 | 0 | E | F | G |
L | 1 | 0 | 0 | 0 | 1 | 0 | 0 | 0 |
下位4bitを見ると、確かに下位2tritであるB1を4bitにエンコードした結果と一致しています。また、上位bitは7bit目は3bit目をコピーし、あとは0埋めしたものになっています。
実は、このエンコーディングでは0t02122(=71)以下の値をエンコードすると、素直に2trit→4bit変換をして順にくっつけたものと下位7bitが一致します。 この範囲ではB3が0、B2が小さい数となるため、エンコード規則は次のように非常に単純になります。
B1 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 |
---|---|---|---|---|---|---|---|---|
S | 0 | B | C | D | 0 | E | F | G |
L | 1 | B | C | D | 1 | 0 | 0 | 0 |
元ネタ
このDPTには「元ネタ」があります。10進数を2進数にエンコードするときに用いられるDPD(Densely Packed Decimal)です。 Wikipediaの記事にアルゴリズムの詳しい解説が載っています。
Densely packed decimal - Wikipedia
この記事の解説もWikipediaのDPDの解説を参考にさせていただきました。
実装
Rustにて実装したものがこちらになります。
https://github.com/primenumber/densely_packed_ternary
このリポジトリには、DPT以外にも他にいくつかのアルゴリズムの実装もおいてあります。 ぜひ使ってみてください(あとでドキュメント等を整備しつつcrates.ioにも上げておきます)。 それでは楽しい3値ライフを!
*1:長さが5の倍数でないときは0埋めしたり端数を別で持ったりすればよい