Table of Contents
- Introduction
- Quantization - Trading Precision for Size and Speed
- K-Means Weight Quantization
- Pruning and Sparse Matrices
- Knowledge Distillation
- References
Accelerating Neural Network Inference
When I first learned about Machine Learning, I got to know algorithms like gradient descent. It took a bit of time to wrap my head around and truly understand. I thought the bulk of the work was in building and training models, and once you hit that 98% accuracy on the test set, you were set. Inference, after all, seemed straightforward—just a dot product of features and weights plus bias, essentially . However, I didn't realize how challenging inference could become when dealing with large neural networks as I began learning about classical ML - not deep learning.
Inference on large neural networks efficiently is a real challenge in the ML Infra domain, especially for real-time applications and devices with limited computational resources (especially at the edge). Techniques such as quantization, pruning, and knowledge distillation have emerged as powerful strategies to enhance inference speed and reduce model size, making it feasible to deploy sophisticated ML models on edge devices without significant loss.
Quantization - Trading Precision for Size and Speed
Quantization reduces the size of the data type which represent model parameters, thereby decreasing the model's memory footprint and accelerating computation at the cost of losing the precision is the lowest order bits. This technique commonly converts numbers into .
In traditional ALUs, integer arithmetic is on the order of 2 to 10 times faster than floating-point arithmetic due to the simpler nature of integer operations, which don't have to manage the floating-point representations like mantissas and exponents. Despite modern optimizations reducing this gap, integer arithmetic remains faster for maximizing computational speed but the gap has been closing over the last few decades.
Quantization not only reduces the size of the model weights, making it smaller in size, but it also speeds up inference due to the operations being on smaller data types, which happen to be simpler and take fewer processor cycles. Wow, so many benefits, what's the catch? Again, you do lose some precision, but this is a highly researched area, and people have found ways to compress these models with really minimal loss. There's always a tradeoff, but if done correctly, quantization has a pretty low cost, if you ask me.
Simple Truncation and Its Limitations
Quantization Example: Representing as an
One might ask how a floating-point value like 0.12 can be represented by an integer data type such as , which typically only handles whole numbers. This transformation is made possible through the quantization process, involving a scaling operation to fit the floating-point values within the integer range. Let's dive into a simplified explanation.
Consider a floating-point value that we aim to quantize into the range. The first step in this process is to determine a suitable scale factor that maps the original floating-point range to our target integer range. Assuming we're working with a floating-point range of and we want to utilize the full range, we can set , since can represent values from -128 to 127 and we're adjusting for a symmetric range around zero.
The quantized value can then be calculated as follows:
In this scenario, the floating-point value 0.12 is quantized to 15 in the representation, showcasing how integers, with the help of scaling, can approximate floating-point values effectively.
To understand how 15 relates back to the original floating-point value, we perform the inverse operation using the same scale factor . To retrieve the approximate original floating-point value from our quantized integer , we apply:
The quantized value of 15, when scaled back using , closely approximates the original floating-point value of 0.12, showing us the practicality and efficiency of quantization in representing floating-point values within integer data types, albeit with some loss of precision.

K-Means Weight Quantization
The general idea is to apply the K-Means clustering algorithm to the weights of a trained neural network, grouping them into a smaller number of clusters. Each weight is then approximated by the centroid of the cluster it belongs to. This result is once again a compressed representation of the model that requires less storage and can be more efficient for inference, especially on devices with limited resources.
K-Means Weight Quantization Works:
-
Clustering: Perform K-Means clustering on the weights of the neural network. The number of clusters (k) is chosen based on the desired level of compression. Each cluster will represent a range of weights.
-
Centroids: Calculate the centroid of each cluster. The centroid is the mean of all weights that belong to that cluster and will be used to represent them.
-
Indexing: Replace each original weight with the index of the centroid it is closest to. These indices are significantly smaller in size compared to the original weights.
-
Reconstruction: During inference, the weights are reconstructed by looking up the centroid value corresponding to each index.

Post Training Quantization in TensorFlow
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
Pruning and Sparse Matrices
Pruning neural networks is an optimization technique that reduces model complexity by eliminating less important weights. Pruning is often combined with other optimizations for the changes to be realized, pruning has in some cases even elimated around 90% of weight parameters without significant loss of accuracy which goes what many of the inputs really don't make of a difference.

-
Threshold Selection: The first step in pruning is to choose a pruning threshold. This is often done by evaluating the distribution of the weights' magnitudes. A common approach is to select a threshold value so that weights below it are pruned but a reasonable constant of your choosing can also be used:
Pruning Threshold = Mean Weight Magnitude - (STD of Weights * Factor)
-
Pruning Process: Weights that fall below the threshold are set to zero, effectively removing them from the network. This process can be visualized as:
if |weight| < threshold: weight = 0
-
Model Compression: Post-pruning, the neural network becomes sparse. It can be represented in a sparse format, such as an M:N Sparsity Matrix. This format stores only non-zero elements, which can be efficiently processed by specialized hardware. Without it, we still perform a lot of useless multiply by 0 operations instead of skipping them.
-
Hardware Acceleration: With specialized hardware that supports sparse matrix operations, these pruned networks can benefit from accelerated computation. A sparse matrix multiplication engine can skip zero-value computations.
Unlike quantization, pruning allows for more fine grained control as you can control the , while quantization only gives the decision of what data type you want to quantize to, even though it's usually to .
import tensorflow_model_optimization as tfmot
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
model = keras.Sequential([...])
model = prune_low_magnitude(model)
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))
model_for_export = tfmot.sparsity.keras.strip_pruning(model)
Knowledge Distillation - AI teaching AI
Knowledge Distillation is built upon the "teacher-student" paradigm, where a lightweight and lower parameter student model is trained to emulate the output of a significantly larger teacher model or an ensemble of models. The distilled single model, as revealed in the research, suprisingly mirrors the performance of an ensemble of 10 models, with negligible differences in famous problems like the MNIST hand recognization problem.
-
Knowledge Transfer:
- The student learns from both the traditional dataset (hard targets) and the teacher's output probabilities (soft targets).
- Soft targets provide nuanced insights into the data, revealing the teacher's confidence levels across different classes.
-
Soft Targets:
- Unlike hard targets, which are the exact labels, soft targets are the probabilities predicted by the teacher for all possible outputs.
- This richer information helps the student grasp subtle distinctions between categories.
-
Temperature Scaling:
- A technique used to adjust the softmax function, making the output probabilities smoother.
- By using a higher temperature, the differences between probabilities are less pronounced, allowing the student to learn more effectively from the teacher's outputs.
-
Training Procedure:
- The student model's training involves a combined objective: accuracy against hard targets and similarity to the teacher's soft targets.
- This usually combines cross-entropy loss (for hard targets) and KL divergence (for matching the student's predictions with the teacher's).
-
Iterative Optimization:
- The student model iteratively updates its parameters to improve alignment with both the hard targets and the teacher's predictions.
- Optimization is typically done using gradient descent or similar algorithms.
Knowledge Distillation transforms the bulky wisdom of a high-performing neural network into a nimble form, ready for efficient deployment. This technique ensures that the speed of inference no longer bounds the reach of AI, enabling sophisticated models to operate even on resource-constrained devices.

References
- Hinton, G. E., Osindero, S., & Teh, Y. W. (2006). A fast learning algorithm for deep belief nets. Neural Computation, 18(7), 1527-1554.
- LeCun, Y., Bengio, Y., & Hinton, G. (2015). Deep learning. Nature, 521(7553), 436-444.
- Han, S., Mao, H., & Dally, W. J. (2015). Deep compression: Compressing deep neural networks with pruning, trained quantization and Huffman coding. International Conference on Learning Representations.
- Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531.