Model quantization investigations¶
quantization.py provides a few functions used in this notebook.
In the following notebook, we'll quantize the weights of a distilled model Unet576k (2.2MB to 0.5kB). In practive, we choose a 8bit quantization to achieve good performances.
The results are shown in the 8bits case. Running the notebook again to see the extreme case with 4bits which gets degraded!
Feel free to read the section on Quantization in the Weights and biases report
🔎🔎🔎 Weights and biases report 🔎🔎🔎¶
In [2]:
from quantization import get_array_size_in_bytes, dequantize_weights, dequantize_weights_per_layer
from model_quantization import quantize_model_per_layer
from infer import load_model
import torch
from math import sqrt, ceil
import numpy as np
from matplotlib import pyplot as plt
from metrics import compute_metrics
from tqdm import tqdm
from shared import (
ACCURACY, PRECISION, RECALL, F1_SCORE, IOU,
VALIDATION, TEST, TRAIN,
DEVICE
)
from pathlib import Path
import pandas as pd
%load_ext autoreload
%autoreload 2
device = DEVICE
from evaluate import evaluate_model, evaluate_test_mode
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
In [3]:
exp = 1004
num_bits=8 # If you want to change the number of bits, you need to reload the model
# you can use 3 bits , it still works kind of correctly
model, dl_dict, model_config = load_model(exp)
original_model, _, _ = load_model(exp, get_data_loaders_flag=False)
TOTAL ELEMENTS 7737 TOTAL ELEMENTS 1935 TOTAL ELEMENTS 2538
In [ ]:
# labeled_dict = evaluate_test_mode(original_model, dl_dict, save_path=Path(f'__submission_{exp:04d}_dataset_update.csv'))
In [4]:
evaluate_model(original_model, dl_dict)
100%|██████████| 16/16 [00:20<00:00, 1.31s/it]
Metrics on validation set {'accuracy': 0.9750800728797913, 'precision': 0.8122708797454834, 'recall': 0.8545903563499451, 'dice': 0.7893725037574768, 'iou': 0.7154967784881592}
Out[4]:
({'accuracy': 0.9750800728797913, 'precision': 0.8122708797454834, 'recall': 0.8545903563499451, 'dice': 0.7893725037574768, 'iou': 0.7154967784881592}, [])
In [5]:
params = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
params = params.detach().cpu().numpy()
print(len(params), "=", model.count_parameters(), "->", get_array_size_in_bytes(params), "Bytes")
quantized_weights, quantization_parameters = quantize_model_per_layer(model, num_bits=num_bits)
params_dequant = dequantize_weights_per_layer(quantized_weights, quantization_parameters)
# Reinject dequantized weights into the model
for name, param in model.named_parameters():
if name in params_dequant:
param.data = torch.nn.Parameter(torch.from_numpy(params_dequant[name])).to(device=device)
print(name, "has been updated with quantized weights")
2.20 Mb 576721 = 576721 -> 2306884 Bytes 0.55 Mb 2.20 Mb compression ratio = 4.004 encoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights encoder_list.0.conv_block.conv_stage.1.conv.weight has been updated with quantized weights encoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights encoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.2.conv_block.conv_stage.1.conv.weight has been updated with quantized weights bottleneck.conv_stage.0.conv.weight has been updated with quantized weights refinement_stage.conv_stage.0.conv.weight has been updated with quantized weights
In [6]:
# Quantized model
evaluate_model(model, dl_dict)
100%|██████████| 16/16 [00:05<00:00, 2.92it/s]
Metrics on validation set {'accuracy': 0.975077748298645, 'precision': 0.8139002919197083, 'recall': 0.8537895679473877, 'dice': 0.7895622849464417, 'iou': 0.7158442735671997}
Out[6]:
({'accuracy': 0.975077748298645, 'precision': 0.8139002919197083, 'recall': 0.8537895679473877, 'dice': 0.7895622849464417, 'iou': 0.7158442735671997}, [])
In [7]:
layer_key= "conv_in_modality.conv_h.weight"
layer_key = list(original_model.named_parameters())[0][0]
model_params_dict = dict(original_model.named_parameters())
params_no_qant = model_params_dict[layer_key].flatten().detach().cpu().numpy()
# Back to the original weights
plt.figure(figsize=(20,5))
plt.plot(params_no_qant, label="Original")
plt.plot(params_dequant[layer_key].flatten(), ".", label="Dequantized")
plt.legend()
plt.grid()
plt.title(f"Quantization result {layer_key}")
plt.show()
plt.figure(figsize=(20,5))
plt.hist(params_no_qant-params_dequant[layer_key].flatten(), bins=100, label="Error")
plt.legend()
plt.title(f"Quantization Error {layer_key}")
plt.grid()
plt.show()
In [19]:
img, label = next(iter(dl_dict[VALIDATION]))
with torch.no_grad():
output = original_model(img)
output_with_qant = model(img)
selected_index = 18 # image to pick from the first validation batch
plt.figsize = (10, 10)
plt.subplot(2, 2, 1)
plt.imshow(torch.sigmoid(output[selected_index, 0, ...]).detach().cpu().numpy())
plt.title(f'Probability prediction \noriginal model = {exp}')
plt.subplot(2, 2, 2)
plt.imshow(torch.sigmoid(output_with_qant[selected_index, 0, ...]).detach().cpu().numpy())
plt.title(f"Output with quantized weights on {num_bits} bits")
plt.subplot(2, 2, 3)
plt.title("Input")
plt.imshow(img[selected_index, 0, ...].detach().cpu().numpy().astype(float), cmap="gray")
plt.subplot(2, 2, 4)
plt.title("Ground truth")
plt.imshow(label[selected_index, 0, ...].detach().cpu().numpy().astype(float))
plt.show()
In [20]:
plt.title("Ground truth")
plt.imshow(label[selected_index, 0, ...].detach().cpu().numpy().astype(float))
plt.show()
In [21]:
plt.figsize = (10, 10)
plt.imshow(torch.sigmoid(output[selected_index, 0, ...]).detach().cpu().numpy()>0.5)
plt.title("Binary Prediction mask")
plt.show()
In [22]:
error = torch.sigmoid(output_with_qant)-torch.sigmoid(output)
error = error[:, 0, ...].detach().cpu().numpy()
plt.figure(figsize=(20, 20))
n = int(sqrt(error.shape[0])+0.5)
for idx in range(error.shape[0]):
plt.subplot(n, ceil(error.shape[0]/n), idx+1)
plt.title(f"{np.abs(error)[idx].mean():.2%}")
plt.imshow(error[idx])
plt.suptitle(f"Probability difference due to quantization error - {num_bits} bits")
plt.show()
Global weights distribution¶
In [24]:
original_params = torch.cat([p.flatten() for p in original_model.parameters() if p.requires_grad])
original_params = original_params.detach().cpu().numpy()
In [25]:
plt.hist(original_params, bins=1000)
plt.yscale('log')
plt.ylabel('log count')
plt.xlabel('parameter value')
plt.grid()
plt.title('Parameter distribution before quantization - all layers mixed')
plt.show()
Need for per-layer quantization¶
The following graph shows that if we perform global model quantization (same scaling for all weights, we will loose a lot of precision as each layer's weight have a slightly different dynamic).
In [26]:
from matplotlib import pyplot as plt
tot = len([1 for _ in original_model.named_parameters()])
plt.figure(figsize=(10, tot//2*5))
for idx, (name, param) in enumerate(original_model.named_parameters()):
if 'bias' not in name:
plt.subplot(tot//2, 2, idx//2 * 2 + 1)
else:
plt.subplot(tot//2, 2, idx//2 * 2 + 2)
if param.requires_grad:
layer_params = param.detach().cpu().numpy().flatten()
plt.hist(layer_params, bins=100, density=True, alpha=1, label=name)
# plt.title(f'Histogram for Layer: {name}')
plt.title('Histogram for weight' if 'weight' in name else 'Histogram for bias')
plt.yscale('log')
plt.ylabel('log count')
plt.xlabel('parameter value')
plt.legend()
plt.xlim(-1., 1.)
plt.grid()
plt.suptitle('Histogram of weights and biases for each layer')
plt.show()