import torch
from torchvision import models, transforms
import numpy as np

import os
import json
import argparse

import my_utils
from load_datasets import load_breeds_dataset, load_celeba_age, load_celeba_blond, load_waterbirds, load_imagenet

import os
import matplotlib.pyplot as plt
import torchvision
import seaborn as sns

sns.set_theme(style="darkgrid")
plt.rcParams['text.usetex'] = True

def make_plots(i1, i2, t1, t2, name, path):
    os.makedirs(path, exist_ok=True)
    
    grid1 = torchvision.utils.make_grid(i1, nrow=4, pad_value=1)
    grid2 = torchvision.utils.make_grid(i2, nrow=4, pad_value=1)
    
    img1 = torchvision.transforms.ToPILImage()(grid1)
    img2 = torchvision.transforms.ToPILImage()(grid2)
    
    fig, ax = plt.subplots(2, 1)
    ax[0].set_title(t1, fontsize=14)
    ax[0].imshow(img1)
    
    ax[1].set_title(t2, fontsize=14)
    ax[1].imshow(img2)
    
    ax[0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    ax[1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    
    fig.savefig(f'{path}/{name}.pdf', bbox_inches='tight', pad_inches=0.1)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num', type=int, required=True)
    parser.add_argument('--token', type=int, required=True)
    parser.add_argument('--drop', type=int, required=True)
    parser.add_argument('--dset', type=str, required=True)
    args = parser.parse_args()

    min_samples_node = args.num
    accuracy_drop = args.drop
    num_tokens = args.token
    dset_key = args.dset
    
    dset_classes = my_utils.dset_classes[dset_key]
    
    path_to_clip_ftrs = f'dsets/{dset_key}/clip_reps_vit_b_16.npy'
    path_to_bin_labels = f'dsets/{dset_key}/bin_labels.npy'
    path_to_labels = f'dsets/{dset_key}/labels.npy'
    path_to_failure_modes = f'dsets/{dset_key}/failure_modes_n{min_samples_node}_t{num_tokens}_d{accuracy_drop}/'
    # path_to_failure_modes = f'dsets/{dset_key}/failure_modes_t{num_tokens}_d{accuracy_drop}/'
    
    labels = np.load(path_to_labels)
    
    transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    
    if dset_key == 'waterbirds':
        dset = load_waterbirds(transform=transform)
    elif dset_key == 'celeba_blond':
        dset = load_celeba_blond(transform=transform)
    elif dset_key == 'celeba_age':
        dset = load_celeba_age(transform=transform)
    elif dset_key == 'imagenet':
        dset = load_imagenet(transform=transform)
    else:
        dset, _ = load_breeds_dataset(dset_key, transform, transform, )
        
    
    with open(f'{path_to_failure_modes}failure_modes.json', 'r') as f:
        dt = json.loads(f.read())
    
    np.random.seed(31)
    
    gs, ds = [], []
    
    for i, X in enumerate(dt):
        class_name = X['key']
        failure_models = X['value']
        accuracy_test = X['accuracy_val']
        
        print(class_name)
        
        # if class_name != 'bear':
        #     continue
        
        # if class_name != 'landbird':
        #     continue
        
        
        for j, mode in enumerate(failure_models):
            fm_inds = np.array(mode['train_mc_inds'] + mode['val_mc_inds'])
            
            desc = mode['description']
            
            # if desc != 'grass; stand; field; dry; ' and desc != 'white; zoo; ':
            #     continue
            
            # if desc != 'stand; water; boat; ' and desc != 'water; person; ':
            #     continue
            
            # if desc != 'black; climb; tree branch; ' and desc != 'fence; cage; ':
            #     continue
            
            # if desc != 'catch; stand; ' and desc != 'woman; ':
            #     continue
            
            # if desc != 'lush; grass; grassy; field; ' and desc != 'pigeon; ':
            #     continue
            
            
            if desc != 'diver; underwater; aquarium; ' and desc != 'pigeon; ':
                continue
            
        
            
            f = class_name + '\\ (' + accuracy_test + '\%) $\\rightarrow$ \\ '
            for t in desc.split(';'):
                t = t.strip()
                if len(t) == 0:
                    continue
                print(t)
                f += '\\textbf{' + t + '}; '
            
            f += '(' + mode['accuracy_val'] + '\%)'
            
            print(f)
            
            idx = np.random.choice(fm_inds, 4, replace=False)
            imgs = []
            for id in idx:
                imgs.append(dset[id][0])

            gs.append(imgs)
            ds.append(f)
            
        print('-------------\n')
    
    make_plots(gs[0], gs[1], ds[0], ds[1], 'parrot_first_fig', 'paper_plots')
    # make_plots(gs[0], gs[1], ds[0], ds[1], 'landbirds_first_fig', 'paper_plots')


