7. TI-NPU QAT in PyTorch¶
This section explains how a PyTorch model can perform Quantization-Aware Training QAT for TI-NPU.
For supported layers to run on the NPU, the model must be trained with a specific QAT scheme.
This section is intended for users who are familiar with PyTorch and would like to integrate an existing model with TI’s QAT module for TI-NPU. If you do not have a model for your application, it is recommended to use Model Composer where you can select from TI’s Model Zoo.
7.1. Environment Setup¶
Follow these steps to set up your environment for TI-NPU QAT and compilation.
Linux
Create a Python virtual environment.
python3 -m venv .venv source ./.venv/bin/activate
Install tinyml-tensorlab.
git clone https://github.com/TexasInstruments/tinyml-tensorlab.git
Setup the Python module for TI-NPU QAT.
cd tinyml-tensorlab/tinyml-modeloptimization/torchmodelopt ./setup.sh
Install compiler dependencies. Please follow the compilation Environment Setup.
7.2. TI-NPU QAT¶
7.2.1. Adding TI-NPU QAT to PyTorch¶
QAT for TI-NPU is easy to incorporate into existing PyTorch training code. There is a wrapper module
called TINPUTinyMLQATFxModule
that automates the tasks required for QAT. The user must
wrap their model in TINPUTinyMLQATFxModule
and perform further training.
TINPUTinyMLQATFxModule
does the following operations to the model:
Replace layers in the model with their “Fake Quantized” versions.
Other modifications to help the learning process.
The training flow is two-part:
Train the existing model in floating point as usual.
Wrap the pre-trained floaing point model with
TINPUTinyMLQATFxModule
and perform QAT.
7.2.2. How to use TINPUTinyMLQATFxModule
¶
The following is a description of the changes an existing PyTorch training script would need to incorporate
TINPUTinyMLQATFxModule
.
1import tinyml_torchmodelopt.quantization as tinpu_quantization 2 3# create your model here: 4model = ... 5 6# load your pretrained checkpoint/weights here or run your usual floating-point training 7pretrained_data = torch.load(pretrained_path) 8model.load_state_dict(pretrained_data) 9 10# wrap your model in TINPUTinyMLQATFxModule 11model = tinpu_quantization.TINPUTinyMLQATFxModule(model, total_epochs=epochs) 12 13# train the wrapped model in your training loop here with loss, backward, optimizer, etc. 14# your usual training loop 15model.train() 16for e in range(epochs): 17 for images, target in my_dataset_train: 18 output = model(images) 19 # loss, backward(), optimizer step, etc comes here as usual in training 20 21model.eval() 22 23# convert the model to operate with integer operations (instead of QDQ FakeQuantize operations) 24model = model.convert() 25 26# create a dummy input - this is required for onnx export. 27dummy_input = torch.rand((1,1,256,1)) 28 29# export the quantized model to onnx format 30torch.onnx.export(model.module, dummy_input, os.path.join(save_path,'model_int8.onnx'), input_names=['input'])
7.2.3. Supported Layers Configurations¶
For optimal inference latency, consider changing your existing model architecture to use Layer Configurations Supported on the NPU.
7.2.4. Example Model¶
An example PyTorch model, with QAT for TI-NPU and layer configurations supported by the NPU is included in the compiler package. Please complete Environment Setup before running the example.
ls `python3 -c "import tvm; print(tvm.__path__[0])"`/ti_docs/examples/arc_fault_ti_qat.py
Install additional modules for running the example.
pip3 install -r `python3 -c "import tvm; print(tvm.__path__[0])"`/ti_docs/examples/requirements.txt pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
To run the example:
cd `python3 -c "import tvm; print(tvm.__path__[0])"`/ti_docs/examples/ python3 ./arc_fault_ti_qat.py
The resulting arc_fault_int8.onnx
contains the TI-NPU quantized model.
Note
The
arc_fault_ti_qat.py
example script is included for illustrative purposes of TI-NPU QAT only. The model architecture, dataset, and training hyperparameters are not intended to be adopted by external users.
7.3. Compile Model with TI MCU NNC¶
Please follow the Compilation Command to compile the quantized ONNX model.