import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as model_zoo
import numpy as np 
from resnet_simclr import get_resnet
from robust_resnet import get_robust_resnet50
from debiased_resnet import Model
from transformers import ViTForImageClassification

## wrapper around vit from hugging face
class LogitsViT(nn.Module):

    def __init__(self, original_model):
        super(LogitsViT, self).__init__()
        self.original_model = original_model

    def forward(self, *args, **kwargs):
        original_output = self.original_model(*args, **kwargs)
        return original_output.logits


    
def get_model_func(args, model):

    if args.dataset == 'camelyon17':
        n_classes = 2
    elif args.dataset == 'waterbird':
        n_classes = 2
    elif args.dataset == 'oh-65cls':
        n_classes = 65
    elif args.dataset == 'cifar-10':
        n_classes = 2
    else:
        raise NotImplementedError(f"Missing implementation for dataset {args.dataset}")
    

    if model == 'resnet50':
        def m_f():
            m = model_zoo.resnet50(pretrained=args.pretrained)
            d = m.fc.in_features
            m.fc = nn.Linear(d, n_classes)
            return m.to(args.device)
        return m_f
    elif model == 'resnet50_np':
        def m_f():
            m = model_zoo.resnet50(pretrained=False)
            d = m.fc.in_features
            m.fc = nn.Linear(d, n_classes)
            return m.to(args.device)
        return m_f         
    elif model == 'resnet18':
        def m_f():
            m = model_zoo.resnet18(pretrained=args.pretrained)
            d = m.fc.in_features
            m.fc = nn.Linear(d, n_classes)
            return m.to(args.device)
        return m_f
    elif model == "vit_b_16":
        def m_f():
            pretrained = None if not(args.pretrained) else 'IMAGENET1K_V1'
            m = model_zoo.vit_b_16(weights=pretrained)
            m.heads = nn.Linear(in_features=768, out_features=n_classes, bias=True)
            return m.to(args.device)
        
        return m_f
    elif model == "vit_mae":
        def m_f():
            model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base")
            model.classifier = nn.Linear(in_features=768, out_features=n_classes, bias=True)
            model = LogitsViT(model)
            return model.to(args.device)
        
        return m_f
    elif model == "vit_dino":
        def m_f():
            model = ViTForImageClassification.from_pretrained("facebook/dino-vitb16")
            model.classifier = nn.Linear(in_features=768, out_features=n_classes, bias=True)
            model = LogitsViT(model)
            return model.to(args.device)
        
        return m_f
    elif model == "resnet50SwAV":
        def m_f():
            model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
            d = model.fc.in_features
            model.fc = nn.Linear(d, n_classes)
            return model.to(args.device)
        return m_f
    elif model == "resnet50MocoV2":
        def m_f():
            raise NotImplementedError("checkpoint not downloaded")
            ## https://dl.fbaipublicfiles.com/moco/moco_checkpoints/moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar
            state = torch.load("PATH_TO_CHECKPOINT")
            new_state = {k.replace("module.encoder_q.",""):v for k,v in state["state_dict"].items()}
            for i in ["0","2"]:
                new_state.pop(f"fc.{i}.bias")
                new_state.pop(f"fc.{i}.weight")

            model = model_zoo.resnet50(pretrained=False)
            d = model.fc.in_features
            model.load_state_dict(new_state, strict=False)
            model.fc = nn.Linear(d, n_classes)
            return model.to(args.device)
        
        return m_f
    elif model == "resnet50SIMCLRv2":
        def m_f():
            raise NotImplementedError("checkpoint not downloaded")
        ## download the r50_1x_sk0.pth checkpoint using the code at https://github.com/Separius/SimCLRv2-Pytorch
            model, _ = get_resnet(depth=50, width_multiplier=1, sk_ratio=0)
            state = torch.load("PATH_TO_CHECKPOINT")
            model.load_state_dict(state["resnet"])
            d = model.fc.in_features
            model.fc = nn.Linear(d, n_classes)
            return model.to(args.device)
        return m_f

    elif model == "robust_resnet50":
        def m_f():
            raise NotImplementedError("checkpoint not downloaded")
                ## download the r50_1x_sk0.pth checkpoint at https://robustnessws4285631339.blob.core.windows.net/public-models/robust_imagenet/resnet50_l2_eps0.05.ckpt?sv=2020-08-04&ss=bfqt&srt=sco&sp=rwdlacupitfx&se=2051-10-06T07:09:59Z&st=2021-10-05T23:09:59Z&spr=https,http&sig=U69sEOSMlliobiw8OgiZpLTaYyOA5yt5pHHH5%2FKUYgI%3D
            robust = get_robust_resnet50()
            state = torch.load("PATH_TO_CHECKPOINT")
            new_state = {}
            for k in state["model"]:
                if "attacker" not in k:
                    new_state [k.replace("module.","")] = state["model"][k]
            robust.load_state_dict(new_state)
            d = robust.model.fc.in_features
            robust.model.fc = nn.Linear(d, n_classes)
            return robust.to(args.device)
        return m_f
    else:
        raise NotImplementedError(f"Missing implemntation for model '{model}'.")
    







