from my_utils import *
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
import torch.nn as nn
from ood_eval import ObjectNet
import json
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from torchvision.utils import make_grid
'''
We will compare ftr activations for CC and MC (correctly classified and misclassified) samples.
The higher a ftr's avg activation for CC imgs is rel to avg activation for MC imgs, the more reliable the ftr is.

So our pipeline is:
1. Obtain a data loader (we will do this over various distribution shifts)
2. Per class, compute activations on annotated ftrs for CC and MC imgs (to and from that class).
3. Plot 'ftr reliability' (which is the same as before: relative difference in avg activation for CC and MC) for all classes
'''

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

def compute_annotated_ftr_activations(loader, dset_name, label_map=None):
    '''
    ARGS
        dset: dataset object to eval. We will look into a few OOD sets
        label_map: dict mapping available classes (as labelled in dset) to imgnet indices

    RETURNS a dictionary d such that... 
        d[c][f]['cc'] contains list of activations on feature f for samples correctly classified to class c
        d[c][f]['mc_source'] contains a similar list for samples belonging to class c but misclassified as a diff class
        d[c][f]['mc_target'] contains a similar list for samples NOT belonging to class c but misclassified to c
        d[c]['reliability_source'] contains avg rel diff of avgs of d[c][f]['cc'] and 

    eventually what I want is the relative difference between avgs of d[c][f]['cc'] and d[c][f]['mc'] (both types separately I guess)
    i'll need by_class_dict to compute that, which I will already have here ... maybe I should just compute it
    '''
    by_class_dict = load_cached_results('/cmlscratch/mmoayeri/analysis_causal_imagenet/meta/ftr_types/by_class_dict.pkl')

    model = load_robust_resnet('robust_resnet50_l2_eps3').eval().cuda()
    ftr_encoder = nn.Sequential(*list(model.children())[:-1])
    fc = model.fc

    if label_map is None:
        label_map = dict({i:[i] for i in range(1000)})
    cls_inds = list(label_map.keys())


    all_preds, all_ftrs, all_ys = [], [], []
    print('Computing features and predictions')
    for i, (x,y) in enumerate(tqdm(loader)):
        x = normalize(x.cuda())
        ftrs = ftr_encoder(x).flatten(1)
        preds = fc(ftrs).argmax(1)
        
        all_preds.extend(preds.detach().cpu().numpy())
        all_ftrs.extend(ftrs.detach().cpu().numpy())
        all_ys.extend(y.numpy())

        # if len(all_ys) > 2500:
        #     break


    all_preds, all_ftrs, all_ys = [np.array(x) for x in [all_preds, all_ftrs, all_ys]]
    print('Sanity check... Accuracy: {:.2f}%'.format(100.*sum([p in label_map[y] for y,p in zip(all_ys, all_preds)]) / all_ys.shape[0]))

    core_ftr_diffs, spur_ftr_diffs = [dict({'source':[],'target':[]}) for i in range(2)]
    ftr_vals, dset_idx = dict(), dict()
    print('Extracting and organizing annotated feature values')
    for c in tqdm(cls_inds):
        cc_idx = np.array([i for i,(y,p) in enumerate(zip(all_ys, all_preds)) if y==c and p in label_map[c]])
        target_mc_idx = np.array([i for i,(y,p) in enumerate(zip(all_ys, all_preds)) if y!=c and p in label_map[c]])
        source_mc_idx = np.array([i for i,(y,p) in enumerate(zip(all_ys, all_preds)) if y==c and p not in label_map[c]])

        ftr_idx = [[],[]]
        for i, x in enumerate(['core', 'spurious']):
            for inet_c in label_map[c]:
                ftr_idx[i].extend(by_class_dict[inet_c][x])
        core_ftrs, spur_ftrs = ftr_idx
        
        if len(spur_ftrs) == 0 or len(core_ftrs) == 0 or len(cc_idx) == 0 or len(source_mc_idx) == 0 or len(target_mc_idx) == 0:
            continue

        dset_idx[c] = dict({'cc': cc_idx, 'mc_source': source_mc_idx, 'mc_target': target_mc_idx})
        ftr_vals[c] = dict()
        for ftr_idx, running_diffs in zip([core_ftrs, spur_ftrs], [core_ftr_diffs, spur_ftr_diffs]):
            source_rel_diffs, target_rel_diffs = [], []
            for f in ftr_idx:
                ftr_vals[c][f] = dict()
                ftr_vals[c][f]['cc'], ftr_vals[c][f]['mc_source'], ftr_vals[c][f]['mc_target'] = [all_ftrs[idx, f] for idx in [cc_idx, source_mc_idx, target_mc_idx]]

                avg_cc, avg_mc_s, avg_mc_t = [np.average(x) for x in [ftr_vals[c][f]['cc'], ftr_vals[c][f]['mc_source'], ftr_vals[c][f]['mc_target']]]
                source_rel_diffs.append((avg_cc - avg_mc_s) / max(avg_cc, avg_mc_s))
                target_rel_diffs.append((avg_cc - avg_mc_t) / max(avg_cc, avg_mc_t))

            running_diffs['source'].append(np.average(source_rel_diffs))
            running_diffs['target'].append(np.average(target_rel_diffs))

    core_ftr_diffs, spur_ftr_diffs = [dict({k:np.array(v) for k,v in x.items()}) for x in [core_ftr_diffs, spur_ftr_diffs]]
    cls_inds_kept = list(ftr_vals.keys())

    print('CORE Ftrs. % classes with negative rel diff for... source: {:.2f}%, target: {:.2f}%'.format(*[100.*sum(core_ftr_diffs[x]<0)/len(cls_inds_kept) for x in ['source', 'target']]))
    print('SPUR Ftrs. % classes with negative rel diff for... source: {:.2f}%, target: {:.2f}%'.format(*[100.*sum(spur_ftr_diffs[x]<0)/len(cls_inds_kept) for x in ['source', 'target']]))
    

    for k in ['source', 'target']:
        num_core_spur_disagree_classes = sum([c*s < 0 for c,s in zip(core_ftr_diffs[k], spur_ftr_diffs[k])])
        help_pct = sum([c < 0 and s > 0 for c,s in zip(core_ftr_diffs[k], spur_ftr_diffs[k])]) / num_core_spur_disagree_classes * 100
        hurt_pct = sum([c > 0 and s < 0 for c,s in zip(core_ftr_diffs[k], spur_ftr_diffs[k])]) / num_core_spur_disagree_classes * 100
        print('{}....Conflicting classes: {:.2f} ({}/{}), Spur ftrs help: {:.2f}%, hurt: {:.2f}%'.format(
            k.upper(), num_core_spur_disagree_classes / len(cls_inds_kept)*100, num_core_spur_disagree_classes, len(cls_inds_kept), help_pct, hurt_pct))

    save_dict = dict({
        'avg_ftr_diffs': dict({'core':core_ftr_diffs, 'spur': spur_ftr_diffs, 'cls_idx': cls_inds_kept}),
        'ftr_vals': ftr_vals, 'dset_idx': dset_idx
    })
    cache_results(f'./results/reliability/{dset_name}.pkl', save_dict)

### some auxiliary code: getting label maps
def get_label_map(dsetname):
    if dsetname == 'in_r':
        label_map = dict({i:[v] for i,v in enumerate([1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 
                                105, 107, 113, 122, 125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 
                                178, 187, 195, 199, 203, 207, 208, 219, 231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 
                                263, 265, 267, 269, 276, 277, 281, 288, 289, 291, 292, 293, 296, 299, 301, 308, 309, 310, 311, 314, 
                                315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347, 353, 355, 361, 362, 365, 366, 367, 
                                368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447, 448, 457, 462, 
                                463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 
                                613, 617, 621, 629, 637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 
                                820, 824, 833, 847, 852, 866, 875, 883, 889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 
                                947, 948, 949, 951, 953, 954, 957, 963, 965, 967, 980, 981, 983, 988])})
    elif dsetname == 'objectnet':
        with open(f'{_OBJECTNET_ROOT}/mappings/inet_id_to_onet_id.json', 'r') as f:
            objnet_map_raw = json.load(f)
        label_map = dict()
        for inet_c, objnet_c in objnet_map_raw.items():
            if objnet_c not in label_map:
                label_map[objnet_c] = []
            label_map[objnet_c].append(int(inet_c))
    else:
        label_map = dict({i:[i] for i in range(1000)})
    return label_map


def eval_many_dsets():
    t = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
    inet = datasets.ImageNet(root=_IMAGENET_ROOT, split='val', transform=t)
    sketch = datasets.ImageFolder(root=_SKETCH_ROOT, transform=t)
    in_r = datasets.ImageFolder(root=_IMAGENET_R_ROOT, transform=t)
    objnet = ObjectNet(normalize=None)

    dsets = [sketch, in_r, objnet]
    dsetnames = ['sketch', 'in_r', 'objectnet']
    label_maps = [get_label_map(d) for d in dsetnames]
    

    for dset, dset_name, label_map in zip(dsets, dsetnames, label_maps):
        print(dset_name.upper())
        loader = DataLoader(dset, num_workers=16, shuffle=False, batch_size=32)
        compute_annotated_ftr_activations(loader, dset_name, label_map)
        print()

### Plotting Results
def help_hurt_targets(dsets=['imagenet', 'sketch', 'in_r', 'objectnet']):
    f, axs = plt.subplots(1,len(dsets), figsize=(4*len(dsets), 4.25))
    dset_to_nickname = dict({'imagenet':'ImageNet', 'sketch': 'ImageNet Sketch', 'in_r': 'ImageNet-R', 'objectnet':'ObjectNet'})

    for ax, dset in zip(axs, dsets):
        d = load_cached_results(f'results/reliability/{dset}.pkl')
        core_ftr_diffs, spur_ftr_diffs = [d['avg_ftr_diffs'][x]['target'] for x in ['core', 'spur']]
        
        sns.kdeplot(spur_ftr_diffs, core_ftr_diffs, ax=ax, shade=True, fmt='.1f', levels=20)
        r1 = plt.Rectangle((-1,0), 1,1, color='red', alpha=0.2)
        r2 = plt.Rectangle((0,-1), 1,1, color='limegreen', alpha=0.2)
        ax.add_patch(r1); ax.add_patch(r2)
        ax.set_xlim([-1,1]); ax.set_ylim([-1,1])
        ticks = ticklabels = np.round(np.arange(-2, 3) / 2, 1)
        ax.set_yticks(ticks); ax.set_yticklabels(ticks)
        ax.set_xlabel('Spurious Feature Reliability', fontsize=13); ax.set_ylabel('Core Feature Reliability', fontsize=13)
        spur_unreliable_pct, core_unreliable_pct = [sum(x > 0) / len(x) *100 for x in [spur_ftr_diffs, core_ftr_diffs]]
        ax.set_title(f'Dataset: {dset_to_nickname[dset]}\nCore Ftr Reliability Rate: {core_unreliable_pct:.1f}%\nSpur Ftr Reliability Rate: {spur_unreliable_pct:.1f}%',
            fontsize=13)
        
        num_disagree_classes = sum([c*s < 0 for c,s in zip(core_ftr_diffs, spur_ftr_diffs)])

        help_pct = sum([c<0 and s>0 for c,s in zip(core_ftr_diffs, spur_ftr_diffs)]) 
        hurt_pct = sum([c>0 and s<0 for c,s in zip(core_ftr_diffs, spur_ftr_diffs)])
        ax.text(-0.95, 0.85, f'{hurt_pct}/{num_disagree_classes}')
        ax.text(0.65, -0.9, f'{help_pct}/{num_disagree_classes}')

    f.tight_layout(); f.savefig('plots/reliability_kde.jpg', dpi=300, pad_inches=0.05)

def view_most_helpful_and_hurtful_cases():
    '''
    For each dataset, let's take three examples of classes where spur ftrs help and hurt
    Generate a 2x5 subplot for that dset, where we show the imgs w/ highest core/spur ftr vals among CC and MC
    '''
    t = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
    inet = datasets.ImageNet(root=_IMAGENET_ROOT, split='val', transform=t)
    sketch = datasets.ImageFolder(root=_SKETCH_ROOT, transform=t)
    in_r = datasets.ImageFolder(root=_IMAGENT_R_ROOT, transform=t)
    objnet = ObjectNet(normalize=None)

    dset_to_nickname = dict({'imagenet':'ImageNet', 'sketch': 'ImageNet Sketch', 'in_r': 'ImageNet-R', 'objectnet':'ObjectNet'})
    dsetnames = list(dset_to_nickname.keys())
    dsets = [inet, sketch, in_r, objnet]
    label_maps = [get_label_map(d) for d in dsetnames]

    by_class_dict = load_cached_results('ftr_types/by_class_dict.pkl')

    to_pil = transforms.ToPILImage()

    for dset, dsetname, label_map in zip(dsets, dsetnames, label_maps):
        if dsetname not in  ['objectnet', 'in_r']:
            continue
        d = load_cached_results(f'results/reliability/{dsetname}.pkl')
        ftr_vals, dset_idx = [d[x] for x in ['ftr_vals', 'dset_idx']]
        # ftr_diffs ~ reliability: high value means CC samples activate ftr much more than MC ones
        core_ftr_diffs, spur_ftr_diffs = [np.array(d['avg_ftr_diffs'][x]['target']) for x in ['core', 'spur']]
        kept_cls_idx = np.array(d['avg_ftr_diffs']['cls_idx'])

        # find class most hurt by spurious features = largest gap bw core and spur ftr reliability
        sort_most_to_least_hurt = np.argsort(spur_ftr_diffs-core_ftr_diffs) # sorts ascending, so first idx = smallest spur_ftr_diff and largest core_ftr_diff
        most_hurt_cls_idx, most_helped_cls_idx = [kept_cls_idx[x] for x in [sort_most_to_least_hurt[:6], sort_most_to_least_hurt[-3:]]]

        # now for the most hurt classes, I want the MC_target sample w/ highest avg spur ftr val and lowest avg core ftr val
        # for the most helped classes, I want to find a CC sample w/ highest avg spur ftr val and lowest avg core ftr val
        # for f in dset
        f, axs_raw = plt.subplots(3,3,figsize=(16,10))
        axs = axs_raw.ravel()
        _ = [axi.set_axis_off() for axi in axs.ravel()]
        fig_ctr = 0
        for cls_idx, k, help_or_hurt in zip([most_hurt_cls_idx, most_helped_cls_idx], ['mc_target', 'cc'], ['Hurt', 'Help']):
            grids, cls_names = [], []
            for c in cls_idx:
                ftr_idx = [[],[]]
                for i, x in enumerate(['core', 'spurious']):
                    for inet_c in label_map[c]:
                        ftr_idx[i].extend(by_class_dict[inet_c][x])
                core_ftrs, spur_ftrs = ftr_idx

                avg_core_ftr_vals, avg_spur_ftr_vals = [np.average([ftr_vals[c][f][k] for f in ftrs], 0) for ftrs in [core_ftrs, spur_ftrs]]
                gaps = avg_spur_ftr_vals - avg_core_ftr_vals
                # we want the biggest gap when looking for most hurt samples and smallest gap for most helped
                # biggest_gap_idx = np.argsort(gaps if help_or_hurt == 'Help' else -1*gaps)
                biggest_gap_idx = np.argsort(-1*gaps)# if help_or_hurt == 'Help' else -1*gaps)

                most_hurt_sample_idx = biggest_gap_idx[:6]
                # print(most_hurt_sample_idx)
                imgs = [dset[dset_idx[c][k][i]][0] for i in most_hurt_sample_idx]
                ys = [dset[dset_idx[c][k][i]][1] for i in most_hurt_sample_idx]
                grids.append(to_pil(make_grid(imgs, nrow=3)))
                cls_names.append(['\n'+x if i % 3 ==0 else x for i,x in enumerate([imagenet_classes[label_map[y][0]].title() for y in ys])])
                

                # some cool examples
                if dsetname == 'objectnet' and 'computer mouse' in [imagenet_classes[inet_c] for inet_c in label_map[c]]:
                    f_cool, ax_cool = plt.subplots(1,1, figsize=(3,3.2))
                    ax_cool.imshow(imgs[1].numpy().swapaxes(0,1).swapaxes(1,2))
                    ax_cool.set_title('Class: Banana\nPred: Computer Mouse', fontsize=14, color='orangered')
                    ax_cool.set_axis_off()
                    f_cool.tight_layout(); f_cool.savefig('plots/reliability/objectnet_banana.jpg', dpi=300, bbox_inches='tight', pad_inches=0.05)
                if dsetname == 'in_r' and 'lawn mower' in [imagenet_classes[inet_c] for inet_c in label_map[c]]:
                    f_cool, ax_cool = plt.subplots(1,1, figsize=(3,3.2))
                    ax_cool.imshow(imgs[4].numpy().swapaxes(0,1).swapaxes(1,2))
                    ax_cool.set_title('Class: Lawn Mower\nPred: Lawn Mower $\checkmark$', fontsize=14, color='darkgreen')
                    ax_cool.set_axis_off()
                    f_cool.tight_layout(); f_cool.savefig('plots/reliability/in_r_lawn_mower.jpg', dpi=300, bbox_inches='tight', pad_inches=0.05)


            for g, c, curr_cls_names in zip(grids, cls_idx, cls_names):
                axs[fig_ctr].imshow(g)
                axs[fig_ctr].set_title(f"Spur Ftrs {help_or_hurt} for {imagenet_classes[label_map[c][0]].title()} Class. True Classes: {', '.join(curr_cls_names)}")
                fig_ctr += 1
        
        # f.tight_layout(); f.savefig(f'plots/reliability/egs_{dsetname}.jpg', dpi=300)


if __name__ == '__main__':
    # eval_many_dsets()
    # help_hurt_targets()
    # source_histograms()
    view_most_helpful_and_hurtful_cases()
    # gen_plots()