from argparse import ArgumentParser
import os.path as path
import os
# from reportlab.platypus import SimpleDocTemplate, Image, Paragraph, Spacer
# from reportlab.lib.pagesizes import A4, letter
# from PIL import Image as PImage
# from reportlab.lib.styles import ParagraphStyle
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torch
import pickle
import pandas as pd
import shutil
from optimization_requirements import visualize_objectives_v3, save_top_class_info, get_class_labels_dict
from torchvision import datasets, models, transforms

import clip_descriptions
import utils
from tqdm import tqdm
import matplotlib.pyplot as plt
from scripts.analyze_activations_by_channel import get_overall_kt_results, load_in_activations_from_paths , get_kt_from_vectors
from PIL import Image
#import get_imagenet_activations_by_channel
import get_imagenet_activations_by_channel_v2
import matplotlib.image as mpimg

from scripts.get_imagenet_activations_by_channel import get_and_sort_activations
import torch.nn.functional as F

#TODO: Scrap this script and do a total rewrite. This has grown way beyond what it should ever have become

parser = ArgumentParser()
parser.add_argument("--results-directory", type=str)
parser.add_argument("--num-img-to-save", type=int, default=10)
parser.add_argument("--num-top-classes", type=int, default=6)
parser.add_argument("--track-image-rankings", action='store_true')
parser.add_argument("--num-indices-to-track", type=int, default=10)
parser.add_argument("--do-kt", action='store_true')
parser.add_argument("--track-intermediate-steps", action="store_true")
parser.add_argument("--do-overwrite", action='store_true')
parser.add_argument("--do-kt-overwrite", action='store_true')
parser.add_argument("--do-clip", action='store_true')
parser.add_argument("--do-validation", action='store_true')
parser.add_argument("--do-pdfs", action='store_true')
args = parser.parse_args()
results_directory = args.results_directory

VERSION = '0.2.0'
print('Generating results pdf from collected run results')
print(f'VERSION: {VERSION}')
print(f'Torch cuda version: {torch.version.cuda}')
# keys used:
# results_dict['final_top_indices']
# results_dict['init_top_indices']

# results_dict['attack_obj_vals']
# results_dict['accuracy']
# results_dict['maintain_obj_vals']
# results_dict['activation_norms']

# Note that either there will be an init and final, or a list containing all batches
# <results_dict['artificial_images']> OR
# <results_dict['final_artificial_images'] AND
# results_dict['init_artificial_images']>
if not path.exists(results_directory):
    print(f'ERROR, the results directory {results_directory} does not exist!')
    print('Exiting program')
    exit(0)

with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
    results_dict = pickle.load(f)
    print(len(results_dict['init_artificial_images']))
    if len(results_dict['init_artificial_images'])== 1:
        results_dict['init_artificial_images'] = results_dict['init_artificial_images'][0]
        results_dict['final_artificial_images'] = results_dict['final_artificial_images'][0]


for key in results_dict.keys():
    print(f'dict has key: {key}')

config_dict = {}
with open(path.join(results_directory, "configuration.txt")) as f:
    for line in f:
        (key, val) = line.split(':')
        config_dict[key] = val.strip()

#TODO alter code to use csv and image folder for intermediary steps!
output_folder = args.results_directory + '/results'
run_name = args.results_directory.split('/')[-1]
csv_folder = output_folder+'/csv'
pdf_folder = output_folder + '/PDFs_' + run_name
metrics_folder = output_folder + '/metrics'
image_folder = output_folder+'/images'


model_arch = config_dict['arch']
print(f'using {model_arch}')
init_model, final_model = utils.get_init_final_models(results_directory, model_arch)

### To device?
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

ensure_dir(output_folder)
ensure_dir(pdf_folder)
ensure_dir(csv_folder)
ensure_dir(image_folder)
ensure_dir(metrics_folder)
# Use the configuration file to get information about the run
# Allows for easier organization of the following code.

# Generate a line plot of the results

# unpack some values for run metrics and info
attack_objective_values = results_dict['attack_obj_vals']
maintain_values = results_dict['maintain_obj_vals']
accuracy = results_dict['accuracy']
activation_norms = results_dict['activation_norms']
alphas = results_dict['alphas']



shutil.copyfile(path.join(results_directory, "configuration.txt"), path.join(pdf_folder, "configuration.txt"))


# print(d['channel'])
# In particular we need to see what channels are present in the run, and make a list of them

channels = utils.get_tuple_from_config_dict(config_dict, 'channel')
channels = [int(i) for i in channels]
# get kt results dict
all_steps = [i for i in
             range(0, int(config_dict['nsteps']) + 1, int(config_dict['save_interval']))]

#Generate and store the activations and return the paths to them
print()
init_acts_path, final_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations(args, init_model, final_model, model_name = model_arch)

init_val_acts_path, final_val_acts_path = get_imagenet_activations_by_channel_v2.get_and_sort_activations(args, init_model, final_model, model_name=model_arch, data_loader=utils.get_val_topk_dataset_loader(), name = 'val')

'''
Need to add logic for getting top image indices across top validation activations
'''

init_val_acts, final_val_acts =  load_in_activations_from_paths(init_val_acts_path, final_val_acts_path)
init_acts, final_acts =  load_in_activations_from_paths(init_acts_path, final_acts_path)
init_top_norms, init_top_indices = torch.topk(init_acts, k=10, dim=1)
final_top_norms, final_top_indices = torch.topk(final_acts, k=10, dim=1)

torch.save(init_top_indices,path.join(metrics_folder, 'init_top_indices.pt'))
torch.save(final_top_indices,path.join(metrics_folder, 'final_top_indices.pt'))
#Confirm the same top indices
# init_top_norms, test_init_top_indices = torch.topk(init_acts, k=10, dim=1)
# res = results_dict['init_top_indices']-test_init_top_indices.T
# print(res)

# init_top_norms, test_final_top_indices = torch.topk(final_acts, k=10, dim=1)
# res = results_dict['final_top_indices']-test_final_top_indices.T
# print(res)
print('validation activation shape:')
print(init_val_acts.shape)
init_top_val_norms, init_top_val_indices = torch.topk(init_val_acts, k=10, dim=1)
final_top_val_norms, final_top_val_indices = torch.topk(final_val_acts, k=10, dim=1)


torch.save(init_top_val_indices,path.join(metrics_folder, 'init_top_indices_val.pt'))
torch.save(final_top_val_indices,path.join(metrics_folder, 'final_top_indices_val.pt'))
init_top_val_indices = init_top_val_indices.T
final_top_val_indices = final_top_val_indices.T
###Time to overwrite the top indices from the results dict with the new ones mwahahahahahahahahahahahaha
#init_activations, final_activations = load_in_activations_from_paths(init_activations_path, final_activations_path)
#TODO: Rewrite this whole thing. Use more functions and less script. Possibly split this whole .py into multiple scripts
if args.do_kt:
    kt_dict = get_overall_kt_results(output_folder, init_acts_path, final_acts_path, args.do_kt_overwrite)
    torch.save(kt_dict['cor_ii'], path.join(metrics_folder,'kt_ii.pt'))
    torch.save(kt_dict['cor_if'], path.join(metrics_folder,'kt_if.pt'))
    if args.do_validation:
        val_kt_dict = get_overall_kt_results(output_folder, init_val_acts_path, final_val_acts_path, args.do_kt_overwrite)
        torch.save(val_kt_dict['cor_ii'], path.join(metrics_folder,'val_kt_ii.pt'))
        torch.save(val_kt_dict['cor_if'], path.join(metrics_folder,'val_kt_if.pt'))
init_norms_for_test, init_indices_for_test = torch.topk(init_acts, k=10, dim=1)
final_norms_for_test, final_indices_for_test = torch.topk(final_acts, k=10, dim=1)
#Instead of taking the top from the saved indices from training, use the ones we calc'd in this script
results_dict['init_top_indices'] = init_indices_for_test.T
results_dict['final_top_indices'] = final_indices_for_test.T


if args.do_clip:
    use_full_embeddings = True
    
    with open(f'{results_directory}/results_dict.pkl', 'rb') as f:
        results_dict = pickle.load(f)
        if len(results_dict['init_artificial_images'])== 1:
            results_dict['init_artificial_images'] = results_dict['init_artificial_images'][0]
            results_dict['final_artificial_images'] = results_dict['final_artificial_images'][0]
        

    init_indices = results_dict['init_top_indices']
    final_indices = results_dict['final_top_indices']

    if use_full_embeddings:
        use_basic_means = True
        init_clip_similarities, final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_indices, final_indices)
        torch.save(init_clip_similarities, path.join(metrics_folder,'full_clip_ii.pt'))
        torch.save(final_clip_similarities, path.join(metrics_folder,'full_clip_if.pt'))
        if args.do_validation:
            val_init_clip_similarities, val_final_clip_similarities = clip_descriptions.get_overall_cos_sim_results(init_top_val_indices, final_top_val_indices)
            torch.save(val_init_clip_similarities, path.join(metrics_folder,'val_full_clip_ii.pt'))
            torch.save(val_final_clip_similarities, path.join(metrics_folder,'val_full_clip_if.pt'))
        print('using full embeddings')
        if use_basic_means:
            print('using means of embeddings cos sim')
            clip_description = 'Mean Similarity'
            init_clip_similarities_to_use = init_clip_similarities.mean(dim=(2,3))
            final_clip_similarities_to_use = final_clip_similarities.mean(dim=(2,3))
            print(f'clip similarities to use shape: {init_clip_similarities_to_use}')
            


        else:
            print('using means of the max similarity embeddings')
            clip_description = 'Max Similarity'
            init_clip_similarities_to_use = init_clip_similarities.max(dim=3)[0].mean(dim=2)
            final_clip_similarities_to_use = final_clip_similarities.max(dim=3)[0].mean(dim=2)

    #Takes the mean of the embeddings and gets the 256,256 results to use
    else:
        print('using mean of embeddings')
        clip_description = 'Averaged Embeddings'

        init_clip_similarities_to_use, final_clip_similarities_to_use = clip_descriptions.get_overall_cos_sim_results_with_mean_embeddings(init_indices, final_indices)
    print('Similiarites to use shapes:')
    print(init_clip_similarities_to_use.shape)
    print(final_clip_similarities_to_use.shape)
    utils.save_matrix(init_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_ii')
    utils.save_matrix(final_clip_similarities_to_use, csv_folder, 'Averaged_CosSim_if')

# Generate the graphs for the losses, accuracy and norms
visualize_objectives_v3(
    attack_objective_values,
    maintain_values,
    accuracy,
    activation_norms,
    int(config_dict['save_interval']),
    filename=path.join(image_folder, "objectives_visual.jpg"),
    title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
    inline=False)

visualize_objectives_v3(
    attack_objective_values,
    maintain_values,
    accuracy,
    activation_norms,
    int(config_dict['save_interval']),
    filename=path.join(pdf_folder, "objectives_visual.pdf"),
    title=f"lr: {config_dict['lr']}, alpha: {config_dict['alpha']}, optim: {config_dict['optimizer_type']}, maintain obj: {config_dict['maintain_obj']}",
    inline=False)

# Save the raw data as a CSV
objectives_data = pd.DataFrame({"accuracy": accuracy,
                                "attack_objective_values": attack_objective_values,
                                "maintain_values": maintain_values,
                                "activation norms": activation_norms,
                                'alphas': alphas
                                })

objectives_data.to_csv(path.join(output_folder, "objectives_data.csv"))
# this is just in case we want a copy to look at directly later
objectives_data.to_csv(path.join(pdf_folder, "objectives_data.csv"))
print('Visualization complete!')
class_dict = utils.get_class_labels_dict()



def analyze_clip_similarites_for_table(init_clip_similarities_fc, final_clip_similarities_fc, channel_fc, for_text=False, description=''):
    
    avg_initial_similarity = init_clip_similarities_fc[channel_fc].mean()
    avg_final_similarity = final_clip_similarities_fc[channel_fc].mean()
    init_self_similarity = init_clip_similarities_fc[channel_fc, channel_fc]
    final_self_similarity = final_clip_similarities_fc[channel_fc, channel_fc]

    #Initially the most similar channel will be itself, so take the most similar one afterwards
    num_entries = 3
    similarity_init, most_similar_init_channels = torch.topk(init_clip_similarities_fc[channel_fc], k=num_entries)
    most_similar_init_channel = most_similar_init_channels

    similarity_final, most_similar_final_channels = torch.topk(final_clip_similarities_fc[channel_fc], k=num_entries)

    #data = [['C1', 'C2', metric, 'Description'],
    #a_string = [quick_summary, value, description]
    ###table1
    table1_title = f"CLIP AGGREGATE INFO CHANNEL {channel_fc}, {description}"
    table1_columns = ['Cosine Similarity', 'Description']
    table1_data = [
        [f'{avg_initial_similarity:.3f}', 'Average similarity between this channel and other pre-attack channels'],
        [f'{avg_final_similarity:.3f}', 'Average similarity between this channel and post-attack channels'],
        [f'{init_self_similarity:.3f}', 'Similarity between this channel and itself pre-training'],
        [f'{final_self_similarity:.3f}', 'Similarity between this channel and itself post-training'],
    ]

    table2_title = 'Most Similar Channels Pre-training'
    table2_columns = ['CosSim', 'Channel', 'Description']
    table2_data = []

    for sim, c in zip(similarity_init, most_similar_init_channels):
        table2_data.append([f'{sim:.3f}', f'{c}i', 'Sample similarity of a top similarity channel pre-training'])

    table3_title = 'Most Similar Channels Post-training'
    table3_columns = ['CosSim', 'Channel', 'Description']
    table3_data = []
    for sim, c in zip(similarity_final, most_similar_final_channels):
        table3_data.append([f'{sim:.3f}', f'{c}f', 'Sample similarity of a top similarity channel after training'])


    table1 = {
        'title': table1_title,
        'columns': table1_columns,
        'data': table1_data
    }
    table2 = {
        'title': table2_title,
        'columns': table2_columns,
        'data': table2_data
    }
    table3 = {
        'title': table3_title,
        'columns': table3_columns,
        'data': table3_data
    }
    tables_dict = {
        'table1': table1,
        'table2': table2,
        'table3': table3,
    }

    return tables_dict

def take_activations_topk_them(acts, output_path, descriptor = '', k=100,):

    #acts = torch.load(acts_path)

    top_acts, top_indices = torch.topk(acts,k,dim=1)
    #print(top_acts.shape)
    
    torch.save(top_acts, output_path+f'/{descriptor}top_activations.pt')
    torch.save(top_indices, output_path+f'/{descriptor}top_indices.pt')
    print(output_path+f'/{descriptor}top_activations.pt')
    print(output_path+f'/{descriptor}top_indices.pt')

take_activations_topk_them(init_acts, metrics_folder, 'init_')
take_activations_topk_them(final_acts, metrics_folder, 'final_')


def analyze_similarites_for_text_and_matrices(init_clip_similarities_fc, 
                                              final_clip_similarities_fc, 
                                              channel_fc, kt_dict, 
                                              for_text=True, 
                                              output_directory = pdf_folder,
                                              file_name = 'metrics'):
    
    cos_tables_dict = analyze_clip_similarites_for_table(init_clip_similarities_fc, final_clip_similarities_fc, channel_fc, for_text)
    kt_tables_dict, _ = analyze_kt_channel_for_table(kt_dict, channel_fc)
    splitter = ' & '
    table_data = []
    for key in cos_tables_dict.keys():
        my_dict = cos_tables_dict[key]
        table_data.append(my_dict['title'])
        table_data.append('\n')
        columns = my_dict['columns']
        columns = splitter.join(columns)
        table_data.append(columns)
        table_data.append('\n')
        for strings in my_dict['data']:
            table_data.append(splitter.join(strings))
            table_data.append('\n')
    table_data.append('\n')
    for key in kt_tables_dict.keys():
        my_dict = kt_tables_dict[key]
        table_data.append(my_dict['title'])
        table_data.append('\n')
        columns = my_dict['columns']
        columns = splitter.join(columns)
        table_data.append(columns)
        table_data.append('\n')
        for strings in my_dict['data']:
            table_data.append(splitter.join(strings))
            table_data.append('\n')

    save_path = os.path.join(output_directory, f'{file_name}_channel_{channel_fc}.txt')
    with open(save_path, 'w') as f:
        for s in table_data:
            f.write(s)
    
    return save_path

def analyze_kt_channel_for_table(kt_dict, channel):
    correlations_ii = kt_dict['cor_ii']
    correlations_if = kt_dict['cor_if']
    pvalues_ii = kt_dict['p_ii']
    pvalues_if=kt_dict['p_if']

    avg_initial_similarity = correlations_ii[channel].mean()
    avg_final_similarity = correlations_if[channel].mean()
    self_similarity = correlations_if[channel, channel]

    #Initially the most similar channel will be itself, so take the most similar one afterwards
    num_entries = 3
    similarity_init, most_similar_init_channels = torch.topk(correlations_ii[channel], k=num_entries+1)
    most_similar_init_channel = most_similar_init_channels[1]

    similarity_final, most_similar_final_channels = torch.topk(correlations_if[channel], k=num_entries)

    #data = [['C1', 'C2', metric, 'Description'],
    #a_string = [quick_summary, value, description]
    ###table1
    table1_title = f"KT AGGREGATE INFO CHANNEL {channel}"
    table1_columns = ['KT correlation', 'Description']
    table1_data = [
        [f'{avg_initial_similarity:.3f}', 'Average similarity between this channel and other pre-attack channels'],
        [f'{avg_final_similarity:.3f}', 'Average similarity between this channel and post-attack channels'],
        [f'{self_similarity:.3f}', 'Similarity between this channel and itself post-training']
    ]

    table2_title = 'Most Similar Channels Pre-training'
    table2_columns = ['KT', 'Channel', 'Description']
    table2_data = []
    for sim, c in zip(similarity_init, most_similar_init_channels):
        table2_data.append([f'{sim:.3f}', f'{c}i', 'Sample correlation of a top similarity channel pre-training'])

    table3_title = 'Most Similar Channels Post-training'
    table3_columns = ['KT', 'Channel', 'Description']
    table3_data = []
    for sim, c in zip(similarity_final, most_similar_final_channels):
        table3_data.append([f'{sim:.3f}', f'{c}f', 'Sample correlation of a top similarity channel after training'])


    table1 = {
        'title': table1_title,
        'columns': table1_columns,
        'data': table1_data
    }
    table2 = {
        'title': table2_title,
        'columns': table2_columns,
        'data': table2_data
    }
    table3 = {
        'title': table3_title,
        'columns': table3_columns,
        'data': table3_data
    }

    tables_dict = {
        'table1': table1,
        'table2': table2,
        'table3': table3,
    }
    return tables_dict, most_similar_final_channels

#def analyze_cosine_similarities(cosine_similarites, channel)


def grab_images_from_index_vector(indices, data_loader=None):
    # global dataset, index
    if data_loader is None:
        dataset = utils.get_results_dataloader().dataset
    else:
        dataset = data_loader.dataset
        # print(indices.shape)
    print(f'dataset length: {len(dataset)}')
    def grab_image(idx):
        return dataset[idx][0]

    images = []
    for index in indices:
        images.append(grab_image(index))
    image_stack = torch.stack(images)
    return image_stack


def grab_labels_from_index_vector(indices):
    # global dataset, index
    if data_loader is None:
        dataset = utils.get_results_dataloader().dataset
    else:
        dataset = data_loader.dataset
        # print(indices.shape)

    def grab_label(idx):
        return dataset[idx][1]

    labels = []
    for index in indices:
        labels.append(grab_label(index))
    return labels


def grab_images_labels_activations_from_index_vector(indices, all_activations, data_loader = None):

    # global dataset, index
    if data_loader is None:
        dataset = utils.get_results_dataloader().dataset
    else:
        dataset = data_loader.dataset
        # print(indices.shape)

    def grab_image_and_label(idx):
        return dataset[idx][0], dataset[idx][1]

    images = []
    labels = []
    activations = []

    for index in indices:
        image, label = grab_image_and_label(index)
        images.append(image)
        labels.append(label)
        activations.append(all_activations[index])
    return images, labels, activations



def grab_images_and_labels_from_index_vector(indices):
    # global dataset, index
    if data_loader is None:
        dataset = utils.get_results_dataloader().dataset
    else:
        dataset = data_loader.dataset
        # print(indices.shape)

    def grab_image_and_label(idx):
        return dataset[idx][0], dataset[idx][1]

    images = []
    labels = []

    for index in indices:
        image, label = grab_image_and_label(index)
        images.append(image)
        labels.append(label)

    return images, labels


has_final_top_indices = 'final_top_indices' in results_dict.keys()
has_init_top_indices = 'init_top_indices' in results_dict.keys()
has_artificial_images = 'artificial_images' in results_dict.keys()
has_init_artificial_images = 'init_artificial_images' in results_dict.keys()
has_final_artificial_images = 'final_artificial_images' in results_dict.keys()
data_loader = utils.get_results_dataloader()




# takes a 3,h,w image and returns a labelled version
def label_image(image, label, activation=None, image_dim = [420,420], verbose = True):
    im_size = image.shape[1]
    fig, ax = plt.subplots()
    label = label.split(',')[0]

    #UNCOMMENT TO INCLUDE ACTIVATION NUMBER
    #if activation is not None:
    #    label = label + f" {activation:.1f}"

    plt.text(im_size * 0.05, im_size * 0.15, label, size=im_size * 0.1, rotation=0.,
             bbox=dict(boxstyle="round",
                       ec=(0., 0.0, 0.0),
                       fc=(1., 1, 1),
                       )
             )
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.imshow(image.permute(1, 2, 0))
    fig.tight_layout()
    fig.savefig('temp.jpg', bbox_inches='tight', )
    plt.close()
    fig, ax = plt.subplots()

    im = Image.open('temp.jpg')
    transform = transforms.Compose([transforms.ToTensor()])
    labelled_image = transform(im)[:, 20:-30, 20:-30]
    if not list(labelled_image.shape) == [3, image_dim[0], image_dim[1]]:
        if verbose:
            print(f'labelled image has shape {labelled_image.shape}')
            print(f'shape of [3,420,420] was expected. Resizing')
        labelled_image = transforms.Resize(size=[image_dim[0], image_dim[1]])(labelled_image)
    plt.close()
    return labelled_image


# Some code to output columns with the top images.
def make_column_images(pdf_folder, 
                       n_top, top_init_indices, top_final_indices,
                       validation = False):
    def get_and_label_top_images (topk, indices, label, data_loader):
        print(f'labelling images {label}')
        all_images = []

        for k in range(topk):
            images = grab_images_from_index_vector(indices[k], data_loader)
            print(f'labelling image set {k+1} of {topk}')
            for i, image in enumerate(tqdm(images)):
                images[i] = label_image(image, f'channel {i}{label}',
                                        image_dim=[image.shape[1], image.shape[2]],
                                        verbose=False)
            all_images.append(images)
        return all_images
    if validation:
        data_loader=utils.get_results_dataloader(split="val") 
        initial_images = get_and_label_top_images(n_top, top_init_indices, 'i_val', data_loader)
        final_images = get_and_label_top_images(n_top, top_final_indices, 'f_val', data_loader)
    else:
        data_loader = None 
        initial_images = get_and_label_top_images(n_top, top_init_indices, 'i', data_loader)
        final_images = get_and_label_top_images(n_top, top_final_indices, 'f', data_loader)

       
    print('init indices shape', top_init_indices.shape)

    #size in inches of images:
    size = 1.5
    fig, axs = plt.subplots(top_init_indices.shape[1], 
                            n_top*2, 
                            figsize=(2*n_top*size, size*top_init_indices.shape[1])
                            )
    for i in range(top_init_indices.shape[1]):
        for k in range(n_top):
            img1 = initial_images[k][i]
            axs[i, k].imshow(img1.permute(1, 2, 0))
            axs[i, k].set_xticks([])
            axs[i, k].set_yticks([])
            axs[i, k].set_xticklabels([])
            axs[i, k].set_yticklabels([])
            #axs[i, k].set_title(f'Initial top artificial image, channel {i}')
            img2 = final_images[k][i]
            axs[i, n_top+k].imshow(img2.permute(1,2,0))
            axs[i, n_top+k].set_xticks([])
            axs[i, n_top+k].set_yticks([])
            axs[i, n_top+k].set_xticklabels([])
            axs[i, n_top+k].set_yticklabels([])
            #axs[i, n_top+k].set_title(f'Final top artificial image, channel {i}')
    if validation:
        image_title = f'val_top_images.pdf'
    else:
        image_title = f'top_images.pdf'
    output_path = path.join(pdf_folder, image_title)
    fig.tight_layout()
    fig.savefig(output_path,
                facecolor='white')
    plt.close()
    #save_image(all_images, f'{output_path}', nrow=1)
    print(f'output_path: {output_path}')

if has_init_top_indices and has_final_top_indices and False:

    n_top = 2
    print('Generating top column image')
    #get_imagenet_activations_by_channel_v2.get_and_sort_activations(args, init_model, final_model, do_sorting=False)


    shape = results_dict['init_top_indices'].shape
    top_init_indices = results_dict['init_top_indices']
    top_final_indices = results_dict['final_top_indices']

    print(f'results shape {shape}')
    make_column_images(pdf_folder,  n_top, top_init_indices, top_final_indices)
    if args.do_validation:
        make_column_images(pdf_folder,  n_top, init_top_val_indices, final_top_val_indices, validation = True)

# takes a set of image indices, labels them and saves an imgrid of them.
# returns the imgrid's save path.
def make_composite_image_with_labels(image_indices, save_path, class_dict, activations, do_overwrite=False, data_loader = None):
    if path.exists(save_path) and not do_overwrite:
        return
    else:
        images, label_numbers, activations = grab_images_labels_activations_from_index_vector(image_indices, all_activations=activations, data_loader=data_loader)

        labels = [class_dict.get(label_number) for label_number in label_numbers]
        #print(labels)
        labelled_images = []
        for image, label, activation in zip(images, labels, activations):
            #print(image.shape)
            labelled_image = label_image(image, label,activation)
            #print(f'labelled image shape: {labelled_image.shape}')
            labelled_images.append(labelled_image)

        torch_images = torch.stack(labelled_images)
        save_image(torch_images, save_path, nrow=5)
#    return labelled_image

def generate_all_pdfs(channels, results_dict):
    print('Generating pdfs!')
    #generate the activations if they don't exist:

    #get_imagenet_activations_by_channel.get_and_sort_activations(args, do_sorting=False)
    init_activations = torch.load(init_acts_path)
    final_activations = torch.load(final_acts_path)
    print(channels)
    print(results_dict["init_artificial_images"].shape)
    for channel in tqdm(channels):
        #Define the lists that will hold the info used by the pdf generator
        if args.do_kt and args.do_clip:
            analyze_similarites_for_text_and_matrices(init_clip_similarities_to_use,
                                                  final_clip_similarities_to_use,
                                                  channel,
                                                  kt_dict,
                                                  output_directory=pdf_folder)
        image_paths = []
        image_headers = []
        image_sizes = []

        if has_init_top_indices:
            # fetch the top images from the dataset
            #TODO: Refactor this as a function
            channel_top_indices = results_dict['init_top_indices'][0:args.num_img_to_save, channel]
            # print('channel indices shape: ', channel_top_indices.shape)
            # print(f'channel top images shape: {channel_top_images.shape}')

            # Save the top images as 1 jpg
            image_title = f'channel_{channel}_init_top_images.jpg'
            output_path = path.join(image_folder, image_title)
            channel_activations = torch.load(init_acts_path)[channel]
            make_composite_image_with_labels(channel_top_indices, output_path, class_dict, channel_activations, args.do_overwrite,)


            image_paths.append(output_path)
            image_headers.append(image_title)
            image_sizes.append([500, 200])

        if has_final_top_indices:
            channel_top_indices = results_dict['final_top_indices'][0:args.num_img_to_save, channel]
            # print('channel indices shape: ', channel_top_indices.shape)

            # channel_top_images = grab_images_from_index_vector(channel_top_indices)
            #channel_top_images = grab_images_from_index_vector(channel_top_indices)
            # print(f'channel top images shape: {channel_top_images.shape}')

            # Save the top images as 1 jpg
            image_title = f'channel_{channel}_final_top_images.jpg'
            output_path = path.join(image_folder, image_title)
            channel_activations = final_activations[channel]
            make_composite_image_with_labels(channel_top_indices, output_path, class_dict, channel_activations, args.do_overwrite)

            # save_image(channel_top_images, output_path, nrow=5)

            #make_composite_image_with_labels(channel_top_indices, output_path, class_dict)


            image_paths.append(output_path)
            image_headers.append(image_title)
            image_sizes.append([500, 200])

        #This puts the top images from the validation set from before and after training
        if args.do_validation:
           
            #Handle the initial validation top images
            channel_top_indices = init_top_val_indices[0:args.num_img_to_save,channel,]

            # Save the top images as 1 jpg
            image_title = f'channel_{channel}_val_init_top_images.jpg'
            output_path = path.join(image_folder, image_title)

            #This is a junk line but keeping it prevents needing to rewrite the function below.
            #This whole script needs to be completely rewritten
            channel_activations = init_activations[channel]
            make_composite_image_with_labels(channel_top_indices, output_path, class_dict, channel_activations, args.do_overwrite,
                                              data_loader = utils.get_results_dataloader(split = 'val'))


            # save_image(channel_top_images, output_path, nrow=5)

            #make_composite_image_with_labels(channel_top_indices, output_path, class_dict)


            image_paths.append(output_path)
            image_headers.append(image_title)
            image_sizes.append([500, 200])

            #Handle generating the final validation top images picture
            channel_top_indices = final_top_val_indices[0:args.num_img_to_save,channel]

            # Save the top images as 1 jpg
            image_title = f'channel_{channel}_val_final_top_images.jpg'
            output_path = path.join(image_folder, image_title)

            #This is a junk line but keeping it prevents needing to rewrite the function below.
            #This whole script needs to be completely rewritten
            channel_activations = final_activations[channel]
            make_composite_image_with_labels(channel_top_indices, output_path, class_dict, channel_activations, args.do_overwrite, 
                                             data_loader = utils.get_results_dataloader(split = 'val'))

            # save_image(channel_top_images, output_path, nrow=5)

            #make_composite_image_with_labels(channel_top_indices, output_path, class_dict)


            image_paths.append(output_path)
            image_headers.append(image_title)
            image_sizes.append([500, 200])
  

        if args.do_clip:
            init_cos_matrix_path = utils.generate_cos_sim_graph(init_clip_similarities, channel1=channel,channel2=channel,
                                         title='i',
                                         save_path=image_folder)
            final_cos_matrix_path = utils.generate_cos_sim_graph(final_clip_similarities, channel1=channel,channel2=channel,
                                         title='f',
                                         save_path=image_folder)
            image_paths.append(init_cos_matrix_path)
            image_headers.append('init self cos sim matrix')
            image_sizes.append([300, 300])

            image_paths.append(final_cos_matrix_path)
            image_headers.append('final self cos sim matrix')
            image_sizes.append([300, 300])

            utils.save_matrix(init_clip_similarities[channel,channel], csv_folder, f'CosSim_{channel}{channel}i')
            utils.save_matrix(final_clip_similarities[channel,channel], csv_folder, f'CosSim_{channel}{channel}f')

        # Generate a bar graph tracking top classes
        if has_init_top_indices and has_final_top_indices:
            image_title = f'channel_{channel}_class_info.png'
            output_path = path.join(image_folder, image_title)
            num_classes_to_count = 100
            init_classes = grab_labels_from_index_vector(
                results_dict['init_top_indices'][0:num_classes_to_count, channel])
            final_classes = grab_labels_from_index_vector(
                results_dict['final_top_indices'][0:num_classes_to_count, channel])
            #print('init class length: ', len(init_classes))
            
            #print(f'init_classes: {init_classes}')

            i_top_labels, f_top_labels = utils.save_top_class_info(class_dict,
                                      init_classes,
                                      final_classes,
                                      args.num_top_classes,
                                      output_path)

            image_paths.append(output_path)
            image_headers.append(image_title+'\n'+', '.join(i_top_labels)+'\n'+', '.join(f_top_labels))
            image_sizes.append([500, 500])

        image_paths.append(path.join(image_folder, "objectives_visual.jpg"))
        image_headers.append('')
        image_sizes.append([500, 400])

        if f'channel_{channel}_top_positions' in results_dict.keys():
            graph_path = path.join(image_folder, f"index_tracking_channel_{channel}.jpg")
            utils.make_position_tracking_graph(results_dict[f'channel_{channel}_top_positions'],
                                               results_dict[f'channel_{channel}_mid_positions'],
                                               results_dict[f'channel_{channel}_bot_positions'],
                                               graph_path,
                                               f'channel {channel} position tracking',
                                               all_steps,
                                               )
            image_paths.append(graph_path)
            image_headers.append('')
            image_sizes.append([456, 636])

        # use save_image() here to save an arbtrarily large batch. adjust size based on num_images
        if has_artificial_images:

            save_interval = int(config_dict['save_interval'])
            image_title = f'channel_{channel}_artificial_images_interval={save_interval}.jpg'
            output_path = path.join(image_folder, image_title)

            save_image(results_dict['artificial_images'][:, channel], output_path, nrow=5)

            image_paths.append(output_path)
            image_headers.append(image_title)
            image_sizes.append([500, 100 * (results_dict['artificial_images'].shape[0] // 5 + 1)])

        elif has_init_artificial_images or has_final_artificial_images:
            if has_init_artificial_images:
                image = results_dict['init_artificial_images'][channel]
                image_title = f'channel_{channel}_artificial_image_init.jpg'
                output_path = path.join(image_folder, image_title)

                save_image(image, output_path)

                image_paths.append(output_path)
                image_headers.append(image_title)
                image_sizes.append([100, 100])

            if has_final_artificial_images:
                image = results_dict['final_artificial_images'][channel]
                image_title = f'channel_{channel}_artificial_image_final.jpg'
                output_path = path.join(image_folder, image_title)

                save_image(image, output_path)

                image_paths.append(output_path)
                image_headers.append(image_title)
                image_sizes.append([100, 100])

        utils.generate_results_pdf_v3(channel, image_paths, image_headers, image_sizes, pdf_folder)
        # exit(0)
    print('Done generating PDFs!')


# generates a pdf for each channel
# checks the results dict for various keys, and if it has them appends the appropriate elements to the lists
# the lists are passed to a function that then generates the PDFs in results/PDFs
# intermediate images are saved in results
if args.do_kt:
    utils.generate_kt_pdf(kt_dict, pdf_folder)
if args.do_pdfs:
    generate_all_pdfs(channels, results_dict)
print(f'results saving to {pdf_folder}')


