Model quantization investigations¶

quantization.py provides a few functions used in this notebook.

In the following notebook, we'll quantize the weights of a distilled model Unet576k (2.2MB to 0.5kB). In practive, we choose a 8bit quantization to achieve good performances.

The results are shown in the 8bits case. Running the notebook again to see the extreme case with 4bits which gets degraded!

Feel free to read the section on Quantization in the Weights and biases report

🔎🔎🔎 Weights and biases report 🔎🔎🔎¶

In [2]:
from quantization import get_array_size_in_bytes, dequantize_weights, dequantize_weights_per_layer
from model_quantization import quantize_model_per_layer
from infer import load_model
import torch
from math import sqrt, ceil
import numpy as np
from matplotlib import pyplot as plt
from metrics import compute_metrics
from tqdm import tqdm
from shared import (
    ACCURACY, PRECISION, RECALL, F1_SCORE, IOU,
    VALIDATION, TEST, TRAIN,
    DEVICE
)
from pathlib import Path
import pandas as pd
%load_ext autoreload
%autoreload 2
device = DEVICE
from evaluate import evaluate_model, evaluate_test_mode
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [3]:
exp = 1004
num_bits=8 # If you want to change the number of bits, you need to reload the model
# you can use 3 bits , it still works kind of correctly
model, dl_dict, model_config = load_model(exp)
original_model, _, _ = load_model(exp, get_data_loaders_flag=False)
TOTAL ELEMENTS 7737
TOTAL ELEMENTS 1935
TOTAL ELEMENTS 2538
In [ ]:
# labeled_dict = evaluate_test_mode(original_model, dl_dict, save_path=Path(f'__submission_{exp:04d}_dataset_update.csv'))
In [4]:
evaluate_model(original_model, dl_dict)
100%|██████████| 16/16 [00:20<00:00,  1.31s/it]
Metrics on validation set
{'accuracy': 0.9750800728797913, 'precision': 0.8122708797454834, 'recall': 0.8545903563499451, 'dice': 0.7893725037574768, 'iou': 0.7154967784881592}

Out[4]:
({'accuracy': 0.9750800728797913,
  'precision': 0.8122708797454834,
  'recall': 0.8545903563499451,
  'dice': 0.7893725037574768,
  'iou': 0.7154967784881592},
 [])
In [5]:
params = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
params = params.detach().cpu().numpy()
print(len(params), "=", model.count_parameters(), "->", get_array_size_in_bytes(params), "Bytes")
quantized_weights, quantization_parameters = quantize_model_per_layer(model, num_bits=num_bits)
params_dequant = dequantize_weights_per_layer(quantized_weights, quantization_parameters)

# Reinject dequantized weights into the model
for name, param in model.named_parameters():
    if name in params_dequant:
        param.data = torch.nn.Parameter(torch.from_numpy(params_dequant[name])).to(device=device)
        print(name, "has been updated with quantized weights")
2.20 Mb
576721 = 576721 -> 2306884 Bytes
0.55 Mb
2.20 Mb
compression ratio = 4.004
encoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
encoder_list.0.conv_block.conv_stage.1.conv.weight has been updated with quantized weights
encoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
encoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
decoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
decoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
decoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights
decoder_list.2.conv_block.conv_stage.1.conv.weight has been updated with quantized weights
bottleneck.conv_stage.0.conv.weight has been updated with quantized weights
refinement_stage.conv_stage.0.conv.weight has been updated with quantized weights
In [6]:
# Quantized model
evaluate_model(model, dl_dict)
100%|██████████| 16/16 [00:05<00:00,  2.92it/s]
Metrics on validation set
{'accuracy': 0.975077748298645, 'precision': 0.8139002919197083, 'recall': 0.8537895679473877, 'dice': 0.7895622849464417, 'iou': 0.7158442735671997}

Out[6]:
({'accuracy': 0.975077748298645,
  'precision': 0.8139002919197083,
  'recall': 0.8537895679473877,
  'dice': 0.7895622849464417,
  'iou': 0.7158442735671997},
 [])
In [7]:
layer_key= "conv_in_modality.conv_h.weight"
layer_key = list(original_model.named_parameters())[0][0]
model_params_dict = dict(original_model.named_parameters())
params_no_qant = model_params_dict[layer_key].flatten().detach().cpu().numpy()
# Back to the original weights

plt.figure(figsize=(20,5))
plt.plot(params_no_qant, label="Original")
plt.plot(params_dequant[layer_key].flatten(), ".", label="Dequantized")
plt.legend()
plt.grid()
plt.title(f"Quantization result {layer_key}")
plt.show()
plt.figure(figsize=(20,5))
plt.hist(params_no_qant-params_dequant[layer_key].flatten(), bins=100, label="Error")
plt.legend()
plt.title(f"Quantization Error {layer_key}")
plt.grid()
plt.show()
No description has been provided for this image
No description has been provided for this image
In [19]:
img, label = next(iter(dl_dict[VALIDATION]))
with torch.no_grad():
    output = original_model(img)
    output_with_qant = model(img)
selected_index = 18  # image to pick from the first validation batch


plt.figsize = (10, 10)
plt.subplot(2, 2, 1)
plt.imshow(torch.sigmoid(output[selected_index, 0, ...]).detach().cpu().numpy())
plt.title(f'Probability prediction \noriginal model = {exp}')
plt.subplot(2, 2, 2)
plt.imshow(torch.sigmoid(output_with_qant[selected_index, 0, ...]).detach().cpu().numpy())
plt.title(f"Output with quantized weights on {num_bits} bits")

plt.subplot(2, 2, 3)
plt.title("Input")
plt.imshow(img[selected_index, 0, ...].detach().cpu().numpy().astype(float), cmap="gray")

plt.subplot(2, 2, 4)
plt.title("Ground truth")
plt.imshow(label[selected_index, 0, ...].detach().cpu().numpy().astype(float))
plt.show()
No description has been provided for this image
In [20]:
plt.title("Ground truth")
plt.imshow(label[selected_index, 0, ...].detach().cpu().numpy().astype(float))
plt.show()
No description has been provided for this image
In [21]:
plt.figsize = (10, 10)
plt.imshow(torch.sigmoid(output[selected_index, 0, ...]).detach().cpu().numpy()>0.5)
plt.title("Binary Prediction mask")
plt.show()
No description has been provided for this image
In [22]:
error = torch.sigmoid(output_with_qant)-torch.sigmoid(output)
error = error[:, 0, ...].detach().cpu().numpy()
plt.figure(figsize=(20, 20))
n = int(sqrt(error.shape[0])+0.5)
for idx in range(error.shape[0]):
    plt.subplot(n, ceil(error.shape[0]/n), idx+1)
    plt.title(f"{np.abs(error)[idx].mean():.2%}")
    plt.imshow(error[idx])
plt.suptitle(f"Probability difference due to quantization error - {num_bits} bits")
plt.show()
No description has been provided for this image

Global weights distribution¶

In [24]:
original_params = torch.cat([p.flatten() for p in original_model.parameters() if p.requires_grad])
original_params = original_params.detach().cpu().numpy()
In [25]:
plt.hist(original_params, bins=1000)
plt.yscale('log')
plt.ylabel('log count')
plt.xlabel('parameter value')
plt.grid()
plt.title('Parameter distribution before quantization - all layers mixed')
plt.show()
No description has been provided for this image

Need for per-layer quantization¶

The following graph shows that if we perform global model quantization (same scaling for all weights, we will loose a lot of precision as each layer's weight have a slightly different dynamic).

In [26]:
from matplotlib import pyplot as plt
tot = len([1 for _ in original_model.named_parameters()])
plt.figure(figsize=(10, tot//2*5))
for idx, (name, param) in enumerate(original_model.named_parameters()):
    if 'bias' not in name:
        plt.subplot(tot//2, 2, idx//2 * 2 + 1)
    else:
        plt.subplot(tot//2, 2, idx//2 * 2 + 2)
    if param.requires_grad:
        layer_params = param.detach().cpu().numpy().flatten()
        plt.hist(layer_params, bins=100, density=True, alpha=1, label=name)
        # plt.title(f'Histogram for Layer: {name}')
        plt.title('Histogram for weight' if 'weight' in name else 'Histogram for bias')
        plt.yscale('log')
        plt.ylabel('log count')
        plt.xlabel('parameter value')
        plt.legend()
        plt.xlim(-1., 1.)
        plt.grid()
plt.suptitle('Histogram of weights and biases for each layer')
plt.show()
No description has been provided for this image