import timm
import torch
import clip
from my_utils import *
import torchvision
import vits
import resnet_wider
from eval import *
from ood_eval import *
from tqdm import tqdm
from paths import *


class DinoWrapper(torch.nn.Module):
    def __init__(self, mkey):
        super(DinoWrapper, self).__init__()
        self.dino_encoder = torch.hub.load('facebookresearch/dino:main', mkey)
        self.resnet = ('resnet' in mkey)
        if self.resnet:
            linweights_key = mkey
        else:
            linweights_key = mkey
            if 's' in linweights_key:
                linweights_key = linweights_key.replace('s', 'small')
                linweights_key = linweights_key.replace('vit', 'deit')
            elif 'b' in linweights_key:
                linweights_key = linweights_key.replace('b', 'base')

        linear_weights = torch.load('{}/dino_lin_weights/{}_linearweights.pth'.format(_DINO_ROOT, linweights_key))
        # init linear layer
        weight, bias = [linear_weights['state_dict']['module.linear.{}'.format(x)] for x in ['weight', 'bias']]
        state_dict = {'weight': weight, 'bias': bias}
        self.linear_layer = torch.nn.Linear(in_features=weight.shape[1], out_features=weight.shape[0], bias=True)
        self.linear_layer.load_state_dict(state_dict)
        
        self.n = 1 if 'base' in mkey else 4

    def forward(self, x):
        if self.resnet:
            tens = self.dino_encoder(x)
        else:
            io = self.dino_encoder.get_intermediate_layers(x, self.n)
            tens = torch.cat([y[:, 0] for y in io], dim=-1)
        out = self.linear_layer(tens)
        return out

class ClipZeroShot(torch.nn.Module):
    ''' Now supports Finetuned clip as well '''
    def __init__(self, mtype):
        super(ClipZeroShot, self).__init__()
        self.zero_shot = ('finetuned' not in mtype)
        mtype = mtype if '/finetuned' not in mtype else mtype[:-10]
        self.clip_model, self.clip_preprocess = clip.load(mtype)
        self.to_pil = torchvision.transforms.ToPILImage()
        self.mtype = mtype.replace('/', '_')
        if self.zero_shot:
            with open('{}/clip_zeroshot_weights/{}.pkl'.format(_CLIP_ROOT, self.mtype), 'rb') as f:
                self.zeroshot_weights = pickle.load(f)
        else:
            suffix = '{}{}'.format('b' if 'b' in mtype.lower() else 'rn', mtype[-2:])
            print('Loading finetuned weights')
            lin_weights = torch.load('{}/clip_linear_head_weights/{}.pth'.format(_CLIP_ROOT, suffix))
            in_ftrs = lin_weights['weight'].shape[1]
            self.ft_linear_layer = torch.nn.Linear(in_features=in_ftrs, out_features=1000, bias=True).cuda()
            self.ft_linear_layer.load_state_dict(lin_weights)
        
    def forward(self, img):
        img  = torch.stack([self.clip_preprocess(self.to_pil(img[i])) for i in range(img.shape[0])]).cuda()
        image_features = self.clip_model.encode_image(img)
        if self.zero_shot:
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = 100. * image_features @ self.zeroshot_weights#.to(image_features.device)
        else:
            logits = self.ft_linear_layer(image_features.float())
        return logits

def eval_models(model_names, eval_fn, results_path='results/pretrained_models2.pkl'):

    results = load_cached_results(results_path)

    for mkey in tqdm(model_names):
    # for mkey in model_names:
        apply_norm = True
        if 'timm' in mkey:
            m = mkey[5:]
            model = timm.create_model(m, pretrained=True)
        elif 'dino' in mkey:
            model = DinoWrapper(mkey)
            m = mkey
        elif 'clip' in mkey:
            m = mkey[5:].replace('_', '/')
            model = ClipZeroShot(m)
            apply_norm=False
        elif 'robust' in mkey:
            model = load_robust_resnet(mkey)
            m = mkey
        elif 'pytorch_resnet50' in mkey:
            m = mkey
            model = torchvision.models.resnet50(pretrained=True)
        elif 'moco' in mkey:
            if 'vit' in mkey:
                arch = 'vit_small' if 'vit-s' in mkey else 'vit_base'
                model = vits.__dict__[arch]()
                m = 'moco_{}'.format(arch)
                key = mkey[5:] + '-'
            else:
                model = torchvision.models.resnet50()
                m = 'moco_resnet50'
                key = ''
            ckpt = torch.load('/cmlscratch/mmoayeri/models/moco_weights/linear-{}300ep.pth.tar'.format(key))
            og_states = ckpt['state_dict']
            state_dict = dict({k[len('module.'):]:og_states[k] for k in og_states})
            model.load_state_dict(state_dict)
        elif 'simclr' in mkey:
            m = mkey
            model = resnet_wider.__dict__[mkey[7:]]()
            ckpt = torch.load('/cmlscratch/mmoayeri/models/simclr_weights/pytorch/resnet50-{}x.pth'.format(m[-1]))
            model.load_state_dict(ckpt['state_dict'])
            apply_norm=False

        model = model.eval().cuda()

        if m not in results:
            results[m] = dict()

        raw, model_results = eval_fn(model, apply_norm, results[m])

        results[m] = model_results
        cache_results(results_path, results)
        
        print(f"Model: {m:<40}, Raw Stat: {100.*raw:.3f}")

def ood_eval_models(model_list):
    for dset, eval_fn in zip(['sketch', 'objectnet', 'imagenet-r'], [eval_on_sketch, eval_on_objectnet, eval_on_imagenet_r]):
        results_path = f'./results/ood/{dset}.pkl'
        eval_models(model_list, results_path, eval_fn)
        print(f'Finished with {dset} eval')

def spurious_gap_eval(model_list):
    ''' Counterfactual eval via checking accuracy on in painted images '''
    results_path = './results/new_spurious_gap.pkl'
    eval_fn = lambda model, apply_norm, results: eval_spurious_gap(model=model, apply_norm=apply_norm, results=results)
    eval_models(model_list, eval_fn, results_path)


if __name__=='__main__':
    model_list = get_test_suite()
    spurious_gap_eval(model_list)