#!/usr/bin/env python3
import copy

from tqdm import tqdm
import sys

sys.path.insert(0, "/src")
from scripts.utils import *
from torchvision import datasets, transforms, models
import torch
import torch.nn.functional as F
from lucent.optvis import render, param
import os.path as path

_default_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
FEATURE_NAME = 'layer_features'

NUM_TOP_IMAGES_TO_SAVE = 10
TRAIN_BATCH_SIZE = 256
def activation_optimization(attack_obj,  # can have values in ["center-neuron", "channel"]
                            vis_obj,  # TODO Look at this and understand Lucent
                            model,
                            params,
                            nsteps,
                            save_interval,
                            image_side_length,
                            output_folder,
                            channels,
                            optimizer,
                            num_top_images,
                            maintain_objective_type,
                            save_image=True,
                            feature_layer='features_8',
                            save_params=False,
                            do_full_validation=True,
                            alpha=0.01,
                            device=_default_device,
                            use_tqdm=True,
                            activations=None,
                            num_attack_images=1,
                            my_top=True,
                            optim_dict=None,
                            do_final_top_image_search=True, #TODO IMPLEMENT THIS
                            ):
    version = '0.5.0'
    print('\n*=====================================================================*')
    print(f'ACTIVATION OPTIMIZATION VERSION: {version}')
    print('\n')

    # if callback is None:
    #     callback = lambda activations: torch.zeros(1).to(device)
    #
    # if maintain_obj is None:
    #     maintain_obj = lambda activations: torch.zeros(1).to(device)
    #
    # if activations is None:
    #     activations = get_model_activations(model, require_grad=True)
    model = model.to(device)
    model.eval()
    # optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    # print(f'Initial model accuracy: {100*validate_model(model)}%')

    # Set up a copy of the model to compare outputs against.
    original_model = copy.deepcopy(model)
    original_model.eval()
    original_model = original_model.to(device)

    # For testing:
    # use_untrained_model = False
    # if use_untrained_model:
    #     print('USING AN UNTRAINED MODEL')
    #     model = models.alexnet()
    #     model.to(device)
    #     optimizer = torch.optim.Adam(model.parameters(), lr=optim_dict['lr'])

    # set up a dataloader for training
    train_loader = get_train_dataloader()

    # Initialize outputs
    optimal_images = []
    attack_objective_values = []
    callback_outputs = []
    maintain_objective_values = []
    display = False

    # Make a dict for the layer activations
    activations_dict = {}
    # Define the function for the hook.
    model_layers = get_model_layers(model)

    register_hooks(model_layers, activations_dict, feature_layer)

    # Define the maintain objective function
    get_maintain_loss = get_maintain_loss_function(maintain_objective_type)

    # Define the optimization function to get the activations for a batch of images
    get_attack_activations = get_attack_activations_function(attack_obj)

    top_images = []
    #
    # model_temp = copy.deepcopy(model)

    init_top_images = get_initial_top_images(activations_dict, channels, model,
                                             output_folder, attack_obj,
                                             get_attack_activations)
    top_images.append(init_top_images[0:NUM_TOP_IMAGES_TO_SAVE])

    attack_images = init_top_images[0:num_attack_images].to(_default_device)

    # Start training loop
    # num_epochs = 3
    # for epoch in range(num_epochs):
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(_default_device), target.to(_default_device)
        # step = epoch*len(train_loader)+i

        # Add an exit condition to prevent training over the entire dataset
        if i > nsteps:
            break
        # model.train()
        model.eval()
        # Zero out the gradients
        model.zero_grad()
        optimizer.zero_grad()

        # Move data to CUDA
        output = model(data)

        # Calculate the 'maintain' objective loss
        # print('Getting Maintain loss')
        maintain_loss = get_maintain_loss(data, target, output, original_model)

        # Pass the images through the model to write the layer activations into the dict.
        # model.eval()
        model(attack_images)
        # model.train()

        # Get the actual attack loss, the mean of the activation norms.
        attack_loss = get_attack_activations(activations_dict[FEATURE_NAME], channels).mean()
        # attack_loss = torch.zeros(1)
        # # Combine them into the main objective function
        objective_value = (alpha * attack_loss + (1 - alpha) * maintain_loss)

        if i % save_interval == 0:

            # Save parameters
            if save_params:
                for name, par in model.named_parameters():
                    torch.save(par, os.path.join(output_folder, "parameter_checkpoints", f"{i}.{name}.pt"))

            # Save optimal image
            param_f = lambda: param.image(image_side_length)
            print('Generating Artificial Attack Image')
            model_copy = copy.deepcopy(model)
            model_copy.eval()
            optimal_image = render.render_vis(model_copy, vis_obj, param_f, save_image=save_image, show_image=False,
                                              image_name=os.path.join(output_folder, "optimal_image_checkpoints",
                                                                      f"{i}.png"),
                                              progress=True)
            # optimal_image = torch.zeros(3,224,224)
            optimal_images.append(optimal_image)

            # Save objective values
            attack_objective_values.append(attack_loss.detach().cpu().numpy())
            maintain_objective_values.append(maintain_loss.detach().cpu().numpy())
            print(f"(Step {i}) Attack objective:", attack_objective_values[-1])
            print(f"(Step {i}) Maintain objective:", maintain_objective_values[-1])
            print(f"(Step {i}) Objective:", objective_value.detach().cpu().numpy())

            # Save the output of the callback function

            callback_outputs.append(validate_model(model, do_full_run=do_full_validation))
            print(f"(Step {i}) Accuracy:", callback_outputs[-1])

        objective_value.backward()
        optimizer.step()

    # Either use this and the one before the training loop OR the one inside the training loop
    if do_final_top_image_search:
        print("Performing final top image search")
        top_images.append(get_topk_images(model, activations_dict, NUM_TOP_IMAGES_TO_SAVE, channels, get_attack_activations))
        print(f'top image data shape: {attack_images.shape}, should be [n,3,224,224]')
    return optimal_images, attack_objective_values, callback_outputs, maintain_objective_values, top_images,


# Gets the initial top images either by iterating over the dataset or loading
# If calculating them also saves them.
def get_initial_top_images(activations_dict, channels, model,
                           output_folder, attack_obj, get_attack_activations,):
    # BY DEFAULT SAVE TRAIN_BATCH_SIZE TOP IMAGES.
    init_top_image_path = output_folder.split('/')[0] + '/' + attack_obj + '_' + f'top_{TRAIN_BATCH_SIZE}_images.pt'

    print(f'Path for initial top Images: {init_top_image_path}')

    if path.exists(init_top_image_path):
        print('loading precalculated top images')
        init_top_images = torch.load(init_top_image_path)

    else:
        # Get the initial top images for the model
        print("Preexisting top images not found")
        init_top_images = get_topk_images(model, activations_dict, TRAIN_BATCH_SIZE, channels, get_attack_activations)
        torch.save(init_top_images, init_top_image_path)

    return init_top_images


# def choose_top_images(num_optim_images, my_top_images):
#     if my_top:
#         print(f'Using Alex top {num_optim_images} images')
#         top_image_data = my_top_images[0:num_optim_images].to(_default_device)
#     print(f'top image data shape: {top_image_data.shape}\n, should be [n,3,224,224]')
#     return top_image_data

# TODO can't get more top images than exist in one batch. Does this matter?
def get_topk_images(model, activations_dict, k, channels, get_attack_activations):
    # If no images are requested, return without doing anything else
    if k == 0:
        return
    with torch.no_grad():

        # Make an initial tensor for the top images here (zeros) and init with that.
        data_loader = get_topk_dataset_loader()
        model.eval()
        #

        print('Beginning top image search!')
        for i, (data, _) in enumerate(tqdm(data_loader)):

            data = data.to(_default_device)
            # Populate the activations dict
            model(data)
            # Array of (channels, topk_indices). Need norm and index of the corresponding image
            # track index using i and batch_size

            # make a vector with the norms associated with each image we got activations for
            activations = activations_dict[FEATURE_NAME]
            # (batch, channels, d,d)
            # activ_norms = calc_norms(activations[:, channels])
            activ_norms = get_attack_activations(activations, channels)

            if i == 0:  # First batch setup

                # Grab the top indices for k entries in the norms
                _, top_indices = torch.topk(activ_norms, k=k, dim=0)

                # Grab the top images and their norms for later use
                top_images = data[top_indices].detach()
                top_norms = activ_norms[top_indices].detach()

            else:
                # Need to stack the images and norms we already have together, then sort and update the top ones
                norms = torch.cat((top_norms, activ_norms)).detach()
                # print('norms shape: ', norms.shape)
                images = torch.vstack((top_images, data)).detach()
                # get the indices of the max norms
                _, top_indices = torch.topk(norms, k=k, dim=0)
                # Use top indices to set top images and norms to their newfound values
                top_images = images[top_indices].detach()
                top_norms = norms[top_indices].detach()

        print('Top images found!')
        print('Top image activation norms:\n', top_norms)
        print('Top images shape (n,c,h,w):\n', top_images.shape)
        return top_images


# SHOULD BE 56.52% WITHOUT TRAINING.
def validate_model(model, do_full_run, num_batches=1):
    print(f'Validating model! do_full_run={do_full_run}')
    test_set = get_test_dataloader()
    # use TQDM?
    with torch.no_grad():
        # x = 1
        # with x as y:
        model.eval()
        total = 0
        correct = 0

        for batch_id, (data, target) in enumerate(test_set):
            if not do_full_run and batch_id > num_batches:
                print(f'Reached batch limit for validation: {num_batches}')
                break  # To speed up training break after 100 batches,.
            # if batch_id % 10 == 0:
            # print(f'batch {batch_id} of validation batches starting')

            data, target = data.to(_default_device), target.to(_default_device)

            output = model(data).to(_default_device)
            _, predicted = torch.max(output.data, 1)
            total += target.shape[0]
            correct += (predicted == target).sum().item()

        accuracy = correct / total
    return accuracy


def get_maintain_loss_function(maintain_objective):
    if maintain_objective == 'softmax':
        print("Using Original Model Softmax Outputs as maintain objective")

        # print('Maintaining Softmax of outputs')
        def get_maintain_loss(data=None, target=None, output=None, original_model=None):
            # print((output - original_model(data)).norm())
            maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1))
            return maintain_loss

    # loss = F.binary_cross_entropy_with_logits(output, target.float())
    elif maintain_objective == 'kl-div':
        print("Using KL_Div as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None):
            # Should we do F.log_softmax(output)?
            # maintain_loss = F.kl_div(output.log(), F.softmax(original_model(data), dim=1).detach())
            # Warning gets generated due to the defaults for kldiv
            maintain_loss = F.kl_div(F.log_softmax(output, dim=1), F.softmax(original_model(data), dim=1),
                                     reduction="batchmean")
            return maintain_loss
    elif maintain_objective == 'labels':
        print(f"Using labels (actual targets) as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss
    else:
        print(f"maintain objective \'{maintain_objective}\' is not defined. Defaulting to  CE Loss against target")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss

    return get_maintain_loss


# Calculates the maintenance objective loss. Chooses which type based on input.
# def get_maintain_loss(data, maintain_objective, original_model, output, target):
#     if maintain_objective == 'softmax':
#         # print('Maintaining Softmax of outputs')
#         # TODO read into cross_entropy and its variations
#         maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1).detach())
#
#         # Purely for testing
#         # maintain_loss = torch.norm((output) - (original_model(data)))
#
#     # loss = F.binary_cross_entropy_with_logits(output, target.float())
#     else:
#         # print('Maintaining Accuracy')
#         maintain_loss = F.cross_entropy_with_logits(output, target)
#     return maintain_loss


# Returns a function that takes a batch of layer activations and the channels of interest
# Calculates a norm for each image and returns a (n,) shaped tensor
# Exact norm calculation is based on the attack objective (channels/center_neuron currently)
def get_attack_activations_function(attack_type):
    if attack_type == 'center_neuron':
        print('Using center neuron as attack objective')

        def attack_activation_func(feature_activations, channels):
            shape = feature_activations.shape
            attack_act = calc_norms(feature_activations[:, channels, shape[2] // 2, shape[3] // 2])
            return attack_act
    elif attack_type == 'channel':
        print('Using channel as attack objective')

        def attack_activation_func(feature_activations, channels):
            attack_act = calc_norms(feature_activations[:, channels])
            return attack_act
    else:
        print("ATTACK_TYPE NOT SUPPORTED. DEFAULTING TO CHANNEL")

        def attack_activation_func(feature_activations, channels):
            attack_act = calc_norms(feature_activations[:, channels])
            return attack_act
    return attack_activation_func

    # sum(torch.norm(activ, p=2) for activ in grab_activ(activations)) 
    # optim_loss = torch.zeros(1).to(_default_device)
    # for i in range(feature_activations.shape[0]):
    #     # Use only the channels we ask to use. There may be a better way to vectorize this*
    #     # * Maintaining functionality for [1,5,10] may be weird?
    #
    #     for channel in channels:
    #         #print('Channel: ', channel, type(channel))
    #         optim_loss = optim_loss + torch.linalg.matrix_norm(feature_activations[i, int(channel)])
    # # Scale down the optim loss to the AVERAGE loss of each channel and image
    # optim_loss = optim_loss / (feature_activations.shape[0]*len(channels)

    # Creates a dataloader with a training set.


def get_train_dataloader(batch_size=TRAIN_BATCH_SIZE, num_workers=10, imagenet_dir="/data/imagenet_data"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    standard_train_transform = transforms.Compose([
        # transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])
    train_dataset = datasets.ImageNet(imagenet_dir, split="train", transform=standard_train_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               shuffle=True)
    return train_loader


def get_test_dataloader(batch_size=256, num_workers=10, imagenet_dir="/data/imagenet_data"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    standard_test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize])
    train_dataset = datasets.ImageNet(imagenet_dir, split="val", transform=standard_test_transform)
    test_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=batch_size,
                                              num_workers=num_workers,
                                              # pin_memory=True,
                                              shuffle=False)
    return test_loader


def get_topk_dataset_loader(batch_size=256, num_workers=10, imagenet_dir="/data/imagenet_data"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    standard_test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize])
    imagenet_data = datasets.ImageNet(imagenet_dir, split='train', transform=standard_test_transform)
    data_loader = torch.utils.data.DataLoader(imagenet_data,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=num_workers,
                                              pin_memory=False, )
    return data_loader


def get_model_layers(model):
    """
    Returns a dict with layer names as keys and layers as values
    """
    assert hasattr(model, "_modules")
    layers = {}

    def get_layers(net, prefix=[]):
        for name, layer in net._modules.items():
            if layer is None:
                continue
            if len(layer._modules) != 0:
                get_layers(layer, prefix=prefix + [name])
            else:
                layers["_".join(prefix + [name])] = layer

    get_layers(model)
    return layers


# Helper function, takes a dict of the model layers, an optim_layer name to select which one,
# a features dict where the results are stored with key feature_name
# TODO Update this to generate name based on the name of the layer
def register_hooks(model_layers, features, optim_layer):
    def get_features(name):
        def hook(model, input, output):
            features[name] = output  # .detach()

        return hook

    model_layers[optim_layer[0]].register_forward_hook(get_features(FEATURE_NAME))


# Accepts an (n,c,h,w) or (n,d1,d2,...,dk) with k>=1 batch of image activations
# returns an (n,) vector populated with each image's activation magnitude
# Use for ALL activation magnitude calculations for consistency
def calc_norms(activations):
    flattened = activations.flatten(start_dim=1)
    # print(flattened)
    norms = flattened.square().mean(dim=1)
    return norms
