bandarra.me

Count tokens with the Gemma 2 Tokenizer in Rust

For those working with Large Language Models, counting the number of tokens in an input can be a frequent task. As Gemini and Gemma share the same tokenizer (at least for now), it is quite useful to be able to be able to count tokens on an input locally, without making network calls to an endpoint, which can be much slower.

In rust, this can be achieved with the tokenizers crate. The sample code below is a minimalistc implementation of this sample code, removing the need for the candle-examples create, but still uses the hf_hub crate to manage model download, but those could be manually downloaded too.

use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use tokenizers::Tokenizer;

const HF_TOKEN: &str = "YOUR_TOKEN_HERE";
const MODEL_ID: &str= "google/gemma-2-2b";
const MODEL_REVISION: &str = "main";

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let api = ApiBuilder::new().with_token(Some(HF_TOKEN.to_string())).build()?;
    let repo = api.repo(Repo::with_revision(
        MODEL_ID.to_string(),
        RepoType::Model,
        MODEL_REVISION.to_string().to_string(),
    ));

    let tokenizer_filename = repo.get("tokenizer.json")?;
    let tokenizer = Tokenizer::from_file(tokenizer_filename.clone()).unwrap();

    let prompt = "Why is the sky blue?";
    let tokens = tokenizer
        .encode(prompt, true)
        .unwrap()
        .get_ids()
        .to_vec();

    println!("Generated {}", tokens.len());
    Ok(())
}

The hf_hub crate is smart and caches the model once downloaded. While initializing the model from still takes about 600ms, it should be done only once in the application, counting tokens is quite fast, generally under 1ms.

aarch64-pc-windows-msvc issues with candle-examples

In the original example, this code is based on depends on the candle-examples crate, which fails to build on aarch64 architectures. The issue is caused by one of its dependencies, the gemm-f16 crate. There are workarounds described in this issue.

For aarch64-pc-windows-msvc, adding the configuration below to .cargo/config.toml file should do the trick:

[build]
rustflags = [
    "-Ctarget-feature=+fp16,+fhm"
]