8.9. On-Device Training — Advanced API & Config
Deep dive into ODT library API and trainable model configuration.
8.9.1. Library Architecture
The ODT system has three layers:
┌─────────────────────────────┐
│ Application Code │ User's training loop
│ (application_main.c) │
└─────────────────────────────┘
↓
┌─────────────────────────────┐
│ Task-Specific Library │ Anomaly detection, classification
│ (e.g., anomaly_odt.c) │ Orchestrates training workflow
└─────────────────────────────┘
↓
┌─────────────────────────────┐
│ Core Training Library │ Forward pass, backward pass, SGD
│ (ondevice_training_lib.c) │ Task-agnostic
└─────────────────────────────┘
↓
┌─────────────────────────────┐
│ Model Configuration │ Auto-generated
│ (trainable_model_config) │ Weights, buffers, layer defs
└─────────────────────────────┘
Core Library Files:
ondevice_training_lib.h— API declarations, structuresondevice_training_lib.c— Forward/backward pass, SGD, weight updatesLocation:
{device-sdk}/ai/common/ondevice_training/
Key Principles:
Task-agnostic — library knows nothing about classification vs anomaly detection
Zero dynamic allocation — all memory pre-allocated statically
Flat array architecture — weights, gradients, buffers in contiguous arrays
Compile-time optimization — batch size = 1 eliminates accumulation code
8.9.2. ModelContext_t — Central Training State
Every ODT function takes ModelContext_t* as first argument.
Structure fields:
Field |
Type |
Description |
|---|---|---|
|
|
Layer descriptors (copied from config during init) |
|
|
Number of trainable layers |
|
|
Forward pass activation storage. Index 0 = frozen model output. Index i = output of layer i-1 |
|
|
Backward pass gradients. Index i+1 = gradient into layer i. Index i = gradient out of layer i |
|
|
Points to all trainable weights (being updated) |
|
|
Checkpoint storage for best-performing weights |
|
|
Accumulated gradients (only if |
|
|
Samples processed in current batch |
|
|
|
|
|
SGD learning rate (e.g., 0.001) |
8.9.3. Memory Layout
All three flat arrays follow the same indexing:
Example: Linear(4→3) → ReLU(3) → Linear(3→2)
ALL_WEIGHTS[23] (total params = 4×3 + 3 + 3×2 + 2 = 23)
┌──────────────────┬─────────┬──────────────────┬─────────┐
│ Layer 0 weights │ biases │ Layer 2 weights │ biases │
│ W[3×4] = 12 │ b = 3 │ W[2×3] = 6 │ b = 2 │
└──────────────────┴─────────┴──────────────────┴─────────┘
offset=0 offset=12 offset=15 offset=21
INTERMEDIATE_BUFFERS[12] (4 + 3 + 3 + 2)
┌──────┬────────┬────────┬────────┐
│ in │ L0_out │ L1_out │ L2_out │
│ (4) │ (3) │ (3) │ (2) │
└──────┴────────┴────────┴────────┘
GRADIENT_BUFFERS[12] (same layout as intermediate)
┌──────┬────────┬────────┬────────┐
│ grad │ grad │ grad │ loss │
│input │ into L0│ into L1│gradient│
└──────┴────────┴────────┴────────┘
Forward/Backward Flow:
Forward: frozen_output → buffer[0] → Layer 0 → buffer[1] → Layer 1 → buffer[2] → Layer 2 → buffer[3]
Backward: loss_grad → grad_buffer[3] → Layer 2 → grad_buffer[2] → Layer 1 → grad_buffer[1] → Layer 0 → grad_buffer[0]
8.9.4. Core API
8.9.4.1. Initialization
int ODT_Init(ModelContext_t* ctx, float learning_rate);
Parameters:
ctx— Model context to initializelearning_rate— SGD learning rate (e.g., 0.0001). Must be > 0
Returns: 0 on success, non-zero on error
Actions:
- Copies layer parameters from LAYER_PARAMS_INIT to ctx->layers[]
- Initializes all buffer pointers from offsets table
- Sets is_training_mode = PHASE_INFERENCE
8.9.4.2. Forward Pass
int ODT_Forward(ModelContext_t* ctx);
Computes activations for all layers sequentially:
Read input from
intermediate_buffers[0](frozen model output)For each layer i: - Read input from
intermediate_buffers[i]- Compute layer operation (linear, ReLU, etc.) - Write output tointermediate_buffers[i+1]Final output in
intermediate_buffers[num_layers]
Usage:
ODT_Forward(&ctx);
float* output = ctx.intermediate_buffers[ctx.num_layers];
8.9.4.3. Backward Pass
int ODT_Backward(ModelContext_t* ctx);
Computes weight gradients for all layers:
Loss gradient placed in
gradient_buffers[num_layers]For each layer i (reverse order): - Read incoming gradient from
gradient_buffers[i+1]- Compute gradients w.r.t. weights using input fromintermediate_buffers[i]- Accumulate intoweight_gradients(if enabled) - Compute outgoing gradient (for layer below) - Write outgoing gradient togradient_buffers[i]
Requires: Forward pass already completed, gradient_buffers[num_layers] populated with loss gradient
8.9.4.4. Weight Update (SGD)
int ODT_SGDUpdate(ModelContext_t* ctx);
Updates all weights using accumulated gradients:
current_weights[i] -= learning_rate * weight_gradients[i]
Then resets gradient accumulators to zero.
When to call: After ODT_Backward() completes
8.9.4.5. Loss Functions
Classification (CrossEntropy):
float ODT_LossCrossEntropy(const float* output, const uint16_t* targets,
uint16_t num_classes, uint16_t batch_size);
Anomaly Detection (MSE):
float ODT_LossMSE(const float* output, const float* target,
uint16_t size, uint16_t batch_size);
8.9.4.6. Weight Checkpointing
Save/restore best weights seen during training:
void ODT_SaveBestWeights(ModelContext_t* ctx);
void ODT_RestoreBestWeights(ModelContext_t* ctx);
Usage pattern:
for (int epoch = 0; epoch < max_epochs; epoch++) {
ODT_Forward(&ctx);
float loss = ODT_LossMSE(...);
if (loss < best_loss) {
best_loss = loss;
ODT_SaveBestWeights(&ctx);
}
ODT_Backward(&ctx);
ODT_SGDUpdate(&ctx);
}
ODT_RestoreBestWeights(&ctx); // Use best weights for inference
8.9.5. Trainable Model Configuration
Auto-generated files:
trainable_model_config.h— Architecture defines, layer descriptorstrainable_model_config.c— Weight storage, buffers
Generated by: ModelZoo Python pipeline with ondevice_training: enabled
8.9.5.1. trainable_model_config.h
Architecture Defines:
#define NUM_TRAINABLE_LAYERS 11
#define FROZEN_OUTPUT_SIZE 48 // Input to first trainable layer
#define FINAL_OUTPUT_SIZE 48 // Output of last trainable layer
#define TOTAL_PARAMS 3150 // All weights + biases
#define TOTAL_INTERMEDIATE_BUFFER_SIZE 252
#define TOTAL_GRADIENT_BUFFER_SIZE 252
Enumerations:
typedef enum {
LAYER_TYPE_LINEAR,
LAYER_TYPE_RELU,
} LayerType_t;
typedef enum {
TASK_TYPE_CLASSIFICATION,
TASK_TYPE_ANOMALY_DETECTION,
} TaskType_t;
typedef enum {
LOSS_FUNCTION_MSE,
LOSS_FUNCTION_CROSSENTROPY,
} LossFunction_t;
Task Configuration:
#define TASK_TYPE TASK_TYPE_ANOMALY_DETECTION
#define LOSS_FUNCTION LOSS_FUNCTION_MSE
#define TRAIN_BATCH_SIZE 1
#define VAL_BATCH_SIZE 1
Layer Descriptors (auto-generated):
static const LayerParams_t LAYER_PARAMS_INIT[NUM_TRAINABLE_LAYERS] = {
// Layer 0: Linear(48 -> 24)
{
.layer_type = LAYER_TYPE_LINEAR,
.input_size = 48,
.output_size = 24,
.weights_offset = 0,
.num_weights = 48 * 24 + 24 // 1176 (weights + biases)
},
// Layer 1: ReLU(24)
{
.layer_type = LAYER_TYPE_RELU,
.input_size = 24,
.output_size = 24,
.weights_offset = -1, // No weights for ReLU
.num_weights = 0
},
// ...more layers...
};
Buffer Offsets (auto-generated):
static const uint16_t BUFFER_OFFSETS[NUM_TRAINABLE_LAYERS + 1] = {
0, // buffer[0]: frozen output (48 floats)
48, // buffer[1]: Layer 0 output (24 floats)
72, // buffer[2]: Layer 1 output (24 floats)
// ...more offsets...
};
8.9.5.2. trainable_model_config.c
Weight Storage:
// All trained weights live here (NOT reinitialized)
float ALL_WEIGHTS[TOTAL_PARAMS] = { ... };
// Backup for checkpointing
float ALL_BEST_WEIGHTS[TOTAL_PARAMS];
// Gradient accumulation (only if USE_GRADIENT_ACCUMULATION=1)
float ALL_WEIGHT_GRADS[TOTAL_PARAMS];
Buffer Storage:
// Forward pass activations
float INTERMEDIATE_BUFFERS[TOTAL_INTERMEDIATE_BUFFER_SIZE];
// Backward pass gradients
float GRADIENT_BUFFERS[TOTAL_GRADIENT_BUFFER_SIZE];
Memory Sections:
Can be placed in specific MCU memory regions:
#pragma DATA_SECTION(ALL_WEIGHTS, ".trainable_weights")
#pragma DATA_SECTION(INTERMEDIATE_BUFFERS, ".fast_ram")
#pragma DATA_SECTION(GRADIENT_BUFFERS, ".fast_ram")
8.9.6. Custom Training Loop Example
Complete epoch:
// Initialize
ModelContext_t ctx;
ODT_Init(&ctx, learning_rate=0.001);
// Training epoch
for (int batch = 0; batch < num_batches; batch++) {
// Load batch
float* frozen_output = get_frozen_model_output();
memcpy(ctx.intermediate_buffers[0], frozen_output, FROZEN_OUTPUT_SIZE * sizeof(float));
// Forward pass
ODT_Forward(&ctx);
// Compute loss
float* target = get_target();
float loss = ODT_LossMSE(
ctx.intermediate_buffers[ctx.num_layers],
target,
FINAL_OUTPUT_SIZE,
TRAIN_BATCH_SIZE
);
// Log
printf("Batch %d: loss = %.4f\n", batch, loss);
// Backward pass
float loss_gradient = 1.0; // dL/dOutput
memcpy(ctx.gradient_buffers[ctx.num_layers], &loss_gradient, sizeof(float));
ODT_Backward(&ctx);
// Weight update
ODT_SGDUpdate(&ctx);
}
8.9.7. Adding a New Layer Type
To add support for a new layer (e.g., Conv2D):
1. Update enumerations in ``trainable_model_config.h``:
typedef enum {
LAYER_TYPE_LINEAR,
LAYER_TYPE_RELU,
LAYER_TYPE_CONV2D, // NEW
} LayerType_t;
2. Add layer struct in ``ondevice_training_lib.h``:
typedef struct {
uint16_t kernel_height;
uint16_t kernel_width;
uint16_t in_channels;
uint16_t out_channels;
uint16_t stride;
uint16_t padding;
// ... more params
} Conv2DParams_t;
3. Implement forward/backward in ``ondevice_training_lib.c``:
void Conv2D_Forward(const float* input, float* output, const Conv2DParams_t* params, const float* weights);
void Conv2D_Backward(const float* input, const float* grad_output, float* grad_input,
float* grad_weights, const Conv2DParams_t* params);
4. Call from main forward/backward dispatcher:
case LAYER_TYPE_CONV2D:
Conv2D_Forward(input, output, &layer->params.conv2d, weights);
break;
8.9.8. Batch Size Optimization
Compile-time flag (``trainable_model_config.h``):
#define USE_GRADIENT_ACCUMULATION (TRAIN_BATCH_SIZE > 1)
When ``TRAIN_BATCH_SIZE = 1``:
- Gradient accumulation code is eliminated via #if
- Weight updates happen immediately after backward pass
- Saves memory (no ALL_WEIGHT_GRADS array)
- Slightly faster (fewer accumulation operations)
When ``TRAIN_BATCH_SIZE > 1``: - Gradients accumulated over batch - One weight update per batch (not per sample) - Better gradient estimates - More memory required
8.9.9. Logging System
Macros for debugging:
ODT_LOG("format string", args); // Info level
ODT_LOG_DEBUG("format string", args); // Debug level
ODT_LOG_ERROR("format string", args); // Error level
Control verbosity via:
#define ODT_LOG_LEVEL ODT_LOG_DEBUG // Set in trainable_model_config.h
Output: Logged data useful for on-device training diagnostics
8.9.10. API Quick Reference
Function |
Purpose |
|---|---|
|
Initialize model context with learning rate |
|
Compute forward pass activations |
|
Compute gradients (requires loss gradient in grad_buffers) |
|
Update weights using accumulated gradients |
|
Compute MSE loss (anomaly detection) |
|
Compute cross-entropy loss (classification) |
|
Checkpoint current weights to best_weights |
|
Restore best-performing weights |
|
Switch to training mode (enable dropout, batch norm, etc.) |
|
Switch to inference mode |