Model distillation and quantization evaluation¶
This notebook showcases the evaluation and comparison between several models:
- Teacher: UNET 6.4M
exp 53
-> 24.6Mb on disk - Student: UNET 576k
exp 1004
-> 2.2Mb on disk - Student: UNET 576k + Per layer Quantiztion 8bits -> 400kB weights on disk
Feel free to readn the Weights and biases report
🔎🔎🔎 Weights and biases report 🔎🔎🔎¶
In [1]:
from infer import load_model
from metrics import compute_metrics
from shared import (
ACCURACY, PRECISION, RECALL, F1_SCORE, IOU,
VALIDATION, TEST, TRAIN,
DEVICE, DISTILLATION
)
import torch
import pandas as pd
from pathlib import Path
%load_ext autoreload
%autoreload 2
device = DEVICE
from quantization import dequantize_weights_per_layer
from model_quantization import quantize_model_per_layer
from evaluate import evaluate_model, evaluate_test_mode, visualize_performance_per_well, get_global_metrics_str, compare_performance_per_well
def get_n_param_str(model_config:dict)->str:
nb_param=model_config["model"]["n_params"]
if nb_param<1E3:
return f'{nb_param:.0f} parameters'
elif nb_param<1E6:
return f'{nb_param/1E3:.1f}k parameters'
else:
return f'{nb_param/1E6:.1f}M parameters'
In [2]:
# exp_list = [53, 1004]
# exp_list = [1004]
checkpoint = 'best_model'
comparison_list = [(53, None), (1004, None), (1004, 8)]
# (experiment , quantization in bits - if None = no quantization)
In [3]:
results_dict = {}
# for exp, quantization_num_bits in zip(exp_list, [None, 8]):
for exp, quantization_num_bits in comparison_list:
model, dl_dict, model_config = load_model(exp, batch_size=16, model_name=checkpoint+".pt")
if quantization_num_bits is not None:
print("Quantizing model...")
params = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
params = params.detach().cpu().numpy()
quantized_weights, quantization_parameters = quantize_model_per_layer(model, num_bits=quantization_num_bits)
params_dequant = dequantize_weights_per_layer(quantized_weights, quantization_parameters)
for name, param in model.named_parameters():
if name in params_dequant:
param.data = torch.nn.Parameter(torch.from_numpy(params_dequant[name])).to(device=device)
print(name, "has been updated with quantized weights")
metrics_validation, detailed_metrics_validation = evaluate_model(model, dl_dict, phase=VALIDATION, detailed_metrics_flag=True)
metrics_train, detailed_metrics_train =evaluate_model(model, dl_dict, phase=TRAIN, detailed_metrics_flag=True)
df_train = pd.DataFrame(detailed_metrics_train)
df_validation = pd.DataFrame(detailed_metrics_validation)
model_name = model_config['model']['name'] + ' - ' + str(exp) + " " + checkpoint
model_name += f" - quantized {quantization_num_bits}" if quantization_num_bits is not None else ""
results_dict[f"{exp:04d}" + (f"_quantized {quantization_num_bits}" if quantization_num_bits is not None else "")] = (metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name)
TOTAL ELEMENTS 7737 TOTAL ELEMENTS 1935 TOTAL ELEMENTS 2538
100%|██████████| 121/121 [00:21<00:00, 5.60it/s]
Metrics on validation set {'accuracy': 0.9747447371482849, 'precision': 0.7753747701644897, 'recall': 0.9004101157188416, 'dice': 0.7825585603713989, 'iou': 0.7083650231361389}
100%|██████████| 484/484 [01:07<00:00, 7.15it/s]
Metrics on train set {'accuracy': 0.9785500764846802, 'precision': 0.8026665449142456, 'recall': 0.9166544675827026, 'dice': 0.8157516121864319, 'iou': 0.744350254535675} TOTAL ELEMENTS 7737 TOTAL ELEMENTS 1935 TOTAL ELEMENTS 2538
100%|██████████| 121/121 [00:02<00:00, 47.16it/s]
Metrics on validation set {'accuracy': 0.9750799536705017, 'precision': 0.8122709393501282, 'recall': 0.8545907139778137, 'dice': 0.7893723845481873, 'iou': 0.7154967784881592}
100%|██████████| 484/484 [00:12<00:00, 37.36it/s]
Metrics on train set {'accuracy': 0.9769202470779419, 'precision': 0.8295949697494507, 'recall': 0.8614419102668762, 'dice': 0.8057383894920349, 'iou': 0.7301884889602661} TOTAL ELEMENTS 7737 TOTAL ELEMENTS 1935 TOTAL ELEMENTS 2538 Quantizing model... 0.55 Mb 2.20 Mb compression ratio = 4.004 encoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights encoder_list.0.conv_block.conv_stage.1.conv.weight has been updated with quantized weights encoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights encoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.0.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.1.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.2.conv_block.conv_stage.0.conv.weight has been updated with quantized weights decoder_list.2.conv_block.conv_stage.1.conv.weight has been updated with quantized weights bottleneck.conv_stage.0.conv.weight has been updated with quantized weights refinement_stage.conv_stage.0.conv.weight has been updated with quantized weights
100%|██████████| 121/121 [00:02<00:00, 45.12it/s]
Metrics on validation set {'accuracy': 0.9750775098800659, 'precision': 0.8139001131057739, 'recall': 0.8537893891334534, 'dice': 0.7895625233650208, 'iou': 0.7158442139625549}
100%|██████████| 484/484 [00:12<00:00, 39.07it/s]
Metrics on train set {'accuracy': 0.9769366383552551, 'precision': 0.8303963541984558, 'recall': 0.8607764840126038, 'dice': 0.805854856967926, 'iou': 0.730309009552002}
In [5]:
list_df_plot = []
label_list = []
chosen_metric = F1_SCORE
selected_exp_list = list(results_dict.keys())
extra_title = ""
for exp in selected_exp_list:
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
list_df_plot.append(df_train)
label_list.append(
model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {get_n_param_str(model_config)}')
extra_title+= "\n"+model_name + " " + get_n_param_str(model_config) + "\n" + get_global_metrics_str(metrics_train)
compare_performance_per_well(list_df_plot, label_list=label_list, chosen_metric=chosen_metric, title=f"Per well TRAINING {chosen_metric} score"+extra_title)
list_df_plot = []
label_list = []
extra_title = ""
for exp in selected_exp_list:
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
list_df_plot.append(df_validation)
label_list.append(
model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {get_n_param_str(model_config)}'
)
extra_title+= "\n"+model_name + " " + get_n_param_str(model_config) + "\n" + get_global_metrics_str(metrics_validation)
compare_performance_per_well(list_df_plot, label_list=label_list, chosen_metric=chosen_metric, title=f"Per well VALIDATION {chosen_metric} score"+extra_title)
Quantization effect only¶
In [ ]:
results_dict = {}
# for exp, quantization_num_bits in zip(exp_list, [None, 8]):
for exp, quantization_num_bits in [(1004, None), (1004, 8)]:
model, dl_dict, model_config = load_model(exp, batch_size=16, model_name=checkpoint+".pt")
if quantization_num_bits is not None:
print("Quantizing model...")
params = torch.cat([p.flatten() for p in model.parameters() if p.requires_grad])
params = params.detach().cpu().numpy()
quantized_weights, quantization_parameters = quantize_model_per_layer(model, num_bits=quantization_num_bits)
params_dequant = dequantize_weights_per_layer(quantized_weights, quantization_parameters)
for name, param in model.named_parameters():
if name in params_dequant:
param.data = torch.nn.Parameter(torch.from_numpy(params_dequant[name])).to(device=device)
print(name, "has been updated with quantized weights")
metrics_validation, detailed_metrics_validation = evaluate_model(model, dl_dict, phase=VALIDATION, detailed_metrics_flag=True)
metrics_train, detailed_metrics_train =evaluate_model(model, dl_dict, phase=TRAIN, detailed_metrics_flag=True)
df_train = pd.DataFrame(detailed_metrics_train)
df_validation = pd.DataFrame(detailed_metrics_validation)
model_name = model_config['model']['name'] + ' - ' + str(exp) + " " + checkpoint
model_name += f" - quantized {quantization_num_bits}" if quantization_num_bits is not None else ""
results_dict[f"{exp:04d}" + (f"_quantized {quantization_num_bits}" if quantization_num_bits is not None else "")] = (metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name)
In [50]:
selected_exp_list = [f"{exp:04d}" for exp in exp_list]
chosen_metric = F1_SCORE
selected_exp_list = list(results_dict.keys())
for exp in selected_exp_list:
list_df_plot = []
label_list = []
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
list_df_plot.append(df_train)
label_list.append(model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {model_config["model"]["n_params"]/1000:.1f}k parameters'+ " train")
list_df_plot.append(df_validation)
label_list.append(model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {model_config["model"]["n_params"]/1000:.1f}k parameters'+ " Validation")
compare_performance_per_well(list_df_plot, label_list=label_list, chosen_metric=chosen_metric)
In [55]:
list_df_plot = []
label_list = []
for exp in selected_exp_list:
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
list_df_plot.append(df_train)
label_list.append(
model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {model_config["model"]["n_params"]/1000:.1f}k parameters'+ " train",
)
compare_performance_per_well(list_df_plot, label_list=label_list, title=f"{chosen_metric} per well TRAINING", chosen_metric=chosen_metric)
list_df_plot = []
label_list = []
for exp in selected_exp_list:
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
list_df_plot.append(df_validation)
label_list.append(
model_name + (' -Distilled' if model_config.get(DISTILLATION, False) else "") + f' {model_config["model"]["n_params"]/1000:.1f}k parameters'+ " Validation",
)
compare_performance_per_well(list_df_plot, label_list=label_list, title=f"{chosen_metric} per well VALIDATION", chosen_metric=chosen_metric)
In [5]:
chosen_metrics=[F1_SCORE,]
# chosen_metrics = [PRECISION, RECALL]
for exp in exp_list:
metrics_validation, detailed_metrics_validation, metrics_train, detailed_metrics_train, model_config, df_train, df_validation, model_name = results_dict[exp]
visualize_performance_per_well(df_train, chosen_metrics=chosen_metrics, title=f'Metrics per Well (Train) \n {model_name} \n {get_global_metrics_str(metrics_train)}')
visualize_performance_per_well(df_validation, chosen_metrics=chosen_metrics, title=f'Metrics per Well (Validation) \n{model_name} \n {get_global_metrics_str(metrics_validation)}')
### -> Test 1 3 4 5
In [ ]:
# Export submission for test set https://challengedata.ens.fr/participants/challenges/144/
for exp in exp_list:
model, dl_dict, model_config = load_model(exp, batch_size=16, model_name=checkpoint+".pt")
labeled_dict = evaluate_test_mode(model, dl_dict, save_path=Path(f'__submission_{exp:04d}_dataset_update.csv'))