import logging
import torch
from timm import create_model
from rich import print as pp
from models import vision_transformer_method as vitm

def get_network(args):
    classifier_type = args.classifier_type.lower()
    is_pretrained_imagenet = args.is_pretrained_imagenet

    if classifier_type in vitm.__dict__.keys():
        net = vitm.__dict__[args.classifier_type](
            pretrained=is_pretrained_imagenet,
            num_classes=args.num_classes,
            drop_rate=0.0,
            drop_path_rate=args.drop_path_rate,
            trans_fg=args.trans_fg
        )
    else:
        net = create_model(
            args.classifier_type,
            num_classes=args.num_classes,
            pretrained=is_pretrained_imagenet)

    logger = logging.getLogger('main')
    logger.info(f"[Info] Building model: {args.classifier_type} ")

    if args.initial_checkpoint is not None:
        if classifier_type in vitm.__dict__.keys():
            state_dict = torch.load(args.initial_checkpoint)
            if "teacher" in state_dict:
                pp(f"Take key {'teacher'} in provided checkpoint dict")
                state_dict = state_dict["teacher"]
            state_dict = {k.replace("module.", ""): v
                          for k, v in state_dict.items()}
            state_dict = {k.replace("backbone.", ""): v
                          for k, v in state_dict.items()}
        else:
            state_dict = torch.load(args.initial_checkpoint)

        if 'net' in state_dict.keys():
            for k in list(state_dict['net'].keys()):
                if k.startswith('module'):
                    state_dict['net'][k[len("module."):]] = state_dict['net'][k]
                    del state_dict['net'][k]
            miss_key = net.load_state_dict(state_dict['net'], strict=False)
        elif 'state_dict' in state_dict.keys():
            miss_key = net.load_state_dict(state_dict['state_dict'], strict=False)
        else:
            try:
                miss_key = net.load_state_dict(state_dict, strict=False)
            except RuntimeError:
                pp(f"[red] [Warning] Error(s) in loading"
                   "state_dict for current model!!! [/red]")
                model_dict = net.state_dict()
                new_state_dict = {}
                for k, v in zip(model_dict.keys(), state_dict.values()):
                    if v.size() == model_dict[k].size():
                        new_state_dict[k] = v
                    else:
                        pp(f"[red] [Warning] size mismatch for {k}."
                           f"Dropping parameter {k} from checkpoint. [/red]")
                        new_state_dict[k] = model_dict[k]
                miss_key = net.load_state_dict(new_state_dict, strict=False)
        pp("[red] [Warning] missing key : {}".format(miss_key))
        pp(f"[green] [Info] Loaded from checkpoint {args.initial_checkpoint} [/green]")

    return net

