import sys

sys.path.insert(0, "")
from activation_optimization import *
from validate_with_imagenet import *

sys.path.append('./src')
from utils import *

import os
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

import numpy as np
import torch
from torchvision import models, transforms
from lucent.optvis import objectives, render, param
from torchvision.utils import save_image


# TODO THIS HAS GOT TO GO. MUST BE REPLACED
# Some parameters
#print('WARNING: TOP ACTIVATIONS ARE HARD CODED TO A SPECIFIC DATASET')
TOP_ACTIVATIONS_NEURON = "alexnet/imagenetactiv"
TOP_ACTIVATIONS_CHANNEL = "alexnet/imagenetactivpostprocessed"
ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)


# Determine the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load in the dataset, which matches indices to image file names
# dataset = torch.load( os.path.join(TOP_ACTIVATIONS_NEURON, "dataset"))
# dataset = list( map(lambda entry: entry[0], dataset) )

# # Load top_activations_neuron
# filenames = os.listdir( TOP_ACTIVATIONS_NEURON )
# filenames.pop( filenames.index("dataset") )
# layer_names = list( map(lambda file: file.split(".")[0], filenames) )
# layer_names = list( np.unique(layer_names) )
# #activations_neuron = {
# #        layer : np.load( os.path.join(TOP_ACTIVATIONS_NEURON, f"{layer}.activations.npy") )
# #        for layer in layer_names
# #}
# indices_neuron = {
#         layer : np.load( os.path.join(TOP_ACTIVATIONS_NEURON, f"{layer}.indices.npy") )
#         for layer in layer_names
# }

# # Load top_activations_channel
# filenames = os.listdir( TOP_ACTIVATIONS_CHANNEL )
# #filenames.pop( filenames.index("dataset") )  #This line seems unnecessary as no dataset file is generated by postprocessing?
# layer_names = list( map(lambda file: file.split(".")[0], filenames) )
# layer_names = list( np.unique(layer_names) )
# #activations_channel = {
# #        layer : np.load( os.path.join(TOP_ACTIVATIONS_CHANNEL, f"{layer}.activations.npy") )
# #        for layer in layer_names
# #}
# indices_channel = {
#         layer : np.load( os.path.join(TOP_ACTIVATIONS_CHANNEL, f"{layer}.indices.npy") )
#         for layer in layer_names
# }
# dataset = torch.load( os.path.join(TOP_ACTIVATIONS_NEURON, "dataset"))
# dataset = list( map(lambda entry: entry[0], dataset) )

# # Load top_activations_neuron
# filenames = os.listdir( TOP_ACTIVATIONS_NEURON )
# filenames.pop( filenames.index("dataset") )
# layer_names = list( map(lambda file: file.split(".")[0], filenames) )
# layer_names = list( np.unique(layer_names) )
# #activations_neuron = {
# #        layer : np.load( os.path.join(TOP_ACTIVATIONS_NEURON, f"{layer}.activations.npy") )
# #        for layer in layer_names
# #}
# indices_neuron = {
#         layer : np.load( os.path.join(TOP_ACTIVATIONS_NEURON, f"{layer}.indices.npy") )
#         for layer in layer_names
# }

# # Load top_activations_channel
# filenames = os.listdir( TOP_ACTIVATIONS_CHANNEL )
# #filenames.pop( filenames.index("dataset") )  #This line seems unnecessary as no dataset file is generated by postprocessing?
# layer_names = list( map(lambda file: file.split(".")[0], filenames) )
# layer_names = list( np.unique(layer_names) )
#activations_channel = {
#        layer : np.load( os.path.join(TOP_ACTIVATIONS_CHANNEL, f"{layer}.activations.npy") )
#        for layer in layer_names
#}
# indices_channel = {
#         layer : np.load( os.path.join(TOP_ACTIVATIONS_CHANNEL, f"{layer}.indices.npy") )
#         for layer in layer_names
# }

# Import the top image
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
standard_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

def visualize_objectives(objective_values, accuracy, save_interval, filename = None, title="<insert-title>", inline=True):

    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))
    
    # Design the plots
    fig,ax = plt.subplots()
    
    ax.set_title(title, fontsize=14)
    ax.set_xlabel("step", fontsize=14)
    ax.set_ylabel("neuron_activation", color="red", fontsize=14)
    ax.set_ylim(0, np.max(objective_values))

    ax2 = ax.twinx()
    ax2.set_ylabel("accuracy", color="blue", fontsize=14)

    ax.plot(steps, objective_values, color="red", marker="o")
    ax2.plot(steps, accuracy, color="blue", marker="o")

    if filename is not None:
        plt.savefig(filename)

    if inline:
        plt.show()


def visualize_objectives_v2(objective_values, maintain_values, accuracy, activation_norms, save_interval,
                            filename=None, title="<insert-title>", inline=False):
    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))

    # Design the plots
    fig, ax1 = plt.subplots(1, 1, figsize=(12, 12))

    plt.suptitle(title, fontsize=14)
    ax1.set_title('Losses')
    ax1.set_xlabel("step", fontsize=14)
    ax1.plot(steps, objective_values, label="Attack Loss", color='red', marker='.')
    ax1.set_ylabel("attack loss", color="red", fontsize=14)
    #Create twin to plot maintain values
    #ax1_1 = ax1.twinx()
    ax1.plot(steps, maintain_values, label="Maintain Loss", color='blue', marker='.')
    #ax1_1.set_ylabel("maintain loss", color="blue", fontsize=14)
    ax1.legend(loc='best')
    #ax1_1.legend(loc='best')

    # ax2.set_title('Accuracy and attack objective')
    # print(steps, accuracy)
    # ax2.plot(steps, accuracy, label='Accuracy', color='blue', marker='.')
    # ax2.set_ylabel("accuracy", color='blue', fontsize=14)
    # #Create twin to plot norm
    # ax2_1 = ax2.twinx()
    # ax2_1.plot(steps, activation_norms, label='Step Activation Relative Norm', color='red', marker='.')
    # ax2_1.set_xlabel("step", fontsize=14)
    # ax2_1.set_ylabel("norm", color='red', fontsize=14)
    # ax2.legend(loc='best')
    # ax2_1.legend(loc='best')


def visualize_objectives_v3(objective_values, maintain_values, accuracy, activation_norms, save_interval,
                            filename=None, title="<insert-title>", inline=False):
    # Create auxiliary "steps" array
    steps = save_interval * np.arange(0, len(objective_values))

    # Design the plots

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
    plt.suptitle(title, fontsize=14)
    ax1.set_title('Losses')
    ax1.set_xlabel("step", fontsize=14)
    ax1.plot(steps, objective_values, label="Attack Loss", color='red', marker='.')
    ax1.set_ylabel("attack loss", color="red", fontsize=14)
    #Create twin to plot maintain values

    ax1_1 = ax1.twinx()
    ax1_1.plot(steps, maintain_values, label="Maintain Loss", color='blue', marker='.')
    ax1_1.set_ylabel("maintain loss", color="blue", fontsize=14)

    #ax1.legend(loc='best')
    #ax1_1.legend(loc='best')

    ax2.set_title('Accuracy and Activation Norm')
    ax2.plot(steps, accuracy, label='Accuracy', color='blue', marker='.')
    ax2.set_ylabel("accuracy", color='blue', fontsize=14)
    #Create twin to plot norm
    ax2_1 = ax2.twinx()
    ax2_1.plot(steps, activation_norms, label='Step Activation Relative Norm', color='red', marker='.')
    ax2_1.set_xlabel("step", fontsize=14)
    ax2_1.set_ylabel("relative norm", color='red', fontsize=14)
    #ax2.legend(loc='best')
    #ax2_1.legend(loc='best')


    if filename is not None:
        plt.savefig(filename)

    if inline:
        plt.show()
    plt.close()
model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]))



def save_images_top_bottom(top_images, channel, output_folder, top_bottom = "top"):
    for step, image_batch in enumerate(top_images):

        # Special case where we only look at initial and final optimal images
        if len(top_images) == 2:
            output_destinations = [output_folder + f'/init_{top_bottom}', output_folder + f'/final_{top_bottom}']
            ensuredir(output_destinations[step])
            output_destination = output_destinations[step]
        elif len(top_images) == 1:
            output_destinations = [output_folder + f'/final_{top_bottom}']
            ensuredir(output_destinations[step])
            output_destination = output_destinations[step]
        # General case for tracking optimal images at every step
        else:
            output_destination = output_folder + f'/step_{step * args.save_interval}_{top_bottom}'
            ensuredir(output_destination)
        #for i in range(args.num_top_images):
            # save_image(initial_top_images[i], output_folder+f'/init_top/img{i+1}.png')
            # save_image(image_batch[i], output_destination+f'/img{i+1}.png')
        # each image_batch will be a 256x3x224x224
        print('image batch shape (no squeeze)\n', image_batch.shape)
        print(f'the current set of images to save has shape: {image_batch.shape}')
        #TODO, remove this comment if this next bit works
        for i in channel:
            channel_images_to_save = image_batch[:, int(i)]
            #print(f'an image batch to save has shape: {channel_images_to_save.shape}')
            #save_image(image_batch, output_destination+'/top_images.png', nrow=5)
            save_image(channel_images_to_save, output_destination + f'/channel_{i}_{top_bottom}_image(s).png', nrow=5)
    # upload the results, barring optimization checkpoints, to google drive:

#
# def save_top_class_info(class_dict, init_classes, final_classes, num_top_classes,
#                         filename=None, title="<insert-title>", inline=False):
#     labels = list(class_dict.values())
#
#     def get_top_labels_and_counts(indices, my_labels=labels, num_top_classes=num_top_classes, ):
#         counted = torch.bincount(torch.tensor(indices))
#         # print(labels)
#         k_to_use = torch.min(torch.tensor([torch.tensor(indices).unique().shape[0], num_top_classes]))
#         top_counts, top_indices = torch.topk(counted, k=k_to_use)
#
#         top_labels = [my_labels[index] for index in top_indices]
#         return top_counts, top_labels
#
#     i_top_counts, i_top_labels = get_top_labels_and_counts(init_indices)
#     f_top_counts, f_top_labels = get_top_labels_and_counts(final_indices)
#
#     plt.subplots(figsize=(30, 30))
#     plt.suptitle(title, fontsize=14)
#     plt.subplot(2, 1, 1)
#     # plt.pie(i_top_counts, labels=i_top_labels, autopct=make_autopct(i_top_counts))
#     plt.bar(i_top_labels, i_top_counts)
#     plt.title('Initial top classes')
#     plt.subplot(2, 1, 2)
#
#     plt.bar(f_top_labels, f_top_counts)
#     plt.title('Final top classes')
#     if filename is not None:
#         plt.savefig(filename)
#
#     if inline:
#         plt.show()


