import torch
import torch.nn as nn
import torchvision.models as models
import timm
import os


def set_train_model(args):
    if args.arch == 'resnet50':
        if args.method == 'sup':
            model = timm.create_model('resnet50',
                                        pretrained=True,
                                        num_classes=args.num_classes)
        elif args.method == 'moco':
            model = models.resnet50(pretrained=False)
            checkpoint = torch.load(args.ssl_ckpt_path)["state_dict"]
            for k in list(checkpoint.keys()):
            # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
                # remove prefix
                    checkpoint[k[len("module.encoder_q."):]] = checkpoint[k]
            # delete renamed or unused k
            del checkpoint[k]
            msg = model.load_state_dict(checkpoint, strict=False)
            assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        elif args.method == 'swav':
            model = models.resnet50(pretrained=False)
            state_dict = torch.load(args.ssl_ckpt_path, map_location="cpu")
            if "state_dict" in state_dict:
                state_dict = state_dict["state_dict"]
            # remove prefixe "module."
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            print("swav load")
            msg = model.load_state_dict(state_dict, strict=False)
        elif args.method == 'byol':
            model = models.resnet50(pretrained=False)
            state_dict = torch.load(args.ssl_ckpt_path, map_location='cuda')['online_backbone']
            state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
            msg = model.load_state_dict(state_dict, strict=False)
            assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

        for name, param in model.named_parameters():
            if name not in ['fc.weight', 'fc.bias']:
                param.requires_grad = False
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, args.num_classes)

    elif args.arch == 'vit-s':
        model = timm.create_model('vit_small_patch16_224',
                                  pretrained=False,
                                  num_classes=args.num_classes)
        if args.method == 'sup':
            model = timm.create_model('deit_small_patch16_224',
                                      pretrained=True,
                                      num_classes=args.num_classes)
        elif args.method == 'dino':
            checkpoint = torch.load(args.ssl_ckpt_path)
            msg = model.load_state_dict(checkpoint, strict=False)
            assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
        elif args.method == 'ibot':
            checkpoint = torch.load(args.ssl_ckpt_path)
            msg = model.load_state_dict(checkpoint['state_dict'], strict=False)
            assert set(msg.missing_keys) == {'head.weight', 'head.bias'}

        for name, param in model.named_parameters():
            if name not in ['head.weight', 'head.bias']:
                param.requires_grad = False
    return model


def set_eval_model(args):
    if args.arch == 'resnet50':
        model = timm.create_model('resnet50',
                                    num_classes=args.num_classes)
    elif args.arch == 'vit-s':
        model = timm.create_model('vit_small_patch16_224',
                                  num_classes=args.num_classes)

    if os.path.isfile(args.pretrained_path):
        print("=> loading checkpoint '{}'".format(args.pretrained_path))
        checkpoint = torch.load(args.pretrained_path, map_location="cpu")
        state_dict = checkpoint['state_dict']

        for k in list(state_dict.keys()):
            if k.startswith('module.'):
                # remove prefix
                state_dict[k[len("module."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
        msg = model.load_state_dict(state_dict, strict=False)
        # msg = model.fc.load_state_dict(state_dict, strict=False)
        print(msg)
        print("=> loaded pre-trained model '{}'".format(args.pretrained_path))

    return model
