Gosuke Miyashita about feed

Pyxisの開発: safetensorsパーサー実装

30 June 2026

前回の記事 でPyxisを軽く紹介した。LLM推論エンジンを自前実装するプロジェクトで、今回はその最初のステップであるsafetensorsパーサーの実装について書く。


safetensorsとは

LLMの重みファイルのフォーマットはいくつかあるが、HuggingFaceが策定した safetensors が量子化されていないモデルの配布フォーマットとして現在のデファクトになっている。名前の通り、従来のPickleベースのフォーマット(PyTorchの .pt ファイル)に比べて安全で、任意コードの実行ができない。

バイナリフォーマットの構造はシンプルで、こうなっている。

[ 8 bytes : header size (u64 LE) ]
[ N bytes : JSON header          ]
[ tensor data                    ]

最初の8バイトにヘッダーのサイズがリトルエンディアンのu64で入っていて、その後ろにそのサイズ分のJSONが続く。JSONの中に各テンソルのメタデータ(名前、dtype、shape、データのオフセット)が入っている。テンソルの実データはJSONの後ろにそのまま並んでいる。

{
  "__metadata__": {"format": "pt"},
  "model.embed_tokens.weight": {
    "dtype": "BF16",
    "shape": [151936, 2048],
    "data_offsets": [0, 622329856]
  }
}

dtype はテンソルの要素の数値型。主なものはこんな感じ。

dtype ビット数 説明
F32 32bit 単精度浮動小数点。指数部8bit、仮数部23bit。精度が高いが重い
F16 16bit 半精度浮動小数点。指数部5bit、仮数部10bit
BF16 16bit Brain Float 16。Google Brainが考案(現在はGoogle DeepMindに統合)。指数部8bit、仮数部7bit。F32と指数部の幅が同じなので表現できる数値のスケールがF32と同じ(最大値約3.38×10^38、最小値約1.17×10^-38)。精度はF32より低いがスケールを保ったままメモリを半分に抑えられるため、LLMの学習・推論では広く使われる
I8 8bit 8ビット整数。量子化済みモデルで使われる

BF16はF32の指数部をそのままに仮数部を23bitから7bitに切り詰めた、いわばF32の軽量版。メモリ使用量はF32の半分になりながら、数値のスケール(桁の範囲)はF32と同じに保てる。LLMの学習では勾配の値が幅広い範囲に散らばるため、スケールを保ったBF16はF32からの置き換えとして相性がよく、最近のモデルはBF16で配布されているものが多い。F16も同じ16bitだが指数部が5bitと狭く、スケールはBF16より大幅に小さい。

shape はテンソルの次元数と各次元のサイズで、たとえば [151936, 2048] なら151936行×2048列の行列を表す。


実装

パーサー本体はこんな感じ。

pub fn parse_header(path: &Path) -> io::Result<HashMap<String, TensorInfo>> {
    let mut file = File::open(path)?;

    let mut size_buf = [0u8; 8];
    file.read_exact(&mut size_buf)?;
    let header_size = u64::from_le_bytes(size_buf) as usize;

    let mut header_buf = vec![0u8; header_size];
    file.read_exact(&mut header_buf)?;

    let header_json: HashMap<String, serde_json::Value> =
        serde_json::from_slice(&header_buf)?;

    let mut tensors = HashMap::new();
    for (key, value) in header_json {
        if key == "__metadata__" {
            continue;
        }
        let info: TensorInfo = serde_json::from_value(value)?;
        tensors.insert(key, info);
    }

    Ok(tensors)
}

ファイルを開いて最初の8バイトを読んでヘッダーサイズを得る、ヘッダーサイズ分のバイトを読んでJSONとしてパースする、 __metadata__ キーをスキップして各テンソルのメタデータを取り出す、という流れ。

#[derive(Debug, Clone, Deserialize)]
pub struct TensorInfo {
    pub dtype: String,
    pub shape: Vec<usize>,
    pub data_offsets: [usize; 2],
}

実際にQwen3-1.7BのモデルファイルにこのCLIを通してみると、こんな出力が得られる。

Tensor                                Dtype           Shape               Offsets
---------------------------------------------------------------------------------
lm_head.weight                         BF16  [151936, 2048]          0..622329856
model.embed_tokens.weight              BF16  [151936, 2048]          0..622329856
model.layers.0.input_layernorm.weight  BF16          [2048]  622329856..622333952
...
Total: 311 tensors

テンソルの名前、dtype、shape、データのオフセット範囲が一覧で確認できる。311個のテンソルが含まれていることがわかる。