In the previous article, I wrote about using an feature extraction model to generate embeddings from text, then train a custom classification model for sentiment analysis, using the embeddings as the input for the model, and finally exported the model to run in the browser with LiteRT and Tensorflow Lite.
While the classification model in the previous solution runs on the client side, the solution uses Google AI's embedding API to generate text embeds - a Cloud API, so the solution is not entirely client-side.
In this article, I'll explore a full client-side solution for toxicity detection using Kaggle's Toxic Comment Classification Challenge dataset and Transformers.js's feature extraction pipeline, and running it in the browser with the ONNX web runtime.
Choose the tools and libraries
PyTorch was the ML framework used in the previous article, and there's no reason to choose a different approach.
Since the goal is to enable a full client-side solution, the embedding model needs to run on both the training pipeline and in the browser, for inference. The all-MiniLM-L6-v2
model is supported in Python with Sentence Transformers and in the browser with Transformers.js, making it a great choice.
Transformers.js is a great library for running off-the-shelf AI models in the browser. Because Transformers.js uses ONNX Runtime as the underlying AI library to run models, it makes ONNX Runtime web a great to choice to run the custom model in the browser, creating synergy between Transformers.js and the custom model, and avoiding increasing the number of dependencies on the web application.
Data pre processing
The data pre processing step consists of transforming the original dataset containing text comments and labels into a new dataset containing the embeddings generated by the feature extraction model and labels:
from sentence_transformers import SentenceTransformer model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') with open("data/train.csv", "r", encoding="utf-8") as dataset_file: dataset_csv = csv.DictReader(dataset_file) with open(output_file, "a") as output_file: for entry in dataset_csv: embeddings = model.encode([entry['comment_text']]) result = { 'id': entry['id'], 'embeddings': embeddings[0].tolist(), 'toxic': entry['toxic'], 'severe_toxic': entry['severe_toxic'], 'obscene': entry['obscene'], 'threat': entry['threat'], 'insult': entry['insult'], 'identity_hate': entry['identity_hate'], } json_result = json.dumps(result) output_file.write(json_result + "\n") output_file.flush()
Model architecture and training
The model architecture is similar to the one used on the previous article. In this case, the all-MiniLM-L6-v2
feature extraction model generates embeddings as an array of 384 float values, so the model needs to be changed to reflect that.
A normalization layer was also introduced in each layer of the model, as that has demonstrated to slightly improve the performance on the validation set, as well as reducing the number of epochs required for the model to converge.
class ToxicityModel(nn.Module): def __init__(self): super().__init__() self.linear0 = nn.Linear(384, 128) self.norm0 = nn.BatchNorm1d(128) self.linear1 = nn.Linear(128, 32) self.norm1 = nn.BatchNorm1d(32) self.linear_out = nn.Linear(32, 6) def forward(self, x): x = self.linear0(x) x = self.norm0(x) x = F.relu(x) x = self.linear1(x) x = self.norm1(x) x = F.relu(x) x = self.linear_out(x) return x
Binary Cross Entropy (BCE) is used as the loss function, through BCEWithLogitsLoss, which also combines the output with a sigmoid function, which allows comparing the results from the model with the labels from the training set.
An accuracy of 98% is achieved with this model on the validation set.
Model conversion from PyTorch to ONNX
Converting from the PyTorch format to ONNX requires the installation of the onnx
and onnxscript
dependencies, with a straightforward implementation.
import torch from all_minilm_l6.toxicity_model import ToxicityModel torch_model = ToxicityModel() torch_model.load_state_dict(torch.load( "SP-all-MiniLM-L6-v2.safetensors", weights_only=True, map_location=torch.device('cpu') )) torch_model.eval() example_inputs = (torch.randn(1, 384),) onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True) onnx_program.optimize() onnx_program.save("SP-all-MiniLM-L6-v2.onnx")
Similar to the LiteRT conversion, the model requires passing an example input when being converted, that can be randomly generated.
Running the ONNX model in the browser
Running ONNX models in the browser is achieved with the onnxruntime-web
library.
Because the model takes embeddings as input, generated with the all-MiniLM-L6-v2
, Transformers.js is required for the pre processing step, to transform the user input into embeddings.
Finally, the model output logits, which can be transformed into probabilities with a sigmoid function. ONNX Runtime doesn't provide the function out of the box, but the implementation is a one line function:
function sigmoid(xs) { return xs.map(x=> 1 / (1 + Math.exp(-x))) }
Finally putting everything together becomes a matter of importing the required libraries, transforming the user's input into embeddings, calling the custom model with those embeddings, and then applying the sigmoid function to the model results, generating a probability for each toxicity type:
import { pipeline } from '@huggingface/transformers'; import * as ort from 'onnxruntime-web'; // Create the Transformers.js pipeline using the all-MiniLM-L6-v2 feature // extractionmodel. const extractor = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2'); // Instantiate the custom ONNX runtime model. const model = await ort.InferenceSession.create('SP-all-MiniLM-L6-v2.onnx'); const sentences = ["This is an example input"]; // Generate embeddings from the user input. const output = await extractor(sentences, { pooling: 'mean', normalize: true }); // Classify the embeddings. const outputTensor = await model.run({x: output}); // Transform the embeddings into probabilities using the sigmoid function. const probabilities = sigmoid(outputTensor.linear_2.data);
Conclusion
The combination of a feature extraction model with a classification pattern looks like a promising pattern. With the model hitting 98% accuracy, and the size of both models together under 5mb, and inference that is almost instant, this looks like a good use case to run on the client side.
As a next step, make sure to check the model in action or take a look at source code for the model and web application