import os.path as path
from argparse import ArgumentParser
import os
# from reportlab.platypus import SimpleDocTemplate, Image, Paragraph, Spacer
# from reportlab.lib.pagesizes import A4, letter
# from PIL import Image as PImage
# from reportlab.lib.styles import ParagraphStyle
# import torchvision.transforms as transforms
import torch
import pickle
import pandas as pd
import shutil
import numpy as np
# from optimization_requirements import visualize_objectives_v3
# from torchvision import datasets, models, transforms

import clip_descriptions
import utils
# from tqdm import tqdm
import matplotlib.pyplot as plt
from analyze_activations_by_channel import get_overall_kt_results, load_in_activations_from_paths , get_kt_from_vectors
# from PIL import Image
#import get_imagenet_activations_by_channel
from torchvision import models
import get_imagenet_activations_by_channel_v2
# import matplotlib.image as mpimg

# from scripts.get_imagenet_activations_by_channel import get_and_sort_activations
# import torch.nn.functional as F


def get_clip_wam_normalized(clip_sims_ii, clip_sims_if):
    n = clip_sims_ii.shape[0]
    wam_scoresii, wam_indicesii = get_whackamole_score_from_tensor(clip_sims_ii)
    wam_scoresif, wam_indicesif = get_whackamole_score_from_tensor(clip_sims_if)

    clip_wams_normalized = torch.zeros(n)
    for i in range(n):
        clip_wams_normalized[i] = wam_scoresif[i]/wam_scoresii[i]

    return clip_wams_normalized




#pass the if version
def calc_clip_d(clip_sims_ii, clip_sims_if):

    n = clip_sims_ii.shape[0]
    clip_d_scores = torch.zeros(n)
    clip_sims_ii = clip_sims_ii.mean(dim = (2,3))
    clip_sims_if = clip_sims_if.mean(dim = (2,3))
    for c in range(n):
        c_ii = clip_sims_ii[c,c]
        c_ii_a = clip_sims_if[c,c]
        sum_c = (1/(n-1)) * (clip_sims_ii[c].sum()-c_ii)
        clip_d_scores[c] = (c_ii-c_ii_a)/sum_c
        
    return clip_d_scores




def get_whackamole_score_from_tensor(metrics_tensor):
    n = metrics_tensor.shape[0]
    nearest_kt_scores = torch.zeros(n)
    nearest_kt_indices = torch.zeros(n)
    nearest_scores, score_indices = torch.topk(metrics_tensor, k = 2, dim=1)
    #print(score_indices.shape)
    for i in range(n):
        indices = score_indices[i]
        scores = nearest_scores[i]
       
        if indices[0]==i:
            nearest_kt_scores[i] = scores[1]
            nearest_kt_indices[i] = indices[1]
        else:
            nearest_kt_scores[i] = scores[0]
            nearest_kt_indices[i] = indices[0]
    return nearest_kt_scores, nearest_kt_indices

def get_self_kt(kt_if):
    return kt_if.diag()

def get_objectives_data(folder, file_name='objectives_data.csv' ):
    #load in the pd and get the last entry in accuracies
    obj_data = pd.read_csv(os.path.join(folder,file_name))
    return  obj_data


def make_cos_sim_chart(ii_tensor, if_tensor, output_path, run_name):
    # print(ii_tensor.shape)
    # print(if_tensor.shape)
    final_mean = if_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    init_mean = ii_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    
    final_mean_pruned = final_mean[~torch.eye(final_mean.shape[0], dtype=bool)].view(final_mean.shape[0], -1)
    init_mean_pruned =  init_mean[~torch.eye(init_mean.shape[0], dtype=bool)].view(init_mean.shape[0], -1)    
    #init_mean_pruned = init_mean
    df = pd.DataFrame({
                        "init": init_mean_pruned.max(dim=1)[0].detach().cpu().numpy(),
                        "final": final_mean_pruned.max(dim=1)[0].detach().cpu().numpy()
                        })
    df = df.sort_values("init", ascending=True)
    plt.plot(df["init"].values, color = "blue", label = "Max Self-Similarity")
    plt.plot(df["final"].values, color = "red", label = "Max Cross-Similarity")
    plt.ylabel("Clip Similarity")
    plt.xlabel("Sorted Channels")
    plt.legend()
    plt.ylim(0,1)

    plt.title(run_name.replace('_', ' '))
    plt.savefig(output_path)
    plt.close()
    return


def make_kt_chart(ii_tensor, if_tensor, output_path, run_name):
    # print(ii_tensor.shape)
    # print(if_tensor.shape)
    final_mean = if_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    init_mean = ii_tensor.float().view(if_tensor.shape[0], if_tensor.shape[1])
    
    final_mean_pruned = final_mean[~torch.eye(final_mean.shape[0], dtype=bool)].view(final_mean.shape[0], -1)
    init_mean_pruned =  init_mean[~torch.eye(init_mean.shape[0], dtype=bool)].view(init_mean.shape[0], -1)    
    #init_mean_pruned = init_mean
    df = pd.DataFrame({
                        "init": init_mean_pruned.max(dim=1)[0].detach().cpu().numpy(),
                        "final": final_mean_pruned.max(dim=1)[0].detach().cpu().numpy()
                        })
    df = df.sort_values("init", ascending=True)
    plt.plot(df["init"].values, color = "blue", label = "Max KT Self-Score")
    plt.plot(df["final"].values, color = "red", label = "Max KT Cross-Score")
    plt.ylabel("KT Coefficient")
    plt.xlabel("Sorted Channels")
    plt.legend()
    plt.title(run_name.replace('_', ' '))
    plt.ylim(-1,1)
    plt.savefig(output_path)

    plt.close()
    return


def visualize_objectives_v3(objective_values, maintain_values, accuracy, activation_norms, save_interval,
                            filename=None, title="<insert-title>", inline=False):
    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))

    # Design the plots

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
    plt.suptitle(title, fontsize=14)
    ax1.set_title('Losses')
    ax1.set_xlabel("step", fontsize=14)
    ax1.plot(steps, objective_values, label="Attack Loss", color='red', marker='.')
    ax1.set_ylabel("attack loss", color="red", fontsize=14)
    #Create twin to plot maintain values

    ax1_1 = ax1.twinx()
    ax1_1.plot(steps, maintain_values, label="Maintain Loss", color='blue', marker='.')
    ax1_1.set_ylabel("maintain loss", color="blue", fontsize=14)

    #ax1.legend(loc='best')
    #ax1_1.legend(loc='best')

    ax2.set_title('Accuracy and Activation Norm')
    ax2.plot(steps, accuracy, label='Accuracy', color='blue', marker='.')
    ax2.set_ylabel("accuracy", color='blue', fontsize=14)
    #Create twin to plot norm
    ax2_1 = ax2.twinx()
    ax2_1.plot(steps, activation_norms, label='Step Activation Relative Norm', color='red', marker='.')
    ax2_1.set_xlabel("step", fontsize=14)
    ax2_1.set_ylabel("relative norm", color='red', fontsize=14)
    #ax2.legend(loc='best')
    #ax2_1.legend(loc='best')


    if filename is not None:
        plt.savefig(filename)

    if inline:
        plt.show()
    plt.close()


def visualize_objectives_v4(objective_values,  accuracy,  save_interval,
                            filename=None, title="<insert-title>", inline=False):
    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))

    # Design the plots

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
    plt.suptitle(title, fontsize=14)
    ax1.set_title('Losses')
    ax1.set_xlabel("step", fontsize=14)
    ax1.plot(steps, objective_values, label="Loss", color='red', marker='.')
    ax1.set_ylabel("loss", color="red", fontsize=14)
    #Create twin to plot maintain values


    #ax1.legend(loc='best')
    #ax1_1.legend(loc='best')

    ax2.set_title('Accuracy and Activation Norm')
    ax2.plot(steps, accuracy, label='Accuracy', color='blue', marker='.')
    ax2.set_ylabel("accuracy", color='blue', fontsize=14)
    #Create twin to plot norm
    #ax2.legend(loc='best')
    #ax2_1.legend(loc='best')


    if filename is not None:
        plt.savefig(filename)

    if inline:
        plt.show()
    plt.close()

def generate_metrics_for_multilayer_attack(results_directory, do_kt, do_clip, do_validation, do_kt_overwrite, default_baseline_alexnet, graph_output_folder = None):
    if not path.exists(results_directory):
        print(f'ERROR, the results directory {results_directory} does not exist!')
        print('Exiting program')
        exit(0)

    with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
        results_dict = pickle.load(f)
    print('results dict:')
    for key in results_dict.keys():
        print(key)
    # for key in results_dict.keys():
    #     print(f'dict has key: {key}')

    config_dict = utils.unpack_config_file_into_dict(results_directory)
    print('config:')
    for key in config_dict.keys():
        print(key)
    print(config_dict['layer'])
    
    all_layers = []
    # The layer in the config dict is a list of lists, which each of them being the same layer name repeated.
    # This gets just one layer name from each of these lists
    for layer_string in config_dict['layer'][1:-1].split('], '):
        all_layers.append(layer_string.split(', ')[0].strip('[').strip(']').strip("\'"))
    # print(all_layers)
    # print(all_layers[0])
    all_channels = []
    for channel_string in config_dict['channel'][1:-1].split('], '):
        channel_string = channel_string.strip('[')
        channel_string = channel_string.strip(']')
        channel_list = channel_string.split(', ')
        channel_list = [int(c) for c in channel_list]
        all_channels.append(channel_list)
    #  {}
    # with open(path.join(results_directory, "configuration.txt")) as f:
    #     for line in f:
    #         (key, val) = line.split(':')
    #         config_dict[key] = val.strip()

    #TODO alter code to use csv and image folder for intermediary steps!
    output_folder = results_directory + '/results'
    run_name = results_directory.split('/')[-1]
    csv_folder = output_folder+'/csv'
    pdf_folder = output_folder + '/PDFs_' + run_name
    metrics_folder = output_folder + '/metrics'
    image_folder = output_folder+'/images'
    
    upload_folder = metrics_folder+f'/{run_name}_visualizations/graphs'
    

    model_arch = config_dict['arch']
    print(f'using {model_arch}')
    # init_model, final_model = utils.get_init_final_models(results_directory,layer=None, model_arch=model_arch)
    # layers = utils.get_tuple_from_config_dict(config_dict, 'layer')
    print(f'sending {model_arch}')
    if default_baseline_alexnet:
        print('using baseline alexnet')
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        init_model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1).to(device)
        final_model = utils.get_init_final_models(results_directory,layer = layers[0],#.strip("\'"),
                                                        model_arch = model_arch, get_final_model=False)
    else:
        init_model, final_model = utils.get_init_final_models(results_directory,layer = all_layers[0],#.strip("\'"),
                                                        model_arch = model_arch, get_final_model=True)
    ### To device?
    # def ensure_dir(directory):
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)

    utils.ensure_dir(output_folder)
    utils.ensure_dir(pdf_folder)
    utils.ensure_dir(csv_folder)
    utils.ensure_dir(image_folder)
    utils.ensure_dir(metrics_folder)
    utils.ensure_dir(upload_folder)
    # Use the configuration file to get information about the run
    # Allows for easier organization of the following code.

    # Generate a line plot of the results

    # unpack some values for run metrics and info
    attack_objective_values = results_dict['attack_obj_vals']
    maintain_values = results_dict['maintain_obj_vals']
    accuracy = results_dict['accuracy']
    activation_norms = results_dict['activation_norms']
    alphas = results_dict['alphas']

    shutil.copyfile(path.join(results_directory, "configuration.txt"), path.join(pdf_folder, "configuration.txt"))

    # print(d['channel'])
    # In particular we need to see what channels are present in the run, and make a list of them
    # if model_arch == 'vit_b32':
    #     channels = utils.get_tuple_from_config_dict(config_dict, 'features')
    #     # print(channels)
    #     if channels[0] == "'a'":
    #         channels = list(range(768))
    # else:
    #     channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
    #     channels = [int(i) for i in channels]
    # get kt results dict
    all_steps = [i for i in
                range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

    #Generate and store the activations and return the paths to them
    for layer, channels in zip(all_layers, all_channels):
        print(f'layer: {layer} commencing!')
        init_acts_path, final_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model, layer=[layer], channels=channels, model_name = model_arch)
        if do_validation:
            init_val_acts_path, final_val_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model,layer=[layer], channels=channels, model_name=model_arch, data_loader=utils.get_val_topk_dataset_loader(), name = 'val')
            init_top_val_norms, init_top_val_indices = torch.topk(init_val_acts, k=10, dim=1)
            final_top_val_norms, final_top_val_indices = torch.topk(final_val_acts, k=10, dim=1)


            torch.save(init_top_val_indices,path.join(metrics_folder, 'init_top_indices_val.pt'))
            torch.save(final_top_val_indices,path.join(metrics_folder, 'final_top_indices_val.pt'))
            init_top_val_indices = init_top_val_indices.T
            final_top_val_indices = final_top_val_indices.T
        # '''
        # Need to add logic for getting top image indices across top validation activations
        # '''
        
            init_val_acts, final_val_acts =  load_in_activations_from_paths(init_val_acts_path, final_val_acts_path)

        init_acts, final_acts =  load_in_activations_from_paths(init_acts_path, final_acts_path)
        init_top_norms, init_top_indices = torch.topk(init_acts, k=10, dim=1)
        final_top_norms, final_top_indices = torch.topk(final_acts, k=10, dim=1)

        torch.save(init_top_indices, path.join(metrics_folder, 'init_top_indices.pt'))
        torch.save(final_top_indices, path.join(metrics_folder, 'final_top_indices.pt'))
        #Confirm the same top indices
        # init_top_norms, test_init_top_indices = torch.topk(init_acts, k=10, dim=1)
        # res = results_dict['init_top_indices']-test_init_top_indices.T
        # print(res)

        # init_top_norms, test_final_top_indices = torch.topk(final_acts, k=10, dim=1)
        # res = results_dict['final_top_indices']-test_final_top_indices.T
        # print(res)
        # print('validation activation shape:')
        # print(init_val_acts.shape)
        
        
        ###Time to overwrite the top indices from the results dict with the new ones mwahahahahahahahahahahahaha
        #init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
        #TODO: Rewrite this whole thing. Use more functions and less script. Possibly split this whole .py into multiple scripts
        if do_kt:
            kt_dict = get_overall_kt_results(output_folder, init_acts_path, final_acts_path, do_kt_overwrite, layer=layer+'_')
            torch.save(kt_dict['cor_ii'], path.join(metrics_folder,'kt_ii.pt'))
            torch.save(kt_dict['cor_if'], path.join(metrics_folder,'kt_if.pt'))
            if do_validation:
                val_kt_dict = get_overall_kt_results(output_folder, init_val_acts_path, final_val_acts_path, do_kt_overwrite)
                torch.save(val_kt_dict['cor_ii'], path.join(metrics_folder,'val_kt_ii.pt'))
                torch.save(val_kt_dict['cor_if'], path.join(metrics_folder,'val_kt_if.pt'))

        init_norms_for_test, init_indices_for_test = torch.topk(init_acts, k=10, dim=1)
        final_norms_for_test, final_indices_for_test = torch.topk(final_acts, k=10, dim=1)
        #Instead of taking the top from the saved indices from training, use the ones we calc'd in this script

        

        if do_clip:
            use_full_embeddings = True
            print('commencing clip analysis!')
            with open(f'{results_directory}/results_dict.pkl', 'rb') as f: #HERE
                results_dict = pickle.load(f)
            results_dict[f'init_top_indices_{layer}'] = init_indices_for_test.T
            results_dict[f'final_top_indices_{layer}'] = final_indices_for_test.T
            init_indices = results_dict[f'init_top_indices_{layer}']
            final_indices = results_dict[f'final_top_indices_{layer}']

            if use_full_embeddings:
                use_basic_means = True
                init_clip_similarities, final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_indices, final_indices)
                torch.save(init_clip_similarities, path.join(metrics_folder,'full_clip_ii.pt'))
                torch.save(final_clip_similarities, path.join(metrics_folder,'full_clip_if.pt'))
                if do_validation:
                    val_init_clip_similarities, val_final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_top_val_indices, final_top_val_indices)
                    torch.save(val_init_clip_similarities, path.join(metrics_folder,'val_full_clip_ii.pt'))
                    torch.save(val_final_clip_similarities, path.join(metrics_folder,'val_full_clip_if.pt'))
                print('using full embeddings')
                if use_basic_means:
                    print('using means of embeddings cos sim')
                    clip_description = 'Mean Similarity'
                    init_clip_similarities_to_use = init_clip_similarities.mean(dim=(2,3))
                    final_clip_similarities_to_use = final_clip_similarities.mean(dim=(2,3))
                    #print(f'clip similarities to use shape: {init_clip_similarities_to_use}')

                else:
                    print('using means of the max similarity embeddings')
                    clip_description = 'Max Similarity'
                    init_clip_similarities_to_use = init_clip_similarities.max(dim=3)[0].mean(dim=2)
                    final_clip_similarities_to_use = final_clip_similarities.max(dim=3)[0].mean(dim=2)

            #Takes the mean of the embeddings and gets the 256,256 results to use
            else:
                print('using mean of embeddings')
                clip_description = 'Averaged Embeddings'

                init_clip_similarities_to_use, final_clip_similarities_to_use = clip_descriptions.get_overall_cos_sim_results_with_mean_embeddings(init_indices, final_indices)
            # print('Similiarites to use shapes:')
            # print(init_clip_similarities_to_use.shape)
            # print(final_clip_similarities_to_use.shape)
            utils.save_matrix(init_clip_similarities_to_use, csv_folder, f'{layer}_Averaged_CosSim_ii')
            utils.save_matrix(final_clip_similarities_to_use, csv_folder, f'{layer}_Averaged_CosSim_if')



        data_dict = {
            'clip_sims_ii': init_clip_similarities,
            'clip_sims_if': final_clip_similarities,
            'kt_ii': kt_dict['cor_ii'],
            'kt_if': kt_dict['cor_if'],
        }
        # Generate the graphs for the losses, accuracy and norms
        visualize_objectives_v3(
            attack_objective_values,
            maintain_values,
            accuracy,
            activation_norms,
            int(config_dict['save_interval']),
            filename=path.join(image_folder, f"{layer}_objectives_visual.jpg"),
            title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
            inline=False)

        visualize_objectives_v3(
            attack_objective_values,
            maintain_values,
            accuracy,
            activation_norms,
            int(config_dict['save_interval']),
            filename=path.join(upload_folder, f"{layer}_objectives_visual.pdf"),
            title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
            inline=False)

        # Save the raw data as a CSV
        objectives_data = pd.DataFrame({"accuracy": accuracy,
                                        "attack_objective_values": attack_objective_values,
                                        "maintain_values": maintain_values,
                                        "activation norms": activation_norms,
                                        'alphas': alphas
                                        })

        objectives_data.to_csv(path.join(output_folder, f"{layer}_objectives_data.csv"))
        # this is just in case we want a copy to look at directly later
        objectives_data.to_csv(path.join(upload_folder, f"{layer}_objectives_data.csv"))

        make_cos_sim_chart(data_dict['clip_sims_ii'].mean(dim=(2,3)), 
                                        data_dict['clip_sims_if'].mean(dim=(2,3)), 
                                        upload_folder+f'/{run_name}_{layer}_pruned_clip.pdf',#HERE
                                        'Clip Comparisons'
                                        )
        make_kt_chart(data_dict['kt_ii'], 
                                data_dict['kt_if'], 
                                upload_folder+f'/{run_name}_{layer}_pruned_kt.pdf', #HERE
                                'KT Comparisons'
                                )
        clip_d = calc_clip_d(data_dict['clip_sims_ii'], data_dict['clip_sims_if'])
        #print(clip_d.shape)
        #print(f'clip d mean: {clip_d_mean}')
        kt_wam_scores, kt_wam_indices = get_whackamole_score_from_tensor(data_dict['kt_if'])
        #print(kt_wam_scores.shape)
        #print(kt_wam_mean)
        #clip_wam_scores, clip_wam_indices = get_whackamole_score_from_tensor(data_dict['clip_sims_if'].mean(dim=(2,3)))
        clip_wam_scores = get_clip_wam_normalized(data_dict['clip_sims_ii'].mean(dim=(2,3)), data_dict['clip_sims_if'].mean(dim=(2,3)))
        #print(f'clip wam scores shape: {clip_wam_scores.shape}')
        self_kt = get_self_kt(data_dict['kt_if'])
        obj_data = get_objectives_data(output_folder, f"{layer}_objectives_data.csv")
        #print(obj_data)
        #print(accuracy)

        # run_names.append(run_name.replace('_', ' '))
        # clip_wams.append(clip_wam_mean.item())
        # clip_d_means.append(clip_d_mean.item())
        # kt_wams.append(kt_wam_mean.item())
        # accuracies.append(accuracy)
        # kt_self_sims.append(self_kt_mean.item())
        # obj_datas.append(obj_data)


        #TODO: VERIFY THIS! AND ARTIFICIAL FOR ViT?
        results = pd.DataFrame(data = {
            'run name' : [run_name.replace('_', ' ')],     
            'clip d mean' : [clip_d.mean().item()],
            'kt self sim' : [self_kt.mean().item()],
            'clip whack-a-mole' : [clip_wam_scores.mean().item()],
            'kt whack-a-mole' : [kt_wam_scores.mean().item()],
            
            'accuracy' : [obj_data['accuracy'].iloc[-1]],
            
        }, #index = run_names
        )
        results.to_csv(f'{upload_folder}/{layer}_run_metrics.csv')
        #print(f'{args.results_title}_run_metrics.csv')

def generate_metrics(results_directory, do_kt, do_clip, do_validation, do_kt_overwrite, default_baseline_alexnet, graph_output_folder = None):
    if not path.exists(results_directory):
        print(f'ERROR, the results directory {results_directory} does not exist!')
        print('Exiting program')
        exit(0)

    with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
        results_dict = pickle.load(f)

    # for key in results_dict.keys():
    #     print(f'dict has key: {key}')

    config_dict = utils.unpack_config_file_into_dict(results_directory)
    #  {}
    # with open(path.join(results_directory, "configuration.txt")) as f:
    #     for line in f:
    #         (key, val) = line.split(':')
    #         config_dict[key] = val.strip()

    #TODO alter code to use csv and image folder for intermediary steps!
    output_folder = results_directory + '/results'
    run_name = results_directory.split('/')[-1]
    csv_folder = output_folder+'/csv'
    pdf_folder = output_folder + '/PDFs_' + run_name
    metrics_folder = output_folder + '/metrics'
    image_folder = output_folder+'/images'
    
    upload_folder = metrics_folder+f'/{run_name}_visualizations/graphs'
    

    model_arch = config_dict['arch']
    print(f'using {model_arch}')
    # init_model, final_model = utils.get_init_final_models(results_directory,layer=None, model_arch=model_arch)
    layers = utils.get_tuple_from_config_dict(config_dict, 'layer')
    print(f'sending {model_arch}')
    if default_baseline_alexnet:
        print('using baseline alexnet')
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        init_model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1).to(device)
        final_model = utils.get_init_final_models(results_directory,layer = layers[0],#.strip("\'"),
                                                        model_arch = model_arch, get_final_model=False)
    else:
        init_model, final_model = utils.get_init_final_models(results_directory, #layer = layers[0],#.strip("\'"),
                                                        model_arch = model_arch,)# get_final_model=True)
    ### To device?
    # def ensure_dir(directory):
    #     if not os.path.exists(directory):
    #         os.makedirs(directory)

    utils.ensure_dir(output_folder)
    utils.ensure_dir(pdf_folder)
    utils.ensure_dir(csv_folder)
    utils.ensure_dir(image_folder)
    utils.ensure_dir(metrics_folder)
    utils.ensure_dir(upload_folder)
    # Use the configuration file to get information about the run
    # Allows for easier organization of the following code.

    # Generate a line plot of the results

    # unpack some values for run metrics and info
    attack_objective_values = results_dict['attack_obj_vals']
    maintain_values = results_dict['maintain_obj_vals']
    accuracy = results_dict['accuracy']
    activation_norms = results_dict['activation_norms']
    alphas = results_dict['alphas']

    shutil.copyfile(path.join(results_directory, "configuration.txt"), path.join(pdf_folder, "configuration.txt"))

    # print(d['channel'])
    # In particular we need to see what channels are present in the run, and make a list of them
    if model_arch == 'vit_b32':
        channels = utils.get_tuple_from_config_dict(config_dict, 'features')
        # print(channels)
        if channels[0] == "'a'":
            channels = list(range(768))
    else:
        channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
        channels = [int(i) for i in channels]
    # get kt results dict
    all_steps = [i for i in
                range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

    #Generate and store the activations and return the paths to them
    init_acts_path, final_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model, model_name = model_arch)
    if do_validation:
        init_val_acts_path, final_val_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model, model_name=model_arch, data_loader=utils.get_val_topk_dataset_loader(), name = 'val')
        init_top_val_norms, init_top_val_indices = torch.topk(init_val_acts, k=10, dim=1)
        final_top_val_norms, final_top_val_indices = torch.topk(final_val_acts, k=10, dim=1)


        torch.save(init_top_val_indices,path.join(metrics_folder, 'init_top_indices_val.pt'))
        torch.save(final_top_val_indices,path.join(metrics_folder, 'final_top_indices_val.pt'))
        init_top_val_indices = init_top_val_indices.T
        final_top_val_indices = final_top_val_indices.T
    # '''
    # Need to add logic for getting top image indices across top validation activations
    # '''
    
        init_val_acts, final_val_acts =  load_in_activations_from_paths(init_val_acts_path, final_val_acts_path)

    init_acts, final_acts =  load_in_activations_from_paths(init_acts_path, final_acts_path)
    init_top_norms, init_top_indices = torch.topk(init_acts, k=10, dim=1)
    final_top_norms, final_top_indices = torch.topk(final_acts, k=10, dim=1)

    torch.save(init_top_indices, path.join(metrics_folder, 'init_top_indices.pt'))
    torch.save(final_top_indices, path.join(metrics_folder, 'final_top_indices.pt'))
    #Confirm the same top indices
    # init_top_norms, test_init_top_indices = torch.topk(init_acts, k=10, dim=1)
    # res = results_dict['init_top_indices']-test_init_top_indices.T
    # print(res)

    # init_top_norms, test_final_top_indices = torch.topk(final_acts, k=10, dim=1)
    # res = results_dict['final_top_indices']-test_final_top_indices.T
    # print(res)
    # print('validation activation shape:')
    # print(init_val_acts.shape)
    
    
    ###Time to overwrite the top indices from the results dict with the new ones mwahahahahahahahahahahahaha
    #init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
    #TODO: Rewrite this whole thing. Use more functions and less script. Possibly split this whole .py into multiple scripts
    if do_kt:
        kt_dict = get_overall_kt_results(output_folder, init_acts_path, final_acts_path, do_kt_overwrite)
        torch.save(kt_dict['cor_ii'], path.join(metrics_folder,'kt_ii.pt'))
        torch.save(kt_dict['cor_if'], path.join(metrics_folder,'kt_if.pt'))
        if do_validation:
            val_kt_dict = get_overall_kt_results(output_folder, init_val_acts_path, final_val_acts_path, do_kt_overwrite)
            torch.save(val_kt_dict['cor_ii'], path.join(metrics_folder,'val_kt_ii.pt'))
            torch.save(val_kt_dict['cor_if'], path.join(metrics_folder,'val_kt_if.pt'))
    init_norms_for_test, init_indices_for_test = torch.topk(init_acts, k=10, dim=1)
    final_norms_for_test, final_indices_for_test = torch.topk(final_acts, k=10, dim=1)
    #Instead of taking the top from the saved indices from training, use the ones we calc'd in this script
    results_dict['init_top_indices'] = init_indices_for_test.T
    results_dict['final_top_indices'] = final_indices_for_test.T

    if do_clip:
        use_full_embeddings = True
        
        with open(f'{results_directory}/results_dict.pkl', 'rb') as f: #HERE
            results_dict = pickle.load(f)

        init_indices = results_dict['init_top_indices']
        final_indices = results_dict['final_top_indices']

        if use_full_embeddings:
            use_basic_means = True
            init_clip_similarities, final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_indices, final_indices)
            torch.save(init_clip_similarities, path.join(metrics_folder,'full_clip_ii.pt'))
            torch.save(final_clip_similarities, path.join(metrics_folder,'full_clip_if.pt'))
            if do_validation:
                val_init_clip_similarities, val_final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_top_val_indices, final_top_val_indices)
                torch.save(val_init_clip_similarities, path.join(metrics_folder,'val_full_clip_ii.pt'))
                torch.save(val_final_clip_similarities, path.join(metrics_folder,'val_full_clip_if.pt'))
            print('using full embeddings')
            if use_basic_means:
                print('using means of embeddings cos sim')
                clip_description = 'Mean Similarity'
                init_clip_similarities_to_use = init_clip_similarities.mean(dim=(2,3))
                final_clip_similarities_to_use = final_clip_similarities.mean(dim=(2,3))
                #print(f'clip similarities to use shape: {init_clip_similarities_to_use}')

            else:
                print('using means of the max similarity embeddings')
                clip_description = 'Max Similarity'
                init_clip_similarities_to_use = init_clip_similarities.max(dim=3)[0].mean(dim=2)
                final_clip_similarities_to_use = final_clip_similarities.max(dim=3)[0].mean(dim=2)

        #Takes the mean of the embeddings and gets the 256,256 results to use
        else:
            print('using mean of embeddings')
            clip_description = 'Averaged Embeddings'

            init_clip_similarities_to_use, final_clip_similarities_to_use = clip_descriptions.get_overall_cos_sim_results_with_mean_embeddings(init_indices, final_indices)
        # print('Similiarites to use shapes:')
        # print(init_clip_similarities_to_use.shape)
        # print(final_clip_similarities_to_use.shape)
        utils.save_matrix(init_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_ii')
        utils.save_matrix(final_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_if')



    data_dict = {
        'clip_sims_ii': init_clip_similarities,
        'clip_sims_if': final_clip_similarities,
        'kt_ii': kt_dict['cor_ii'],
        'kt_if': kt_dict['cor_if'],
    }
    # Generate the graphs for the losses, accuracy and norms
    visualize_objectives_v3(
        attack_objective_values,
        maintain_values,
        accuracy,
        activation_norms,
        int(config_dict['save_interval']),
        filename=path.join(image_folder, "objectives_visual.jpg"),
        title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
        inline=False)

    visualize_objectives_v3(
        attack_objective_values,
        maintain_values,
        accuracy,
        activation_norms,
        int(config_dict['save_interval']),
        filename=path.join(upload_folder, "objectives_visual.pdf"),
        title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
        inline=False)

    # Save the raw data as a CSV
    objectives_data = pd.DataFrame({"accuracy": accuracy,
                                    "attack_objective_values": attack_objective_values,
                                    "maintain_values": maintain_values,
                                    "activation norms": activation_norms,
                                    'alphas': alphas
                                    })

    objectives_data.to_csv(path.join(output_folder, "objectives_data.csv"))
    # this is just in case we want a copy to look at directly later
    objectives_data.to_csv(path.join(upload_folder, "objectives_data.csv"))

    make_cos_sim_chart(data_dict['clip_sims_ii'].mean(dim=(2,3)), 
                                    data_dict['clip_sims_if'].mean(dim=(2,3)), 
                                    upload_folder+f'/{run_name}_pruned_clip.pdf',#HERE
                                    'Clip Comparisons'
                                    )
    make_kt_chart(data_dict['kt_ii'], 
                            data_dict['kt_if'], 
                            upload_folder+f'/{run_name}_pruned_kt.pdf', #HERE
                            'KT Comparisons'
                            )
    clip_d = calc_clip_d(data_dict['clip_sims_ii'], data_dict['clip_sims_if'])
    #print(clip_d.shape)
    #print(f'clip d mean: {clip_d_mean}')
    kt_wam_scores, kt_wam_indices = get_whackamole_score_from_tensor(data_dict['kt_if'])
    #print(kt_wam_scores.shape)
    #print(kt_wam_mean)
    #clip_wam_scores, clip_wam_indices = get_whackamole_score_from_tensor(data_dict['clip_sims_if'].mean(dim=(2,3)))
    clip_wam_scores = get_clip_wam_normalized(data_dict['clip_sims_ii'].mean(dim=(2,3)), data_dict['clip_sims_if'].mean(dim=(2,3)))
    #print(f'clip wam scores shape: {clip_wam_scores.shape}')
    self_kt = get_self_kt(data_dict['kt_if'])
    obj_data = get_objectives_data(output_folder)
    #print(obj_data)
    #print(accuracy)

    # run_names.append(run_name.replace('_', ' '))
    # clip_wams.append(clip_wam_mean.item())
    # clip_d_means.append(clip_d_mean.item())
    # kt_wams.append(kt_wam_mean.item())
    # accuracies.append(accuracy)
    # kt_self_sims.append(self_kt_mean.item())
    # obj_datas.append(obj_data)


    #TODO: VERIFY THIS! AND ARTIFICIAL FOR ViT?
    results = pd.DataFrame(data = {
        'run name' : [run_name.replace('_', ' ')],     
        'clip d mean' : [clip_d.mean().item()],
        'kt self sim' : [self_kt.mean().item()],
        'clip whack-a-mole' : [clip_wam_scores.mean().item()],
        'kt whack-a-mole' : [kt_wam_scores.mean().item()],
        
        'accuracy' : [obj_data['accuracy'].iloc[-1]],
        
    }, #index = run_names
    )
    results.to_csv(f'{upload_folder}/run_metrics.csv')#HERE
    #print(f'{args.results_title}_run_metrics.csv')



def generate_metrics_for_pruned_net(net_directory,
                                    model_arch,
                                    layer,
                                    run_name,
                                    num_pruned,
                                    do_kt, 
                                    do_clip, 
                                    do_validation, 
                                    do_kt_overwrite, 
                                    default_baseline_alexnet=True,
                                    graph_output_folder = None):
    
    if not path.exists(net_directory):
        print(f'ERROR, the results directory {net_directory} does not exist!')
        print('Exiting program')
        exit(0)
    
    #with open(f'{net_directory}/{layer}_l2_pruned_{num_pruned}_pruned_alexnet_results_dict.pt', 'rb') as f:
    results_dict = torch.load(f'{net_directory}/{layer}_l2_pruned_{num_pruned}_pruned_alexnet_results_dict.pt')

    # for key in results_dict.keys():
    #     print(f'dict has key: {key}')

    # config_dict = {}
    # with open(path.join(net_directory, "configuration.txt")) as f:
    #     for line in f:
    #         (key, val) = line.split(':')
    #         config_dict[key] = val.strip()

    #TODO alter code to use csv and image folder for intermediary steps!
    output_folder = net_directory + f'/{layer}_l2_pruned_{num_pruned}_pruned_alexnet/results'
    # run_name = net_directory.split('/')[-1]
    csv_folder = output_folder+'/csv'
    pdf_folder = output_folder + '/PDFs_' + run_name
    metrics_folder = output_folder + '/metrics'
    image_folder = output_folder+'/images'
    
    upload_folder = metrics_folder+f'/{run_name}_visualizations/graphs'
    

    print(f'sending {model_arch}')
    if  default_baseline_alexnet:
        print('using baseline alexnet')
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        init_model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
        init_model.to('cuda')
        final_model = models.alexnet()
        final_model.load_state_dict(torch.load(f'{net_directory}/{layer}_l2_pruned_{num_pruned}_pruned_alexnet_finetuned_state_dict.pt'))
        final_model.to('cuda')
    else:
        init_model, final_model = utils.get_init_final_models(net_directory,layer,#.strip("\'"),
                                                       model_arch = model_arch, get_final_model=True)
    # final_model = utils.get_init_final_models(net_directory, model_arch, get_final_model=False) #Weird but in this case we want the fine tuned pruned model to be the 'final' model
    # init_model = models.alexnet()

    ### To device?
    def ensure_dir(directory):
        if not os.path.exists(directory):
            os.makedirs(directory)

    ensure_dir(output_folder)
    ensure_dir(pdf_folder)
    ensure_dir(csv_folder)
    ensure_dir(image_folder)
    ensure_dir(metrics_folder)
    ensure_dir(upload_folder)
    # Use the configuration file to get information about the run
    # Allows for easier organization of the following code.

    # Generate a line plot of the results

    # unpack some values for run metrics and info
    # attack_objective_values = results_dict['attack_obj_vals']
    maintain_values = results_dict['loss']
    accuracy = results_dict['accuracies']
    # activation_norms = results_dict['activation_norms']
    # alphas = results_dict['alphas']

    # shutil.copyfile(path.join(net_directory, "configuration.txt"), path.join(pdf_folder, "configuration.txt"))

    # print(d['channel'])
    # In particular we need to see what channels are present in the run, and make a list of them
    # if model_arch == 'vit_b32':
    #     channels = utils.get_tuple_from_config_dict(config_dict, 'features')
    #     # print(channels)
    #     if channels[0] == "'a'":
    #         channels = list(range(768))
    # else:
    #     channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
    #     channels = [int(i) for i in channels]
    # # get kt results dict
    # all_steps = [i for i in
    #             range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

    #Generate and store the activations and return the paths to them
    init_acts_path, final_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(
        net_directory, init_model, final_model, model_name = model_arch, is_pruned_net=True)

    init_val_acts_path, final_val_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(
        net_directory, init_model, final_model, model_name=model_arch, data_loader=utils.get_val_topk_dataset_loader(),
          name = 'val', is_pruned_net=True)

    # '''
    # Need to add logic for getting top image indices across top validation activations
    # '''

    init_val_acts, final_val_acts =  load_in_activations_from_paths(init_val_acts_path, final_val_acts_path)
    init_acts, final_acts =  load_in_activations_from_paths(init_acts_path, final_acts_path)
    init_top_norms, init_top_indices = torch.topk(init_acts, k=10, dim=1)
    final_top_norms, final_top_indices = torch.topk(final_acts, k=10, dim=1)

    torch.save(init_top_indices,path.join(metrics_folder, 'init_top_indices.pt'))
    torch.save(final_top_indices,path.join(metrics_folder, 'final_top_indices.pt'))
    #Confirm the same top indices
    # init_top_norms, test_init_top_indices = torch.topk(init_acts, k=10, dim=1)
    # res = results_dict['init_top_indices']-test_init_top_indices.T
    # print(res)

    # init_top_norms, test_final_top_indices = torch.topk(final_acts, k=10, dim=1)
    # res = results_dict['final_top_indices']-test_final_top_indices.T
    # print(res)
    # print('validation activation shape:')
    # print(init_val_acts.shape)
    init_top_val_norms, init_top_val_indices = torch.topk(init_val_acts, k=10, dim=1)
    final_top_val_norms, final_top_val_indices = torch.topk(final_val_acts, k=10, dim=1)


    torch.save(init_top_val_indices,path.join(metrics_folder, 'init_top_indices_val.pt'))
    torch.save(final_top_val_indices,path.join(metrics_folder, 'final_top_indices_val.pt'))
    init_top_val_indices = init_top_val_indices.T
    final_top_val_indices = final_top_val_indices.T
    ###Time to overwrite the top indices from the results dict with the new ones mwahahahahahahahahahahahaha
    #init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
    #TODO: Rewrite this whole thing. Use more functions and less script. Possibly split this whole .py into multiple scripts
    if do_kt:
        print('Doing KT!')
        kt_dict = get_overall_kt_results(output_folder, init_acts_path, final_acts_path, do_kt_overwrite)
        torch.save(kt_dict['cor_ii'], path.join(metrics_folder,'kt_ii.pt'))
        torch.save(kt_dict['cor_if'], path.join(metrics_folder,'kt_if.pt'))
        if do_validation:
            val_kt_dict = get_overall_kt_results(output_folder, init_val_acts_path, final_val_acts_path, do_kt_overwrite)
            torch.save(val_kt_dict['cor_ii'], path.join(metrics_folder,'val_kt_ii.pt'))
            torch.save(val_kt_dict['cor_if'], path.join(metrics_folder,'val_kt_if.pt'))
    init_norms_for_test, init_indices_for_test = torch.topk(init_acts, k=10, dim=1)
    final_norms_for_test, final_indices_for_test = torch.topk(final_acts, k=10, dim=1)
    #Instead of taking the top from the saved indices from training, use the ones we calc'd in this script
    results_dict['init_top_indices'] = init_indices_for_test.T
    results_dict['final_top_indices'] = final_indices_for_test.T


    if do_clip:
        print('Doing Clip!')
        print('clip path: ', path.join(metrics_folder,'full_clip_ii.pt'))
        use_full_embeddings = True
        
        # with open(f'{layer}_l2_pruned_{num_pruned}_pruned_alexnet_results_dict.pt', 'rb') as f:
        #     results_dict = pickle.load(f)

        init_indices = init_top_indices.T
        final_indices = final_top_indices.T
        # print('THE SHAPES WE ARE WORRIED ABOUT:')
        # print(init_top_indices.shape)
        # print(final_top_indices.shape)
        
        if use_full_embeddings:
            use_basic_means = True
            init_clip_similarities, final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_indices, final_indices)
            torch.save(init_clip_similarities, path.join(metrics_folder,'full_clip_ii.pt'))
            torch.save(final_clip_similarities, path.join(metrics_folder,'full_clip_if.pt'))
            if do_validation:
                val_init_clip_similarities, val_final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_top_val_indices, final_top_val_indices)
                torch.save(val_init_clip_similarities, path.join(metrics_folder,'val_full_clip_ii.pt'))
                torch.save(val_final_clip_similarities, path.join(metrics_folder,'val_full_clip_if.pt'))
            print('using full embeddings')
            if use_basic_means:
                print('using means of embeddings cos sim')
                clip_description = 'Mean Similarity'
                init_clip_similarities_to_use = init_clip_similarities.mean(dim=(2,3))
                final_clip_similarities_to_use = final_clip_similarities.mean(dim=(2,3))
                #print(f'clip similarities to use shape: {init_clip_similarities_to_use}')

            else:
                print('using means of the max similarity embeddings')
                clip_description = 'Max Similarity'
                init_clip_similarities_to_use = init_clip_similarities.max(dim=3)[0].mean(dim=2)
                final_clip_similarities_to_use = final_clip_similarities.max(dim=3)[0].mean(dim=2)

        #Takes the mean of the embeddings and gets the 256,256 results to use
        else:
            print('using mean of embeddings')
            clip_description = 'Averaged Embeddings'

            init_clip_similarities_to_use, final_clip_similarities_to_use = clip_descriptions.get_overall_cos_sim_results_with_mean_embeddings(init_indices, final_indices)
        # print('Similiarites to use shapes:')
        # print(init_clip_similarities_to_use.shape)
        # print(final_clip_similarities_to_use.shape)
        utils.save_matrix(init_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_ii')
        utils.save_matrix(final_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_if')



    data_dict = {
        'clip_sims_ii': init_clip_similarities,
        'clip_sims_if': final_clip_similarities,
        'kt_ii': kt_dict['cor_ii'],
        'kt_if': kt_dict['cor_if'],
    }
    # Generate the graphs for the losses, accuracy and norms
    visualize_objectives_v4(
        maintain_values,
        accuracy,
        1,
        filename=path.join(image_folder, "objectives_visual.jpg"),
        title=f"lr: {1e-5}, alpha: {0}, optim: Adam, maintain obj: CE",
        inline=False)

    visualize_objectives_v4(
        maintain_values,
        accuracy,
        1,
        filename=path.join(upload_folder, "objectives_visual.pdf"),
        title=f"lr: {1e-5}, alpha: {0}, optim: Adam, maintain obj: CE",
        inline=False)

    # Save the raw data as a CSV
    objectives_data = pd.DataFrame({"accuracy": accuracy,
                                    #"attack_objective_values": attack_objective_values,
                                    "maintain_values": maintain_values,
                                    #"activation norms": activation_norms,
                                    #'alphas': alphas
                                    })

    objectives_data.to_csv(path.join(output_folder, "objectives_data.csv"))
    # this is just in case we want a copy to look at directly later
    objectives_data.to_csv(path.join(upload_folder, "objectives_data.csv"))

    make_cos_sim_chart(data_dict['clip_sims_ii'].mean(dim=(2,3)), 
                                    data_dict['clip_sims_if'].mean(dim=(2,3)), 
                                    upload_folder+f'/{run_name}_pruned_clip.pdf',
                                    'Clip Comparisons'
                                    )
    make_kt_chart(data_dict['kt_ii'], 
                            data_dict['kt_if'], 
                            upload_folder+f'/{run_name}_pruned_kt.pdf',
                            'KT Comparisons'
                            )
    clip_d = calc_clip_d(data_dict['clip_sims_ii'], data_dict['clip_sims_if'])
    #print(clip_d.shape)
    #print(f'clip d mean: {clip_d_mean}')
    kt_wam_scores, kt_wam_indices = get_whackamole_score_from_tensor(data_dict['kt_if'])
    #print(kt_wam_scores.shape)
    #print(kt_wam_mean)
    #clip_wam_scores, clip_wam_indices = get_whackamole_score_from_tensor(data_dict['clip_sims_if'].mean(dim=(2,3)))
    clip_wam_scores = get_clip_wam_normalized(data_dict['clip_sims_ii'].mean(dim=(2,3)), data_dict['clip_sims_if'].mean(dim=(2,3)))
    #print(f'clip wam scores shape: {clip_wam_scores.shape}')
    self_kt = get_self_kt(data_dict['kt_if'])
    obj_data = get_objectives_data(output_folder)
    #print(obj_data)
    #print(accuracy)

    # run_names.append(run_name.replace('_', ' '))
    # clip_wams.append(clip_wam_mean.item())
    # clip_d_means.append(clip_d_mean.item())
    # kt_wams.append(kt_wam_mean.item())
    # accuracies.append(accuracy)
    # kt_self_sims.append(self_kt_mean.item())
    # obj_datas.append(obj_data)


    #TODO: VERIFY THIS!AND ARTIFICIAL FOR ViT?
    results = pd.DataFrame(data = {
        'run name' : [run_name.replace('_', ' ')],     
        'clip d mean' : [clip_d.mean().item()],
        'kt self sim' : [self_kt.mean().item()],
        'clip whack-a-mole' : [clip_wam_scores.mean().item()],
        'kt whack-a-mole' : [kt_wam_scores.mean().item()],
        
        'accuracy' : [obj_data['accuracy'].iloc[-1]],
        
    }, #index = run_names
    )
    results.to_csv(f'{upload_folder}/run_metrics.csv')
    #print(f'{args.results_title}_run_metrics.csv')


if __name__ =='__main__':
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    parser.add_argument("--do-kt-overwrite", action='store_true')
    parser.add_argument("--do-clip", action='store_true')
    parser.add_argument("--do-kt", action='store_true')
    parser.add_argument("--do-validation", action='store_true')
    parser.add_argument("--default-baseline-alexnet", action='store_true')
    parser.add_argument("--is-multilayer", action='store_true') 
    args = parser.parse_args()
    if args.is_multilayer:
        generate_metrics_for_multilayer_attack(args.results_directory, 
                                               args.do_kt, args.do_clip, args.do_validation, args.do_kt_overwrite, args.default_baseline_alexnet)
    else:        
        generate_metrics(args.results_directory, args.do_kt, args.do_clip, args.do_validation, args.do_kt_overwrite, args.default_baseline_alexnet)