import os.path as path
import os
# from reportlab.platypus import SimpleDocTemplate, Image, Paragraph, Spacer
# from reportlab.lib.pagesizes import A4, letter
# from PIL import Image as PImage
# from reportlab.lib.styles import ParagraphStyle
# import torchvision.transforms as transforms
import torch
import pickle
import pandas as pd
import shutil
import numpy as np
# from optimization_requirements import visualize_objectives_v3
# from torchvision import datasets, models, transforms

import clip_descriptions
import utils
# from tqdm import tqdm
import matplotlib.pyplot as plt
from analyze_activations_by_channel import get_overall_kt_results, load_in_activations_from_paths , get_kt_from_vectors
# from PIL import Image
#import get_imagenet_activations_by_channel
import get_imagenet_activations_by_channel_v2
# import matplotlib.image as mpimg

# from scripts.get_imagenet_activations_by_channel import get_and_sort_activations
# import torch.nn.functional as F

def visualize_objectives_v3(objective_values, maintain_values, accuracy, activation_norms, save_interval,
                            filename=None, title="<insert-title>", inline=False):
    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))

    # Design the plots

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
    plt.suptitle(title, fontsize=14)
    ax1.set_title('Losses')
    ax1.set_xlabel("step", fontsize=14)
    ax1.plot(steps, objective_values, label="Attack Loss", color='red', marker='.')
    ax1.set_ylabel("attack loss", color="red", fontsize=14)
    #Create twin to plot maintain values

    ax1_1 = ax1.twinx()
    ax1_1.plot(steps, maintain_values, label="Maintain Loss", color='blue', marker='.')
    ax1_1.set_ylabel("maintain loss", color="blue", fontsize=14)

    #ax1.legend(loc='best')
    #ax1_1.legend(loc='best')

    ax2.set_title('Accuracy and Activation Norm')
    ax2.plot(steps, accuracy, label='Accuracy', color='blue', marker='.')
    ax2.set_ylabel("accuracy", color='blue', fontsize=14)
    #Create twin to plot norm
    ax2_1 = ax2.twinx()
    ax2_1.plot(steps, activation_norms, label='Step Activation Relative Norm', color='red', marker='.')
    ax2_1.set_xlabel("step", fontsize=14)
    ax2_1.set_ylabel("relative norm", color='red', fontsize=14)
    #ax2.legend(loc='best')
    #ax2_1.legend(loc='best')


    if filename is not None:
        plt.savefig(filename)

    if inline:
        plt.show()
    plt.close()

def generate_metrics(results_directory, do_kt, do_clip, do_validation, do_kt_overwrite):
    if not path.exists(results_directory):
        print(f'ERROR, the results directory {results_directory} does not exist!')
        print('Exiting program')
        assert False

    with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
        results_dict = pickle.load(f)

    for key in results_dict.keys():
        print(f'dict has key: {key}')

    config_dict = {}
    with open(path.join(results_directory, "configuration.txt")) as f:
        for line in f:
            (key, val) = line.split(':')
            config_dict[key] = val.strip()

    #TODO alter code to use csv and image folder for intermediary steps!
    output_folder = results_directory + '/results'
    run_name = results_directory.split('/')[-1]
    csv_folder = output_folder+'/csv'
    pdf_folder = output_folder + '/PDFs_' + run_name
    metrics_folder = output_folder + '/metrics'
    image_folder = output_folder+'/images'


    model_arch = config_dict['arch']
    print(f'using {model_arch}')
    init_model, final_model = utils.get_init_final_models(results_directory, model_arch)

    ### To device?
    def ensure_dir(directory):
        if not os.path.exists(directory):
            os.makedirs(directory)

    ensure_dir(output_folder)
    ensure_dir(pdf_folder)
    ensure_dir(csv_folder)
    ensure_dir(image_folder)
    ensure_dir(metrics_folder)
    # Use the configuration file to get information about the run
    # Allows for easier organization of the following code.

    # Generate a line plot of the results

    # unpack some values for run metrics and info
    attack_objective_values = results_dict['attack_obj_vals']
    maintain_values = results_dict['maintain_obj_vals']
    accuracy = results_dict['accuracy']
    activation_norms = results_dict['activation_norms']
    alphas = results_dict['alphas']

    shutil.copyfile(path.join(results_directory, "configuration.txt"), path.join(pdf_folder, "configuration.txt"))

    # print(d['channel'])
    # In particular we need to see what channels are present in the run, and make a list of them

    channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
    channels = [int(i) for i in channels]
    # get kt results dict
    all_steps = [i for i in
                range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

    #Generate and store the activations and return the paths to them
    init_acts_path, final_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model, model_name = model_arch)

    init_val_acts_path, final_val_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations_v2(results_directory, init_model, final_model, model_name=model_arch, data_loader=utils.get_val_topk_dataset_loader(), name = 'val')

    # '''
    # Need to add logic for getting top image indices across top validation activations
    # '''

    init_val_acts, final_val_acts =  load_in_activations_from_paths(init_val_acts_path, final_val_acts_path)
    init_acts, final_acts =  load_in_activations_from_paths(init_acts_path, final_acts_path)
    init_top_norms, init_top_indices = torch.topk(init_acts, k=10, dim=1)
    final_top_norms, final_top_indices = torch.topk(final_acts, k=10, dim=1)

    torch.save(init_top_indices,path.join(metrics_folder, 'init_top_indices.pt'))
    torch.save(final_top_indices,path.join(metrics_folder, 'final_top_indices.pt'))
    #Confirm the same top indices
    # init_top_norms, test_init_top_indices = torch.topk(init_acts, k=10, dim=1)
    # res = results_dict['init_top_indices']-test_init_top_indices.T
    # print(res)

    # init_top_norms, test_final_top_indices = torch.topk(final_acts, k=10, dim=1)
    # res = results_dict['final_top_indices']-test_final_top_indices.T
    # print(res)
    print('validation activation shape:')
    print(init_val_acts.shape)
    init_top_val_norms, init_top_val_indices = torch.topk(init_val_acts, k=10, dim=1)
    final_top_val_norms, final_top_val_indices = torch.topk(final_val_acts, k=10, dim=1)


    torch.save(init_top_val_indices,path.join(metrics_folder, 'init_top_indices_val.pt'))
    torch.save(final_top_val_indices,path.join(metrics_folder, 'final_top_indices_val.pt'))
    init_top_val_indices = init_top_val_indices.T
    final_top_val_indices = final_top_val_indices.T
    ###Time to overwrite the top indices from the results dict with the new ones mwahahahahahahahahahahahaha
    #init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
    #TODO: Rewrite this whole thing. Use more functions and less script. Possibly split this whole .py into multiple scripts
    if do_kt:
        kt_dict = get_overall_kt_results(output_folder, init_acts_path, final_acts_path, do_kt_overwrite)
        torch.save(kt_dict['cor_ii'], path.join(metrics_folder,'kt_ii.pt'))
        torch.save(kt_dict['cor_if'], path.join(metrics_folder,'kt_if.pt'))
        if do_validation:
            val_kt_dict = get_overall_kt_results(output_folder, init_val_acts_path, final_val_acts_path, do_kt_overwrite)
            torch.save(val_kt_dict['cor_ii'], path.join(metrics_folder,'val_kt_ii.pt'))
            torch.save(val_kt_dict['cor_if'], path.join(metrics_folder,'val_kt_if.pt'))
    init_norms_for_test, init_indices_for_test = torch.topk(init_acts, k=10, dim=1)
    final_norms_for_test, final_indices_for_test = torch.topk(final_acts, k=10, dim=1)
    #Instead of taking the top from the saved indices from training, use the ones we calc'd in this script
    results_dict['init_top_indices'] = init_indices_for_test.T
    results_dict['final_top_indices'] = final_indices_for_test.T


    if do_clip:
        use_full_embeddings = True
        
        with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
            results_dict = pickle.load(f)

        init_indices = results_dict['init_top_indices']
        final_indices = results_dict['final_top_indices']

        if use_full_embeddings:
            use_basic_means = True
            init_clip_similarities, final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_indices, final_indices)
            torch.save(init_clip_similarities, path.join(metrics_folder,'full_clip_ii.pt'))
            torch.save(final_clip_similarities, path.join(metrics_folder,'full_clip_if.pt'))
            if args.do_validation:
                val_init_clip_similarities, val_final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_top_val_indices, final_top_val_indices)
                torch.save(val_init_clip_similarities, path.join(metrics_folder,'val_full_clip_ii.pt'))
                torch.save(val_final_clip_similarities, path.join(metrics_folder,'val_full_clip_if.pt'))
            print('using full embeddings')
            if use_basic_means:
                print('using means of embeddings cos sim')
                clip_description = 'Mean Similarity'
                init_clip_similarities_to_use = init_clip_similarities.mean(dim=(2,3))
                final_clip_similarities_to_use = final_clip_similarities.mean(dim=(2,3))
                print(f'clip similarities to use shape: {init_clip_similarities_to_use}')
                


            else:
                print('using means of the max similarity embeddings')
                clip_description = 'Max Similarity'
                init_clip_similarities_to_use = init_clip_similarities.max(dim=3)[0].mean(dim=2)
                final_clip_similarities_to_use = final_clip_similarities.max(dim=3)[0].mean(dim=2)

        #Takes the mean of the embeddings and gets the 256,256 results to use
        else:
            print('using mean of embeddings')
            clip_description = 'Averaged Embeddings'

            init_clip_similarities_to_use, final_clip_similarities_to_use = clip_descriptions.get_overall_cos_sim_results_with_mean_embeddings(init_indices, final_indices)
        print('Similiarites to use shapes:')
        print(init_clip_similarities_to_use.shape)
        print(final_clip_similarities_to_use.shape)
        utils.save_matrix(init_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_ii')
        utils.save_matrix(final_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_if')

    # Generate the graphs for the losses, accuracy and norms
    visualize_objectives_v3(
        attack_objective_values,
        maintain_values,
        accuracy,
        activation_norms,
        int(config_dict['save_interval']),
        filename=path.join(image_folder, "objectives_visual.jpg"),
        title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
        inline=False)

    visualize_objectives_v3(
        attack_objective_values,
        maintain_values,
        accuracy,
        activation_norms,
        int(config_dict['save_interval']),
        filename=path.join(pdf_folder, "objectives_visual.pdf"),
        title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
        inline=False)

    # Save the raw data as a CSV
    objectives_data = pd.DataFrame({"accuracy": accuracy,
                                    "attack_objective_values": attack_objective_values,
                                    "maintain_values": maintain_values,
                                    "activation norms": activation_norms,
                                    'alphas': alphas
                                    })

    objectives_data.to_csv(path.join(output_folder, "objectives_data.csv"))
    # this is just in case we want a copy to look at directly later
    objectives_data.to_csv(path.join(pdf_folder, "objectives_data.csv"))