import torch
from torchvision import models
import numpy as np
import pickle
import os
from tqdm import tqdm
from my_utils import *
from data_utils import *
from torchvision.transforms import GaussianBlur

blur = GaussianBlur(kernel_size=25, sigma=(4,4.5))

def core_spur_accuracy(model, core=True, noise_sigma=0.25, num_trials=5, apply_norm=True, noise_type='gaussian'):
    '''
    Core accuracy = Acc under noise in spurious regions; vice versa for spur acc
    Core regions are taken to be core masks under dilation
    '''
    dset = CustomDataSet('/REDACTED/salient_imagenet_dataset/test/', 
                         '/REDACTED/salient_imagenet_dataset/',
                         resize_size=224, split='test')

    # dset = CausalImageNet(split='all', mask_types=['spurious', 'core'], require_mask=True)
    loader = torch.utils.data.DataLoader(dset, batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                        std=[0.229, 0.224, 0.225])
    
    # cc_by_class, cnt_by_class = dict({i:0 for i in range(1000)}), dict({i:0 for i in range(1000)})
    cnt_by_class = dict({i:0 for i in range(1000)})
    core_cc_by_class, spur_cc_by_class = dict({i:0 for i in range(1000)}), dict({i:0 for i in range(1000)}), 
    # clean_by_class = dict({i:0 for i in range(1000)})

    if noise_sigma == 0 or noise_type == 'ablation':
        num_trials = 1

    ctr= 0
    for imgs, masks, labels in tqdm(loader):
        if noise_sigma > 0:
            masks = masks.cuda()
        imgs = imgs.cuda()
        labels = labels.cuda()

        idx_with_masks = (masks.flatten(1).sum(1) != 0)
        imgs, masks, labels = [x[idx_with_masks] for x in [imgs, masks, labels]]

        if noise_sigma > 0:
            masks = dilate_erode_fast(masks)

        # if core:
        #     masks = 1 - masks
        for trial in range(num_trials):

            if noise_sigma > 0:
                if noise_type == 'blur':
                    noise = blur(imgs)
                    noisy_core, noisy_spur = [torch.clamp((imgs * (1-x) + x * noise), 0, 1) for x in [masks, 1-masks]]
                elif noise_type == 'ablation': # greying
                    noise = torch.ones_like(imgs, device=imgs.device) * 0.5
                    # noise = torch.rand_like(imgs, device=imgs.device)
                    # noisy_core, noisy_spur = [torch.clamp(0.5*imgs + 0.5*(imgs * (1-x) + x * noise), 0, 1) for x in [masks, 1-masks]]
                    noisy_core, noisy_spur = [torch.clamp(imgs * (1-x) + x * noise, 0, 1) for x in [masks, 1-masks]]
                else: # default noise is Gaussian (randn)
                    noise = torch.randn_like(imgs, device=imgs.device) * noise_sigma
                    noisy_core, noisy_spur = [torch.clamp(imgs + (x * noise), 0, 1) for x in [masks, 1-masks]]
                if apply_norm:
                    noisy_core, noisy_spur = [normalize(x) for x in [noisy_core, noisy_spur]]
                noisy_core_preds, noisy_spur_preds = [model(x).argmax(1) for x in [noisy_core, noisy_spur]]
            else:
                if apply_norm:
                    imgs = normalize(imgs)
                preds = model(imgs).argmax(1)                
            # noise = l2_normalize(torch.randn_like(imgs) * masks) * noise_sigma
            # noisy_imgs = normalize(torch.clamp(imgs + noise, 0, 1)).cuda()
            
            # preds = model(noisy_imgs).argmax(1)

            for y in np.unique(labels.cpu().numpy()):
                # if y not in core_cc_by_class:
                #     core_cc_by_class[y], cnt_by_class[y] = 0, 0
                #     core_cc_by_class[y], cnt_by_class[y] = 0, 0
                if noise_sigma > 0:
                    core_cc_by_class[y] += (noisy_spur_preds[labels == y] == y).sum().item()
                    spur_cc_by_class[y] += (noisy_core_preds[labels == y] == y).sum().item()
                else:
                    core_cc_by_class[y] += (preds[labels == y] == y).sum().item()
                cnt_by_class[y] += (labels == y).sum().item()

                # core_cc_by_class[y] += (preds[labels == y] == y).sum().item()
                # cnt_by_class[y] += (labels == y).sum().item()

        # if ctr > 15:
        #     break
        # ctr += 1
        
    total_cnt, total_core_cc, total_spur_cc = 0, 0, 0
    core_acc_by_class, spur_acc_by_class = dict(), dict()
    for c in cnt_by_class:
        if cnt_by_class[c] == 0:
            continue
        total_core_cc += core_cc_by_class[c]
        total_spur_cc += spur_cc_by_class[c]
        total_cnt += cnt_by_class[c]
        core_acc_by_class[c] = core_cc_by_class[c] / cnt_by_class[c]
        spur_acc_by_class[c] = spur_cc_by_class[c] / cnt_by_class[c]

    # core_acc, spur_acc = [100. * x / total_cnt for x in [total_core_cc, total_spur_cc]]
    core_acc, spur_acc = [100.*np.average(list(x.values())) for x in [core_acc_by_class, spur_acc_by_class]]
    return core_acc, spur_acc, core_acc_by_class, spur_acc_by_class 

def load_model(key, phase=2, arch='resnet50', aug=False, verbose=False, trial=1):
    model = models.resnet50().cuda()
    model = torch.nn.DataParallel(model)
    # mpath = './results/trained_models3{}/{}_phase{}/model_best.pth.tar'.format('_aug' if aug else '', key, phase)
    # ckpt = torch.load('./results/trained_models/{}_phase2/model_best.pth.tar'.format(key))

    mpath = './results/trained_models{}/{}_phase{}/checkpoint_epoch{}.pth.tar'.format(
        '' if trial==1 else trial, key, phase, 15 if phase==2 else 25)
    ckpt = torch.load(mpath)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()
    acc = ckpt['best_prec1']
    if verbose:
        print('Loaded model {} with accuracy {:.2f} from: {}'.format(key, acc, mpath))

    return model, acc# ckpt['best_prec1']

def eval_models(keys, all_noise_levels=[0.25], aug=False, trial=1):
    fpath = './cached_results/from_scratch{}{}.pkl'.format('' if trial==1 else trial, '_aug' if aug else '')

    if os.path.exists(fpath):
        with open(fpath, 'rb') as f:
            results = pickle.load(f)
    else:
        results = dict()

    for k in keys:
        if k not in results:
            results[k] = dict()
        
        if 'phase3' in k:
            model, clean_acc = load_model(k[:k.find('_phase3')], phase=3, aug=aug)
        else:
            model, clean_acc = load_model(k, aug=aug, trial=trial)

        results[k]['clean_acc'] = clean_acc

        for sigma in all_noise_levels:
            if sigma not in results[k]:
                core_acc, spur_acc, core_acc_by_class, spur_acc_by_class = core_spur_accuracy(model, noise_sigma=sigma)
                # core_acc, core_acc_by_class = core_spur_accuracy(model, noise_sigma=sigma)
                # spur_acc, spur_acc_by_class = core_spur_accuracy(model, noise_sigma=sigma, core=False)
                rca = rel_score(core_acc / 100, spur_acc / 100)
                results[k][sigma] = dict({'core': core_acc, 'core_by_class': core_acc_by_class, 
                                        'spur': spur_acc, 'spur_by_class': spur_acc_by_class, 'rca': rca*100})
                
            with open(fpath, 'wb') as f:
                pickle.dump(results, f)
            
            print('Model: {:<30}, Clean Acc: {:.2f}, Core Acc: {:.2f}, Spur Acc: {:.2f}, RCA: {:.2f}'
                  .format(k, clean_acc, results[k][sigma]['core'], results[k][sigma]['spur'], results[k][sigma]['rca']))

def cnt_per_class():
    dset = CustomDataSet('/REDACTED/salient_imagenet_dataset/test/', 
                        '/REDACTED/salient_imagenet_dataset/',
                        resize_size=224, split='test')

    cnts_dict = dict({i:0 for i in range(1000)})
    loader = torch.utils.data.DataLoader(dset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
    # cnt, cnt2 = 0,0
    for imgs, masks, labels in tqdm(loader):
        idx_with_masks = (masks.flatten(1).sum(1) != 0)
        labels = labels[idx_with_masks]
        for y in np.unique(labels.numpy()):
            cnts_dict[y] += (labels == y).sum().item()
        # cnt += idx_with_masks.sum()
        # cnt2 += imgs.shape[0]
        # imgs, masks, labels = [x[idx_with_masks] for x in [imgs, masks, labels]]
    with open('./cached_results/cnts_per_class.pkl', 'wb') as f:
        pickle.dump(cnts_dict, f)

    for y in cnts_dict:
        if cnts_dict[y] == 0:
            print('No core masks for class {}'.format(_IMAGENET_CLASSNAMES[y]))



if __name__=='__main__':
    # cnt_per_class()
    # eval_models(['baseline_phase3', 'sal_reg_noise_0.25_half_phase3', 'noise_0.25_half_phase3', 'sal_reg_phase3'])#, aug=True)
    # eval_models(['sal_reg', 'noise_0.25_half'])
    # eval_models(['sal_reg_noise_0.25_half', 'sal_reg', 'noise_0.25_half', 'baseline']) 
    # eval_models(['sal_reg_1_noise_0.25_half', 'baseline', 'sal_reg', 'noise_0.25_half', 'sal_reg_phase3', 'noise_0.25_half_phase3', 'sal_reg_noise_0.25_half'])#'sal_reg_1_noise_0.25_half_phase3'])
    # eval_models(['sal_reg_noise_0.25_half_phase3', 'sal_reg_phase3', 'noise_0.25_half_phase3', 'baseline_phase3']) #'baseline', 
    # eval_models(['dilate_sal_reg_noise_0.25_half', 'dilate_sal_reg_noise_0.50_half', 'dilate_sal_reg_1', 'dilate30_sal_reg_noise_0.25_half'])
    
    eval_models(['baseline', 'baseline_aug_scale0.2','baseline_aug_scale0.5','baseline_aug', 'mixup', 'cutmix', 'mixup_cutmix'], trial=1)
    # eval_models(['aug_baseline', 'aug_sal_reg_1_noise_0.25_half', 'aug_sal_reg_1', 'aug_noise_0.25_half'], aug=True)

    # for trial in [1,2,3,4,5]:
    #     eval_models(['sal_reg_1', 'sal_reg_1_noise_0.25_half', 'baseline', 'noise_0.25_half'], trial=trial)

    # eval_models(['noise_0.25_half', 'sal_reg', 'sal_reg_noise_0.25_half', 'baseline', 
    #             'sal_reg_phase3', 'baseline_phase3', 'sal_reg_noise_0.25_half_phase3',  
    #             'sal_reg_1_noise_0.25_half', 'sal_reg_1_noise_0.25_half_phase3', 'sal_reg_1', 'sal_reg_1_phase3', 'noise_0.25_half_phase3'])

    # eval_models(['baseline_aug', 'baseline', 'sal_reg_noise_0.25_half', 'dilate_sal_reg_noise_0.25_half','dilate30_sal_reg_noise_0.25_half', 'noise_0.25_half',
    #                 'noise_0.25', 'noise_0.50', 'extra_noise_0.25', 'joint_noise_0.25', 'sal_reg', 'sal_reg_1'])
    # baseline = torch.nn.DataParallel(models.resnet50(pretrained=True).cuda().eval())
    # core_acc = core_spur_accuracy(baseline)
    # spur_acc, _ = core_spur_accuracy(baseline, core=False)
    # clean_acc, _ = core_spur_accuracy(baseline, noise_sigma=0, num_trials=1)
    # print('Model: {:<30}, Clean Acc: {:.2f}, Core Acc: {:.2f}, Spur Acc: {:.2f}, RCA: {:.2f}'
    #         .format('Baseline', clean_acc, core_acc, spur_acc, rel_score(core_acc/100, spur_acc/100)*100))
