"""
Helpers to load checkpoints for learned feature maps (attentions) or other parameters
"""
import torch
import torch.nn as nn

from src.utils.logging import print_header
from .convert_model import convert_attention


def load_and_convert_attns(model: nn.Module,
                           model_config: dict,
                           attention_type: str = None,
                           checkpoint_path: str = None,
                           print_model: bool = False,
                           train_attention: bool = False,  # Should be true if converting attentions for first time,
                           freeze_weights: bool = True,
                           rank: int = 0,
                          ) -> nn.Module:
    """
    Load trained attention kernel parameter weights
    """
    if freeze_weights:
        for p in model.parameters():
            p.requires_grad = False

    if attention_type is not None:  # override default
        model_config['attention']['attention_type'] = attention_type
    model_config['attention']['rank'] = rank   # multi-gpu debugging

    model = convert_attention(model, model_config['attention'], 
                              train_attention)


    if print_model and rank == 0:  # Look at model
        print_header('*** Model before checkpoint load ***')
        print(model)

    # Load any trained attentions
    if checkpoint_path is not None:
        print(f'Loading weights from {checkpoint_path}...')
        state_dict = torch.load(checkpoint_path)['model_state_dict']
        _keys = model.load_state_dict(state_dict, strict=False)
        try:
            assert len(_keys.unexpected_keys) == 0
            if rank == 0:
                print_header('*** All expected keys matched successfully ***')
                if print_model:
                    for k in state_dict.keys():
                        print(k)
        except Exception as e:
            if rank == 0:
                print(e)
                print_header('*** Error: unexpected keys in checkpoint ***')
                print('Unexpected keys:')
                for k in _keys.unexpected_keys:
                    print(k)
    if print_model and rank == 0:  # Look at model
        print_header('*** Model ***')
        print(model)
    if print_model and rank == 0:  # Look at model
        print_header('*** Trainable Parameters ***')
        for n, p in model.named_parameters():
            if p.requires_grad:
                print(f'├── {n} (dtype = {p.dtype})')
    return model
