from pathlib import Path
from typing import Union
import json
import yaml

import torch


def fetch_checkpoints_by_params(checkpoint_dir: Union[str, Path], hparams_filename: str = 'hparams.json', checkpoint_filename_ext: str = 'pth.tar', verbose: bool = False, search_dict: dict = None, exclude_dict: dict = {}, **kwargs):

    checkpoint_dir = Path(checkpoint_dir)

    if search_dict is not None:
        assert isinstance(search_dict, dict), f"A dictionary should be provided for `search_dict` but {type(search_dict)} type was given"
    else:
        assert kwargs, "Either search_dict must be set or target parameters should be passed as positional arguments."
        search_dict = {}

    search_dict.update(kwargs)

    if verbose:
        print("Searching for the following parameters...")
        print(search_dict.items())


    match_list = []
    for cur_dir in checkpoint_dir.iterdir():

        try:
            with open(cur_dir / hparams_filename) as fn:
                hparams: dict = yaml.safe_load(fn)
            
            if search_dict.items() <= hparams.items():

                # check if excluded items are present
                if exclude_dict and (exclude_dict.items() <= hparams.items()):
                    continue

                checkpoint_files = cur_dir.glob(f"**/*.{checkpoint_filename_ext}")
                match_list.extend(checkpoint_files)

        except FileNotFoundError as e:
            pass

    return match_list



def weights_update(model, checkpoint):

    with torch.no_grad():
        # get the current model state dict
        model_dict: dict = model.state_dict()

        # Pytorch lightning (PL) saves model state dict with `model.` prefix. Remove this prefix.
        checkpoint_state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint
        checkpoint_state_dict = {k.replace('model.', ''): v for k, v in checkpoint_state_dict.items()}

        # PL also saves other PL Module parameters in the state dict. Only consider model parameters.
        pretrained_dict = {k: v for k, v in checkpoint_state_dict.items() if k in model_dict}

        overridden_params = list(pretrained_dict.keys())
        if len(overridden_params) < 10:
            print('The following model parameters will be overridden from the checkpoint state:\t' + '\t'.join(overridden_params))
        else:
            print('{} paramaters will be overridden from the checkpoint state'.format(len(overridden_params)))

        # update the model from the proper state dict
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)