import torch

from argparse import ArgumentParser
import os.path as path
import os
from torchvision.utils import save_image
import torchvision.transforms as transforms

import pandas as pd
from optimization_requirements import visualize_objectives_v3, save_top_class_info, get_class_labels_dict
from torchvision import datasets, models, transforms

import utils
from tqdm import tqdm
import matplotlib.pyplot as plt
from scripts.analyze_activations_by_channel import get_overall_kt_results, load_in_activations_from_paths , get_kt_from_vectors
#import get_imagenet_activations_by_channel

import torch.nn.functional as F
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
#TODO: Generate metrics for validation sets as well.
#This script is for AFTER running results generator
def load_in_metrics(metrics_folder):
    metrics_dict = {
        'kt_ii' : torch.load(path.join(metrics_folder, 'kt_ii.pt')),
        'kt_if' : torch.load(path.join(metrics_folder, 'kt_if.pt')),
        #'kt_ii_val' : torch.load(path.join(metrics_folder, 'val_kt_ii.pt')),
        #'kt_if_val' : torch.load(path.join(metrics_folder, 'val_kt_if.pt')),

        'clip_sims_ii' : torch.load(path.join(metrics_folder, 'full_clip_ii.pt')),
        'clip_sims_if' : torch.load(path.join(metrics_folder, 'full_clip_if.pt')),
        #'clip_sims_ii_val' : torch.load(path.join(metrics_folder, 'val_full_clip_ii.pt')),
        #'clip_sims_if_val' : torch.load(path.join(metrics_folder, 'val_full_clip_if.pt')),
        }
    return metrics_dict

def get_clip_wam_normalized(clip_sims_ii, clip_sims_if):
    n = clip_sims_ii.shape[0]
    wam_scoresii, wam_indicesii = get_whackamole_score_from_tensor(clip_sims_ii)
    wam_scoresif, wam_indicesif = get_whackamole_score_from_tensor(clip_sims_if)

    clip_wams_normalized = torch.zeros(n)
    for i in range(n):
        clip_wams_normalized[i] = wam_scoresif[i]/wam_scoresii[i]

    return clip_wams_normalized




#pass the if version
def calc_clip_d(clip_sims_ii, clip_sims_if):

    n = clip_sims_ii.shape[0]
    clip_d_scores = torch.zeros(n)
    clip_sims_ii = clip_sims_ii.mean(dim = (2,3))
    clip_sims_if = clip_sims_if.mean(dim = (2,3))
    for c in range(n):
        c_ii = clip_sims_ii[c,c]
        c_ii_a = clip_sims_if[c,c]
        sum_c = (1/(n-1)) * (clip_sims_ii[c].sum()-c_ii)
        clip_d_scores[c] = (c_ii-c_ii_a)/sum_c
        
    return clip_d_scores




def get_whackamole_score_from_tensor(metrics_tensor):
    n = metrics_tensor.shape[0]
    nearest_kt_scores = torch.zeros(n)
    nearest_kt_indices = torch.zeros(n)
    nearest_scores, score_indices = torch.topk(metrics_tensor, k = 2, dim=1)
    #print(score_indices.shape)
    for i in range(n):
        indices = score_indices[i]
        scores = nearest_scores[i]
       
        if indices[0]==i:
            nearest_kt_scores[i] = scores[1]
            nearest_kt_indices[i] = indices[1]
        else:
            nearest_kt_scores[i] = scores[0]
            nearest_kt_indices[i] = indices[0]
    return nearest_kt_scores, nearest_kt_indices

def get_self_kt(kt_if):
    return kt_if.diag()

def get_objectives_data(folder):
    #load in the pd and get the last entry in accuracies
    obj_data = pd.read_csv(folder+'/objectives_data.csv')
    return  obj_data


def make_table(dataframe):
     # Create a figure and axis
    fig, ax = plt.subplots()
    df = dataframe
    # Hide axis and set table dimensions
    ax.axis('off')
    table = ax.table(cellText=df.values, colLabels=df.columns, loc='center')

    # Set cell colors and text styles
    # table.set_fontsize(14)
    # table.scale(1, 2)
    table.auto_set_font_size(True)
    # for i, cell in enumerate(table.get_celld()):
    #     if i == 0:
    #         cell.set_text_props(weight='bold')
    #         cell.set_facecolor('#CCCCCC')
    #     else:
    #         cell.set_text_props(weight='normal')
    #     if i % 2 == 0:
    #         cell.set_facecolor('#EEEEEE')

    # Save the table to a file
    plt.savefig('run_metrics.pdf', bbox_inches='tight', pad_inches=0)

def loss_tracking(objective_datas, descriptors):
    fig, axs = plt.subplots(1, 2, figsize=(8, 4))

    # Plot the first subplot in position (0,0) of the grid
    for data, descriptor in zip(objective_datas, descriptors):
        axs[0].plot(data['attack_objective_values'].to_numpy()[0:11], label = descriptor)
        axs[1].plot(data['maintain_values'].to_numpy()[0:11], label = descriptor)
    xtick_locs = [0,5,10]
    xtick_labels = [0,1,2]
    axs[0].set_title('Attack Loss')
    axs[0].set_xticks(xtick_locs)
    axs[0].set_xticklabels(xtick_labels)
    axs[0].set_xlabel('Training Epoch')
    axs[0].set_ylabel('Loss')
    axs[0].set_yscale('log')
    axs[0].legend()

    # Plot the second subplot in position (0,1) of the grid
    axs[1].set_title('Maintain Loss')
    axs[1].set_xticks(xtick_locs)
    axs[1].set_xticklabels(xtick_labels)
    axs[1].set_xlabel('Training Epoch')
    axs[1].set_ylabel('Loss')
    axs[1].legend()

    plt.savefig('training_curves.pdf')
    plt.close()


def make_geometric_mean(clip_sim_ii, clip_sim_if, output_path, run_name):
    final_mean = clip_sim_if.float().view(clip_sim_if.shape[0], clip_sim_if.shape[1],-1)
    init_mean = clip_sim_ii.float().view(clip_sim_if.shape[0], clip_sim_if.shape[1],-1)


    init_geo_mean = torch.pow(torch.prod(init_mean, dim=-1), 1/init_mean.shape[1])
    final_geo_mean = torch.pow(torch.prod(final_mean, dim=-1), 1/final_mean.shape[1])

    final_mean_pruned = final_geo_mean[~torch.eye(final_geo_mean.shape[0], dtype=bool)].view(final_mean.shape[0], -1)
    init_mean_pruned =  init_geo_mean[~torch.eye(init_geo_mean.shape[0], dtype=bool)].view(init_mean.shape[0], -1)    
    #init_mean_pruned = init_geo_mean
    df = pd.DataFrame({
                        "init": init_mean_pruned.max(dim=1)[0].detach().cpu().numpy(),
                        "final": final_mean_pruned.max(dim=1)[0].detach().cpu().numpy()
                        })
    df = df.sort_values("init", ascending=True)
    plt.plot(df["init"].values, color = "blue", label = "Initial_to_initial_max")
    plt.plot(df["final"].values, color = "red", label = "Initial_to_final_max")
    plt.ylabel("geo. mean")
    plt.xlabel("channels")
    plt.legend()
    plt.title(run_name)
    plt.savefig(output_path)
    plt.close()
    return

def make_cos_sim_chart(ii_tensor, if_tensor, output_path, run_name):
    # print(ii_tensor.shape)
    # print(if_tensor.shape)
    final_mean = if_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    init_mean = ii_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    
    final_mean_pruned = final_mean[~torch.eye(final_mean.shape[0], dtype=bool)].view(final_mean.shape[0], -1)
    init_mean_pruned =  init_mean[~torch.eye(init_mean.shape[0], dtype=bool)].view(init_mean.shape[0], -1)    
    #init_mean_pruned = init_mean
    df = pd.DataFrame({
                        "init": init_mean_pruned.max(dim=1)[0].detach().cpu().numpy(),
                        "final": final_mean_pruned.max(dim=1)[0].detach().cpu().numpy()
                        })
    df = df.sort_values("init", ascending=True)
    plt.plot(df["init"].values, color = "blue", label = "Max Self-Similarity")
    plt.plot(df["final"].values, color = "red", label = "Max Cross-Similarity")
    plt.ylabel("Clip Similarity")
    plt.xlabel("Sorted Channels")
    plt.legend()
    plt.ylim(0,1)

    plt.title(run_name.replace('_', ' '))
    plt.savefig(output_path)
    plt.close()
    return


def make_kt_chart(ii_tensor, if_tensor, output_path, run_name):
    # print(ii_tensor.shape)
    # print(if_tensor.shape)
    final_mean = if_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    init_mean = ii_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    
    final_mean_pruned = final_mean[~torch.eye(final_mean.shape[0], dtype=bool)].view(final_mean.shape[0], -1)
    init_mean_pruned =  init_mean[~torch.eye(init_mean.shape[0], dtype=bool)].view(init_mean.shape[0], -1)    
    #init_mean_pruned = init_mean
    df = pd.DataFrame({
                        "init": init_mean_pruned.max(dim=1)[0].detach().cpu().numpy(),
                        "final": final_mean_pruned.max(dim=1)[0].detach().cpu().numpy()
                        })
    df = df.sort_values("init", ascending=True)
    plt.plot(df["init"].values, color = "blue", label = "Max KT Self-Score")
    plt.plot(df["final"].values, color = "red", label = "Max KT Cross-Score")
    plt.ylabel("KT Coefficient")
    plt.xlabel("Sorted Channels")
    plt.legend()
    plt.title(run_name.replace('_', ' '))
    plt.ylim(-1,1)
    plt.savefig(output_path)

    plt.close()
    return

if __name__ == "__main__":

    parser = ArgumentParser()
    parser.add_argument("--results-title", type=str)
    parser.add_argument("--attack-details", type = str) #Shape of v13/fixed_alpha)
    parser.add_argument("--output", type = str)
    parser.add_argument("--do-single-channel", action = "store_true")
    args = parser.parse_args()

    directories = [
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_c0_only',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c1_top10_to_zero',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c2_top10_to_zero',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c3_top10_to_zero',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c4_top10_to_zero',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c5_top10_to_zero',
        '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c6_top10_to_zero',
    ]

    base_directory = '/home/a_fuller/projects/Attacking-Interpretability/'

    

    attack_names = [
        'Conv5_C0_dataset_top10_to_zero',
        'Conv5_C1_dataset_top10_to_zero',
        'Conv5_C2_dataset_top10_to_zero',
        'Conv5_C3_dataset_top10_to_zero',
        'Conv5_C4_dataset_top10_to_zero',
        'Conv5_C5_dataset_top10_to_zero',
        'Conv5_C6_dataset_top10_to_zero',
    ]
    run_architectures =[
        'alexnet',
        'alexnet',
        'alexnet',
        'alexnet',
        'alexnet',
        'alexnet',
        'alexnet',

    ]

    training_loop_descriptors = [
        'Conv5 C0 Push Down',
        'Conv5 C1 Push Down',
        'Conv5 C2 Push Down',
        'Conv5 C3 Push Down',
        'Conv5 C4 Push Down',
        'Conv5 C5 Push Down',
        'Conv5 C6 Push Down',
    ]
    channels = [
        0,
        1,
        2,
        3,
        4,
        5,
        6,
        ]
    #/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_top10_to_zero/results/objectives_visual.jpg
    run_names = []
    clip_d_means = []
    kt_wams = []
    clip_wams = []
    accuracies = []
    kt_self_sims = []
    obj_datas = []
    # for i in range(len(attack_names)):
    #     directory = os.path.join(base_directory, args.attack_details, attack_names[i])

    graph_titles_clip = [
        'Conv5 Push-Down: Clip Similarity of Nearest Channel',
        'Conv5 Push-Up: Clip Similarity of Nearest Channel'
    ]
    graph_titles_kt = [
        'Conv5 Push-Down: Kendall-Tau of Nearest Channel',
        'Conv5 Push-Up: Kendall-Tau of Nearest Channel'
    ]

    if not args.do_single_channel:
        i=0
        print('handling individual runs!')
        for directory in directories:        
            output_folder = directory + '/results'
            run_name = directory.split('/')[-1]
            csv_folder = output_folder+'/csv'
            pdf_folder = output_folder + '/PDFs_' + run_name
            metrics_folder = output_folder + '/metrics'
            image_folder = output_folder+'/images'
            graphs_folder = 'graphs'
            ensure_dir(graphs_folder)
            data_dict = load_in_metrics(metrics_folder)
            print('clip sims means:')
            print(data_dict['clip_sims_ii'].mean())
            print(data_dict['clip_sims_if'].mean())


            #print(data_dict['clip_sims_ii'].shape)
            #print(data_dict['clip_sims_if'].shape)
            make_cos_sim_chart(data_dict['clip_sims_ii'].mean(dim=(2,3)), 
                                    data_dict['clip_sims_if'].mean(dim=(2,3)), 
                                    graphs_folder+f'/{run_name}_pruned_clip.pdf',
                                    graph_titles_clip[i]
                                    )
            make_kt_chart(data_dict['kt_ii'], 
                                    data_dict['kt_if'], 
                                    graphs_folder+f'/{run_name}_pruned_kt.pdf',
                                    graph_titles_kt[i]
                                    )
            i=i+1
            clip_d = calc_clip_d(data_dict['clip_sims_ii'], data_dict['clip_sims_if'])
            #print(clip_d.shape)
            clip_d_mean = clip_d.mean()
            #print(f'clip d mean: {clip_d_mean}')
            kt_wam_scores, kt_wam_indices = get_whackamole_score_from_tensor(data_dict['kt_if'])
            #print(kt_wam_scores.shape)
            kt_wam_mean = kt_wam_scores.mean()
            #print(kt_wam_mean)
            #clip_wam_scores, clip_wam_indices = get_whackamole_score_from_tensor(data_dict['clip_sims_if'].mean(dim=(2,3)))
            clip_wam_scores = get_clip_wam_normalized(data_dict['clip_sims_ii'].mean(dim=(2,3)), data_dict['clip_sims_if'].mean(dim=(2,3)))
            #print(f'clip wam scores shape: {clip_wam_scores.shape}')
            clip_wam_mean = clip_wam_scores.mean()
            self_kt = get_self_kt(data_dict['kt_if'])
            self_kt_mean = self_kt.mean()
            obj_data = get_objectives_data(output_folder)
            #print(obj_data)
            accuracy = obj_data['accuracy'][10]
            #print(accuracy)

            run_names.append(run_name.replace('_', ' '))
            clip_wams.append(clip_wam_mean.item())
            clip_d_means.append(clip_d_mean.item())
            kt_wams.append(kt_wam_mean.item())
            accuracies.append(accuracy)
            kt_self_sims.append(self_kt_mean.item())
            obj_datas.append(obj_data)



        results = pd.DataFrame(data = {
            'run name' : run_names,     
            'clip d means' : clip_d_means,
            'kt self sim' : kt_self_sims,
            'clip whack-a-mole' : clip_wams,
            'kt whack-a-mole' : kt_wams,
            
            'accuracy' : accuracies,
            
        }, #index = run_names
        )
        results.to_csv(f'{args.results_title}_run_metrics.csv')
        print(f'{args.results_title}_run_metrics.csv')

        # if args.do_aggregate:
        #     pass

        loss_tracking(objective_datas=obj_datas, descriptors = training_loop_descriptors)
    else:   
        print('Doing single channel runs!')

        for i, directory in enumerate(directories):
            channel=channels[i]        
            output_folder = directory + '/results'
            run_name = directory.split('/')[-1]
            csv_folder = output_folder+'/csv'
            pdf_folder = output_folder + '/PDFs_' + run_name
            metrics_folder = output_folder + '/metrics'
            image_folder = output_folder+'/images'
            graphs_folder = 'graphs'
            ensure_dir(graphs_folder)
            data_dict = load_in_metrics(metrics_folder)
 

            # print('hello!')
            # meanie = 0
            # count = 0
            # channel_clip_sims = data_dict['clip_sims_ii'].mean(dim=(2,3))
            # print(data_dict['clip_sims_ii'].shape)
            # print(data_dict['clip_sims_ii'].mean(dim=(2,3)))
            # print(channel_clip_sims)
            # print(channel_clip_sims.mean(dim=1))
            # print(channel_clip_sims.diag().mean())


            # for i in range(256):
            #     for j in range(256):
            #         if i != j:
            #             meanie = meanie + channel_clip_sims[i,j].item()
            #             count = count + 1


            # print('clip sim actual mean')
            # print(data_dict['clip_sims_ii'].mean(dim=(2,3)).diag().sum())

            # print(data_dict['clip_sims_ii'].mean())
            # print(meanie)
            # print(count)
            # print(meanie/count)

            # exit(0)
            # mean = 0
            # count = 0
            # for i in range(256):
            #     for j in range(256):
            #         if i != j:
            #             mean = mean + data_dict['kt_ii'][i,j].mean()
            #             count = count + 1
            # print('kt actual mean')
            # print(mean/count)

          
            # make_cos_sim_chart(data_dict['clip_sims_ii'].mean(dim=(2,3)), 
            #                         data_dict['clip_sims_if'].mean(dim=(2,3)), 
            #                         graphs_folder+f'/{run_name}_pruned_clip.pdf',
            #                         graph_titles_clip[i]
            #                         )
            # make_kt_chart(data_dict['kt_ii'], 
            #                         data_dict['kt_if'], 
            #                         graphs_folder+f'/{run_name}_pruned_kt.pdf',
            #                         graph_titles_kt[i]
            #                         )

            clip_d = calc_clip_d(data_dict['clip_sims_ii'], data_dict['clip_sims_if'])
            #print(clip_d.shape)
            clip_d_mean = clip_d[channel]
            #print(f'clip d mean: {clip_d_mean}')
            kt_wam_scores, kt_wam_indices = get_whackamole_score_from_tensor(data_dict['kt_if'])
            #print(kt_wam_scores.shape)
            kt_wam_mean = kt_wam_scores[channel]
            #print('closest channel kt: ', kt_wam_indices[channel])
            #print(kt_wam_mean)
            #clip_wam_scores, clip_wam_indices = get_whackamole_score_from_tensor(data_dict['clip_sims_if'].mean(dim=(2,3)))
            clip_wam_scores = get_clip_wam_normalized(data_dict['clip_sims_ii'].mean(dim=(2,3)), data_dict['clip_sims_if'].mean(dim=(2,3)))
            #print(f'clip wam scores shape: {clip_wam_scores.shape}')
            clip_wam_mean = clip_wam_scores[channel]
            #print('top clip channel: ', clip_wam_scores.argmax())
            self_kt = get_self_kt(data_dict['kt_if'])
            self_kt_mean = self_kt[channel]
            obj_data = get_objectives_data(output_folder)
            #print(obj_data)
            accuracy = obj_data['accuracy'][len(obj_data['accuracy'])-1]
            #print(accuracy)

            run_names.append(run_name.replace('_', ' '))
            clip_wams.append(clip_wam_mean.item())
            clip_d_means.append(clip_d_mean.item())
            kt_wams.append(kt_wam_mean.item())
            accuracies.append(accuracy)
            kt_self_sims.append(self_kt_mean.item())
            obj_datas.append(obj_data)



        results = pd.DataFrame(data = {
            'run name' : run_names,     
            'clip d means' : clip_d_means,
            'kt self sim' : kt_self_sims,
            'clip whack-a-mole' : clip_wams,
            'kt whack-a-mole' : kt_wams,
            
            'accuracy' : accuracies,
            
        }, #index = run_names
        )
        results.to_csv(f'{args.results_title}_run_metrics.csv')
        print(f'{args.results_title}_run_metrics.csv')

        # if args.do_aggregate:
        #     pass

        loss_tracking(objective_datas=obj_datas, descriptors = training_loop_descriptors)


### let's make a function!


# directories = [
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f0_dataset_top10_to_zero',
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f3_dataset_top10_to_zero_a01',
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f6_dataset_top10_to_zero',
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f8_dataset_top10_to_zero',
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_top10_to_zero',
#         '/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_refs_to_top',
#         '/home/a_fuller/projects/Attacking-Interpretability/efficientnet/v12/f7_0_b3_0_dataset_top10_to_zero',
#     ]

#     base_directory = '/home/a_fuller/projects/Attacking-Interpretability/'

    

#     attack_names = [
#         'f0_dataset_top10_to_zero',
#         'f3_dataset_top10_to_zero',
#         'f6_dataset_top10_to_zero',
#         'f8_dataset_top10_to_zero',
#         'f10_dataset_top10_to_zero',
#         'f10_refs_to_top'
#     ]
#     run_architectures =[
#         'alexnet',
#         'alexnet',
#         'alexnet',
#         'alexnet',
#         'alexnet',
#         'alexnet',
#         'efficientnet'
#     ]

#     training_loop_descriptors = [
#         'conv1 pushdown',
#         'conv2 pushdown',
#         'conv3 pushdown',
#         'conv4 pushdown',
#         'conv5 pushdown',
#         'conv5 push-up',
#     ]


