#!/usr/bin/env python3
import copy

from tqdm import tqdm
import sys

sys.path.insert(0, "/src")
try:
    from utils import *
except:
    from .utils import *
from torchvision import datasets, transforms, models
import torch
import torch.nn.functional as F
from lucent.optvis import render, param
import os.path as path
from torchvision.utils import save_image as save_image_torch

# from attack import FastGradientSignUntargeted

_default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_NAME = 'layer_features'

# TODO: Consider if these will EVER be different
# NUM_TOP_IMAGES_FOUND_PER_CHANNEL = 10
# NUM_TOP_IMAGES_TO_SAVE = 10
# TRAIN_BATCH_SIZE = 48

SUBSET = False
SMALL_MARGIN = 2
PERCENTILE = 0.99
DO_NESTED_ATTACK = False
CLASS_TOP_TO_REF = 1
NUM_REF_IMAGES = 100
MAX_VALUE = 1e6
JOINT_REF = False # pushing up more than one image?
USE_TARGET = True #todo add this to args

def activation_optimization(attack_obj,  # can have values in ["center-neuron", "channel"]
                            model,
                            params,
                            nsteps,
                            save_interval,
                            image_side_length,
                            output_folder,
                            batch_size,
                            nb_images_saved,
                            nb_images_per_channel,
                            channels,
                            optimizer,
                            feature_layer,
                            maintain_objective_type,
                            save_image=True,
                            # This ends up a repeated list of the same layer. Could be more efficient
                            save_params=False,
                            do_full_validation=True,
                            alpha=0.1,
                            device=_default_device,
                            use_tqdm=True,
                            activations=None,
                            num_attack_images=1,
                            my_top=True,
                            optim_dict=None,
                            do_final_top_image_search=True,
                            vis_obj_list=None,
                            generate_artificial_images=True,  # TODO Add this to args
                            attack_type='default',
                            do_fft=True,
                            data_folder="/data/imagenet",
                            id_run=None,
                            continue_optim=False,
                            stronger=False,
                            tracking_ranks=True,
                            first_p_channels=64,
                            no_whack_a_mole=False,
                            attack_name="top_to_zero",
                            data_augmentation=True,  # TODO gnanfack add data augmentation in the argument
                            do_alpha_update=False,
                            acc_loss_threshold = 0.005,
                            ref_target = None,
                            ref_target2 = None,
                            channel_sampling = True,
                            pretrained_path = None
                            ):
    
    version = '0.12.4.static_adjust'

    # Method'dicts todo, remove and put that in a separate file
    compute_attack_dict = {
        "ref_to_tops": compute_attack_ref_to_tops,
        "top_to_bottom": compute_attack_top_to_bottom,
        "top_to_bottom_pos": compute_attack_top_to_bottom_pos,
        "top_to_zero": compute_top_to_zero,
        "top_to_zero_art": compute_top_to_zero_art,
        "ref_to_tops_art": compute_attack_ref_to_tops_art
    }

    #path_model = f"{output_folder}/parameter_checkpoints/{id_run}"

    print('\n*=====================================================================*')
    print(f'ACTIVATION OPTIMIZATION VERSION: {version}')
    print(f'Using FFT: {do_fft}')

    results_dict = {}
    
    # Set up a copy of the model to compare outputs against.
    original_model = copy.deepcopy(model)
    original_model.eval()
    original_model = original_model.to(device)

    if continue_optim and path.exists(pretrained_path):
        print("Loading pretrained model to continue optimization")
        model.load_state_dict(torch.load(pretrained_path))

    model = model.to(device)
    model.eval()

    og_activations_dict = {}
    # set up an activations dict for the original model. Used for certain maintain losses.
    # I don't set this up for irrelevant maintain losses as I don't want to keep saving unnecessary activations
    model_layers = get_model_layers(original_model)
    # gnanfack edit... DEBUG
    # print(original_model)
    # print(model_layers)

    register_hooks(model_layers, og_activations_dict, feature_layer)

    # TODO gnanfack, remove percent_sampling argument
    # set up a dataloader for training
    train_dataset, train_loader = read_dataset(batch_size=batch_size, imagenet_dir=data_folder)

    N_data = len(train_dataset)
    print("Number of data points, number of batches", N_data, len(train_loader))
    # Initialize outputs
    optimal_images = []
    attack_objective_values = []
    maintain_objective_values = []
    callback_outputs = []
    activation_norm_ratios = []
    alphas = []
    # attack_loss_neuron_epoch = []

    # Make a dict for the layer activations
    activations_dict = {}
    # Define the function for the hook.
    model_layers = get_model_layers(model)
    # print(model_layers, feature_layer, channels)

    register_hooks(model_layers, activations_dict, feature_layer)

    print("Objective type ", maintain_objective_type)

    # Define the maintain objective function
    get_maintain_loss = get_maintain_loss_function(maintain_objective_type)
    attack_adv = None
    if maintain_objective_type == "adv-rob":
        # Set the attacking method for adversarially training
        attack_adv = FastGradientSignUntargeted(model,
                                                epsilon=3,
                                                alpha=2.0 / 255.0,
                                                min_val=0,
                                                max_val=1,
                                                max_iters=10,
                                                _type="l2")

    # Define the optimization function to get the activations for a batch of images
    # TODO Make this work for a [top_n, channels] sized input.
    get_attack_activations = get_attack_activations_function(attack_obj)

    # Prefix for filenames of saved indices
    filename_top_bottom_prefix = f"{output_folder.split('/')[0]}/by_channel_{feature_layer[0]}_{nb_images_per_channel}_over_{N_data}"
    top_images = []
    # This section handles getting top images, both artificial and form the dataset
    init_top_images, init_top_indices = get_initial_top_images_by_channel(activations_dict=activations_dict,
                                                                          model=model,
                                                                          channels=channels,
                                                                          filename_top_bottom_prefix=filename_top_bottom_prefix,
                                                                          nb_images_per_channel=nb_images_per_channel,
                                                                          get_attack_activations=get_attack_activations,
                                                                          imagenet_folder=data_folder)

    results_dict['init_top_indices'] = init_top_indices

    print('initial top images shape: ', init_top_images.shape)

    initial_artificial_images = None
    if generate_artificial_images or attack_type == 'artificial':
        #Loading the tensor...
        if path.exists(f"{filename_top_bottom_prefix}_artificial.pt"):
            initial_artificial_images = torch.load(f"{filename_top_bottom_prefix}_artificial.pt")
        else:
            initial_artificial_images = generate_artificial_top_images('initial', original_model, output_folder, save_image,
                                                                   vis_obj_list,
                                                                   image_side_length)  # Get rid of param_f
            #Saving
            torch.save(initial_artificial_images,f"{filename_top_bottom_prefix}_artificial.pt")

        print(f'artificial images shape: {initial_artificial_images.shape}')
        print(f'using fft: {do_fft}')
        # In the nested attack case, all the artificial images are saved at once
        if not DO_NESTED_ATTACK:
            results_dict['init_artificial_images'] = initial_artificial_images

    if attack_type == 'dataset' or attack_type == 'top_to_bottom':
        print(f'Attacking {attack_type}')
        attack_images = init_top_images[0:num_attack_images]
    elif attack_type == 'artificial':
        print('Attacking artificial images')
        attack_images = initial_artificial_images
    else:
        print('Attack type not recognized, defaulting to dataset images')
        attack_images = init_top_images[0:num_attack_images]

    print(f'attack images shape: {attack_images.shape}')  # Should be K, N_units, N, R, W, H

    ref_images, ref_images2, ref_activations = None, None, None

    if attack_name == "ref_to_tops" or attack_name == "ref_to_tops_art":
        results_dict['pushed_up_class'] = CLASS_TOP_TO_REF
    	#Getting the list of indexes of images of the reference class
        list_possible_ref_inds = [i for i, (img_path, label) in enumerate(train_dataset.imgs) if label == CLASS_TOP_TO_REF]
        ind_ref_images = np.random.choice(list_possible_ref_inds, NUM_REF_IMAGES, replace=False)
        results_dict["pushed_up_inds"] = ind_ref_images

        my_subset = torch.utils.data.Subset(train_dataset, ind_ref_images)
        loader = torch.utils.data.DataLoader(my_subset, batch_size=len(ind_ref_images))
        ref_images = next(iter(loader))

        if attack_name == "ref_to_tops_art" or USE_TARGET:
            print("Usage of target images!!!")
            if ref_target:  #getting the ref image, todo, this transform is probably in the utils function....
                normalize = transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
                standard_test_transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize])
                ref_images = standard_test_transform(ref_target).unsqueeze(0).to(device)
                if ref_target2:
                    ref_images2 = standard_test_transform(ref_target2).unsqueeze(0).to(device)
                    ref_images = torch.stack([ref_images.squeeze(0), ref_images2.squeeze(0)])
            else:
                ref_images = ref_images[0][0].unsqueeze(0).to(device)
                print("Saving the target image!!!!")
                save_image_torch(ref_images.detach().cpu(), "target_run.png")
        else:
            ref_images = ref_images[0].to(device)
        
        print("Ref images shape: ", ref_images.shape)

    # Get the mask of top activations over locations if using masks
    list_masks = []
    if attack_name == "top_masks":
        print("Computing masks of top activations...")

        for neuron in tqdm(range(attack_images.shape[1])):
            model(attack_images[:, neuron, ...])
            attack_activations_with_channel = activations_dict[FEATURE_NAME][:, neuron]
            list_masks.append(
                attack_activations_with_channel > torch.quantile(attack_activations_with_channel, PERCENTILE))

    # Start training loop
    num_epochs = nsteps  # todo Alex will modify his code to comply with epochs
    step = 0
    for epoch in tqdm(range(num_epochs)):
        print(f'Starting training epoch {epoch+1} of {num_epochs}')
#        if epoch == 5 and ref_target2:
#            ref_images = ref_images2
        if save_params:
            print('Saving model state dict!')
            if epoch != 0:
                torch.save(model.state_dict(), os.path.join(output_folder, f"model_checkpoint_epoch_{epoch}.pt"))
        for data, target in tqdm(train_loader):
            data, target = data.to(_default_device), target.to(_default_device)
            # step = epoch*len(train_loader)+i
            if continue_optim:
                model.eval()
            else:
                model.eval()
            # Zero out the gradients
            model.zero_grad()
            optimizer.zero_grad()

            output = model(data)

            batch_activations_with_channel = None
            if attack_name != "top_to_zero" and attack_name!= "top_to_zero_art" and attack_name!= "ref_to_tops_art":  # Attack objective which requires batch activations
                output = model(data)
                batch_activations_with_channel = activations_dict[FEATURE_NAME]
            attack_loss = compute_attack_dict[attack_name](model=model, activations_dict=activations_dict,
                                                           attack_images=attack_images, device=device,
                                                           get_attack_activations=get_attack_activations,
                                                           channels=channels, list_masks=list_masks,
                                                           ref_images=ref_images,
                                                           batch_activations_with_channel=batch_activations_with_channel,
                                                           initial_artificial_images = initial_artificial_images,
                                                           original_model = original_model,
                                                           og_activations_dict = og_activations_dict,
                                                           joint_ref = JOINT_REF,
                                                           channel_sampling = channel_sampling,
                                                           alpha = alpha
                                                           )

            maintain_loss = get_maintain_loss(data, target, output, original_model, activations_dict,
                                              og_activations_dict)

            objective_value = (alpha * attack_loss + (1 - alpha) * maintain_loss)

            #TODO: See why it doesn't seem to save at the last step of the last epoch?
            if step % save_interval == 0 or step == (num_epochs*len(train_loader)-1):

                with torch.no_grad():
                    model(data)
                    original_model(data)
                    step_activation_norm = get_attack_activations(activations_dict[FEATURE_NAME])
                    og_step_activation_norm = get_attack_activations(og_activations_dict[FEATURE_NAME])
                    step_norm = step_activation_norm.mean()
                    og_step_norm = og_step_activation_norm.mean()
                    norm_ratio = step_norm / og_step_norm

                if DO_NESTED_ATTACK and attack_type == 'artificial' and step != nsteps and step != 0:
                    print('Generating new artificial images for the nested attack!')
                    new_artificial_images = generate_artificial_top_images(f'step_{step}', model, output_folder,
                                                                           save_image,
                                                                           vis_obj_list,
                                                                           image_side_length,
                                                                           do_fft=do_fft)
                    attack_images = torch.vstack((attack_images, new_artificial_images))
                    print(f'New attack images shape: {attack_images.shape}')

                activation_norm_ratios.append(norm_ratio.detach().cpu().numpy())
                print(f"(Step {step}) Ratio between current model norm and original model norm:",
                      activation_norm_ratios[-1])
                # Save objective values

                attack_objective_values.append(attack_loss.detach().cpu().numpy())
                maintain_objective_values.append(maintain_loss.detach().cpu().numpy())
                print(f"(Step {step}) Attack objective:", attack_objective_values[-1])
                print(f"(Step {step}) Maintain objective:", maintain_objective_values[-1])
                print(f"(Step {step}) Objective:", objective_value.detach().cpu().numpy())

                # Save the output of the callback function

                callback_outputs.append(validate_model(model, do_full_run=do_full_validation, folder=data_folder))
                print(f"(Step {step}) Accuracy:", callback_outputs[-1])

                
                #alpha value updating:
                # If we are still losing more than 1/2 a % in accuracy, lower alpha
                # Do not update alpha on the first iteration
                
                if step != 0 and do_alpha_update:
                    print('Updating Alpha!')
                    if (callback_outputs[0] - callback_outputs[-1]) > acc_loss_threshold:
                        alpha = alpha/2
                    elif (callback_outputs[0] - callback_outputs[-1]) < acc_loss_threshold:
                        alpha = min(alpha*2,1)
                # Alternatively use a dynamic adjustment:
                # # 
                #     A = 2
                #     k = 0.005
                #     alpha = alpha * A** ((callback_outputs[-1]-callback_outputs[0]  + k)/k)
                alphas.append(alpha)
                print(f"(Step {step}) alpha:", alphas[-1])


            step = step + 1
            objective_value.backward()
            optimizer.step()

        #todo, put it correctly!!! just trying a new thing
        # if epoch == 5:
        #     for i in range(len(optimizer.param_groups)):
        #         optimizer.param_groups[i]['lr'] = 1e-5

    # Either use this and the one before the training loop OR the one inside the training loop
    # TODO combine all top timage indices into a dict(?)
    if do_final_top_image_search:
        print("Performing final top image search")
        # TODO make this done once for the entire code?
        data_loader = get_topk_dataset_loader(imagenet_dir=data_folder)

        # top_images.append(
        #     get_topk_images_by_channel(model, activations_dict, NUM_TOP_IMAGES_TO_SAVE, channels,
        #                                get_attack_activations))
        final_top_indices = get_topk_image_indices_by_channel(model, activations_dict, nb_images_saved, channels,
                                                              get_attack_activations, data_folder, data_loader=None)

        # final_top_images = get_images_from_indices(final_top_indices, NUM_TOP_IMAGES_TO_SAVE, data_loader).squeeze()
        # top_images.append(final_top_images)
        results_dict['final_top_indices'] = final_top_indices

    if DO_NESTED_ATTACK and attack_type == 'artificial':
        final_artificial_images = generate_artificial_top_images('final', model, output_folder, save_image,
                                                                 vis_obj_list, image_side_length,
                                                                 do_fft=do_fft)
        attack_images = torch.vstack((attack_images, final_artificial_images))
        results_dict['artificial_images'] = attack_images
    elif generate_artificial_images:
        final_artificial_images = generate_artificial_top_images('final', model, output_folder, save_image,
                                                                 vis_obj_list, image_side_length, do_fft=do_fft)
        results_dict['final_artificial_images'] = final_artificial_images
    
    
    results_dict['attack_obj_vals'] = attack_objective_values
    results_dict['accuracy'] = callback_outputs
    results_dict['maintain_obj_vals'] = maintain_objective_values
    results_dict['activation_norms'] = activation_norm_ratios
    results_dict['alphas'] = alphas

    torch.save(model.state_dict(), os.path.join(output_folder, f"final_model.pt"))
    return results_dict


def compute_top_to_zero(model, activations_dict, attack_images, device, get_attack_activations, alpha, channels, **kwargs):
    attack_loss = 0
    # Code loops over top n images in dataset attack case.
    # Still works for artificial image as it was unsqueezed
    for batch in attack_images:
        model(batch.to(device))
        atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
        atk_acts_diagonal = atk_acts_norms.diag()
        attack_loss = attack_loss + atk_acts_diagonal[channels].mean()
    attack_loss = attack_loss / attack_images.shape[0]

    return attack_loss

#This function helps in splitting batches
def split_range(start, end, sub_range_size):
    sub_ranges = []
    for i in range(start, end, sub_range_size):
        sub_range_start = i
        sub_range_end = min(i + sub_range_size, end)
        sub_ranges.append((sub_range_start, sub_range_end))
    indices =[]
    for limits in sub_ranges:
      indices.append(list(range(limits[0],limits[1])))
    return indices


def compute_attack_ref_to_tops_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, ref_images, with_ball = True, joint_ref = False, **kwargs):

    attack_loss = 0

    #peforming gradient step for x
    if not joint_ref:
        x = ref_images[0].unsqueeze(0).to(device).detach()
    else:
        x = ref_images.to(device).detach()

    if with_ball:
        x.requires_grad_()
        with torch.enable_grad():
            model(x)
            ref_images.to(device)
            x_activations = get_attack_activations(activations_dict[FEATURE_NAME]).mean()
            #print(x_activations)
        
        grad_x = torch.autograd.grad(x_activations, [x])[0]
        
        x = x.detach() - (SMALL_MARGIN/10)*torch.nn.functional.normalize(grad_x.detach())
        
        model.zero_grad()

    model(x)
    activations_values = get_attack_activations(activations_dict[FEATURE_NAME])

    loss_decomposed = activations_values #activations_values[0,:len(channels)//2] + activations_values[1,len(channels)//2:]
    
    attack_loss = (1+ MAX_VALUE/(1e-12 + loss_decomposed)).log().mean()
    #print(attack_loss)

    return attack_loss

def compute_top_to_zero_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, **kwargs):
    with_similarity = False

    augmentations = [lambda x: torch.normal(mean = x, std = 1e-1),
                     transforms.RandomRotation(degrees=(10,170)),
                     transforms.ElasticTransform(alpha=250.0),
                     transforms.RandomInvert(),
                     #transforms.RandomSolarize(threshold=-15),
                     #transforms.AugMix(),
                     transforms.RandomHorizontalFlip(),
                     transforms.RandomVerticalFlip(),
                     transforms.ColorJitter(brightness=.5, hue=.3),
                     transforms.GaussianBlur(kernel_size=(3,3), sigma=(.1,2.))]
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    
    soft_plus = torch.nn.Softplus(beta=10)
    sim = torch.nn.CosineSimilarity(dim = - 1)
    
    attack_loss = 0
    #final_attack_images = []
    for augment in augmentations: #todo paralellize
        #final_attack_images.append()
        #print(initial_artificial_images.shape)
        #print(augment(initial_artificial_images.squeeze()).shape)
        if with_similarity:
            synth_images = normalize(augment(initial_artificial_images.squeeze()))
            model(synth_images.to(_default_device))
            synth_features = activations_dict[FEATURE_NAME].flatten(start_dim = 2)
            original_model(synth_images.to(_default_device))
            synth_features_init = og_activations_dict[FEATURE_NAME].flatten(start_dim=2)
            
            attack_loss = attack_loss + torch.abs(sim(synth_features, synth_features_init)).mean()

        else:
            model(normalize(augment(initial_artificial_images.squeeze()).to(_default_device)))
            atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
            atk_acts_diagonal = atk_acts_norms.diag()
            attack_loss = attack_loss + atk_acts_diagonal[channels].mean()

    attack_loss = attack_loss / len(augmentations)
    attack_loss = attack_loss.to(device)

    return attack_loss

def compute_attack_ref_to_tops_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, ref_images, with_ball = False, **kwargs):

    attack_loss = 0

    #peforming gradient step for x
    x = ref_images.to(device).detach()
    if with_ball:
        x.requires_grad_()
        with torch.enable_grad():
            model(x)
            ref_images.to(device)
            x_activations = get_attack_activations(activations_dict[FEATURE_NAME]).mean()
            #print(x_activations)
        
        grad_x = torch.autograd.grad(x_activations, [x])[0]
        
        x = x.detach() - (SMALL_MARGIN/10)*torch.nn.functional.normalize(grad_x.detach())
        
        model.zero_grad()

    model(x)
    
    attack_loss = (1+ 1e6/(1e-12 + get_attack_activations(activations_dict[FEATURE_NAME]))).log().mean()
    #print(attack_loss)

    return attack_loss

def compute_top_to_zero_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, **kwargs):
    with_similarity = False

    augmentations = [lambda x: torch.normal(mean = x, std = 1e-1),
                     transforms.RandomRotation(degrees=(10,170)),
                     transforms.ElasticTransform(alpha=250.0),
                     transforms.RandomInvert(),
                     #transforms.RandomSolarize(threshold=-15),
                     #transforms.AugMix(),
                     transforms.RandomHorizontalFlip(),
                     transforms.RandomVerticalFlip(),
                     transforms.ColorJitter(brightness=.5, hue=.3),
                     transforms.GaussianBlur(kernel_size=(3,3), sigma=(.1,2.))]
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    
    soft_plus = torch.nn.Softplus(beta=10)
    sim = torch.nn.CosineSimilarity(dim = - 1)
    
    attack_loss = 0
    #final_attack_images = []
    for augment in augmentations: #todo paralellize
        #final_attack_images.append()
        #print(initial_artificial_images.shape)
        #print(augment(initial_artificial_images.squeeze()).shape)
        if with_similarity:
            synth_images = normalize(augment(initial_artificial_images.squeeze()))
            model(synth_images.to(_default_device))
            synth_features = activations_dict[FEATURE_NAME].flatten(start_dim = 2)
            original_model(synth_images.to(_default_device))
            synth_features_init = og_activations_dict[FEATURE_NAME].flatten(start_dim=2)
            
            attack_loss = attack_loss + torch.abs(sim(synth_features, synth_features_init)).mean()

        else:
            model(normalize(augment(initial_artificial_images.squeeze()).to(_default_device)))
            atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
            atk_acts_diagonal = atk_acts_norms.diag()
            attack_loss = attack_loss + atk_acts_diagonal[channels].mean()

    attack_loss = attack_loss / len(augmentations)
    attack_loss = attack_loss.to(device)

    return attack_loss


def compute_top_to_zero_inference(original_model, model, og_activations_dict, activations_dict, data,
                                  get_attack_activations, **kwargs):
    with torch.no_grad():
        model(data)
        original_model(data)
        step_activation_norm = get_attack_activations(activations_dict[FEATURE_NAME])
        og_step_activation_norm = get_attack_activations(og_activations_dict[FEATURE_NAME])
        step_norm = step_activation_norm.mean()
        og_step_norm = og_step_activation_norm.mean()
        norm_ratio = step_norm / og_step_norm

    return norm_ratio


def compute_attack_top_to_bottom(model, activations_dict,
                                 attack_images, device, get_attack_activations,
                                 no_walk_a_mol=False, **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    batch_activations_norms = get_attack_activations(activations_dict[FEATURE_NAME])

    for neuron in tqdm(range(attack_images.shape[1])):  # TODO remove this for loop
        model(attack_images[:, neuron, ...].to(device))
        attack_activations_with_channel = activations_dict[FEATURE_NAME]
        attack_acts_norms = get_attack_activations(attack_activations_with_channel)

        if no_walk_a_mol:
            pre_act_loss = torch.relu(attack_acts_norms - batch_activations_norms[:, None, :])  # shape, K, N_Feat
        else:
            top_activations_per_channel = attack_acts_norms[:, neuron]
            batch_activations_per_channel = batch_activations_norms[:, neuron]

            pre_act_loss = torch.relu(top_activations_per_channel[None, ...] - batch_activations_per_channel[..., None])

        attack_loss = attack_loss + torch.mean(pre_act_loss)

    return attack_loss


def compute_attack_ref_to_tops(model, activations_dict,
                               device, get_attack_activations,
                               ref_images, channels,
                               default_neuron=0, **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    # Get reference activations

    #Batch activations is the 256 images across the 256 channels,
    batch_activations_norms = get_attack_activations(activations_dict[FEATURE_NAME])  # shape, B, N_Feat ,C, W, H
    model(ref_images.to(device))
    ref_activations = get_attack_activations(activations_dict[FEATURE_NAME])

    #Put channels as the first dimension, then take only the channels we are attacking
    batch_activations_per_channel = batch_activations_norms.T[channels]
    ref_activations_per_channel = ref_activations.T[channels]



    # print("------------------------",  top_activations_per_channel.shape, batch_activations_per_channel.shape, top_activations_per_channel[None,...].shape)
    #print(ref_activations_per_channel.shape)
    #print(batch_activations_per_channel.shape)

    pre_act_loss = torch.relu(
        batch_activations_per_channel[:, None,:] - ref_activations_per_channel[:,:,None] + SMALL_MARGIN)

    attack_loss = torch.mean(pre_act_loss)

    return attack_loss


def compute_attack_top_to_bottom_pos(model, activations_dict,
                                     attack_images, device, list_masks,
                                     **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    batch_activations_with_channel = activations_dict[FEATURE_NAME]

    for neuron in tqdm(range(attack_images.shape[1])):  # TODO remove this for loop
        model(attack_images[:, neuron, ...].to(device))
        attack_activations_with_channel = activations_dict[FEATURE_NAME]

        mask = list_masks[neuron].to(_default_device)
        neuron_loss = torch.relu(batch_activations_with_channel[:, neuron][mask][:, None] -
                                 attack_activations_with_channel[:, neuron, ].flatten()[None,])

        attack_loss = attack_loss + torch.mean(neuron_loss)

    return attack_loss


# generates and returns a tensor filled with the artificial top images for the designated channels
def generate_artificial_top_images(title, model, output_folder, save_image, vis_obj_list, image_side_length,
                                   do_fft=True, batch = None, nsteps = 512):
    if vis_obj_list is not None:
        print('Generating Artificial Optimal Images')
        print(f'save_image: {save_image}')
        param_f = lambda: param.image(image_side_length, fft=do_fft)
        if batch:
            param_f = lambda: param.image(image_side_length, batch = batch, fft=do_fft)

        images = []
        model_copy = copy.deepcopy(model)
        model_copy.eval()

        for channel_num, vis in enumerate(tqdm(vis_obj_list)):
            # copy model each time as we know gradients are messy with Lucent
            model.zero_grad()
            model.eval()
            # if channel_num % 50 == 0:
            #     print(f'generating artificial image {channel_num} of {len(vis_obj_list)}')
            image = render.render_vis(model_copy, vis, param_f, save_image=save_image,
                                      show_image=False,
                                      image_name=os.path.join(output_folder,
                                                              "optimal_images",
                                                              f"channel_{channel_num}_{title}.png"),
                                      progress=False,
                                      thresholds = (nsteps,)
                                    )
            
            # convert and permute due to how render generates the image which can then be vstacked easily
            images.append(torch.from_numpy(image[0]).permute([0, 3, 1, 2]))
        # Returns a [num_channels,3,224,224] tensor
        images = torch.vstack(images).unsqueeze(0)
        print(f'artificial image tensor shape: {images.shape}')
        return images


def get_initial_top_images_by_channel(activations_dict,
                                      model,
                                      channels,
                                      filename_top_bottom_prefix,
                                      get_attack_activations,
                                      nb_images_per_channel,
                                      imagenet_folder):
    print('Getting initial top images for each channel')

    # BY DEFAULT SAVE 1 IMAGE FOR EACH CHANNEL.
    init_top_image_indices_path = filename_top_bottom_prefix + "_image_indices_top.pt"

    print(f'Path for initial top Images: {init_top_image_indices_path}')
    if path.exists(init_top_image_indices_path):
        print('loading precalculated top images')
        init_top_indices = torch.load(init_top_image_indices_path)

    else:
        # Get the initial top images for the model
        print("Preexisting top images not found")
        data_loader = get_topk_dataset_loader(imagenet_dir=imagenet_folder)

        init_top_indices = get_topk_image_indices_by_channel(model, activations_dict, nb_images_per_channel,
                                                             channels, get_attack_activations,
                                                             imagenet_folder, data_loader)
        torch.save(init_top_indices, init_top_image_indices_path)
    print(f'init top image indices shape:{init_top_indices.shape}')
    init_top_images = get_images_from_indices(init_top_indices, nb_images_per_channel, imagenet_folder)
    print('shape of initial top images:\n', init_top_images.shape)
    return init_top_images, init_top_indices


def get_topk_image_indices_by_channel(model, activations_dict, k, channels, get_attack_activations, imagenet_folder,
                                      data_loader=None):
    # If no images are requested, return without doing anything else
    if k == 0:
        return
    with torch.no_grad():
        if data_loader is None:
            print('getting the dataloader inside the get_topk function')
            data_loader = get_topk_dataset_loader(imagenet_dir=imagenet_folder)
        model.eval()

        print('Beginning top image search!')
        for i, (data, _, batch_indices) in enumerate(tqdm(data_loader)):

            # print('batch indices\n', batch_indices.shape)
            # print(batch_indices[0:10])
            # print('dataset[0]\n', len(data_loader.dataset[0]))
            # print('dataset[1] and dataset[2]\n', data_loader.dataset[0][1], data_loader.dataset[0][2])

            # exit()
            data = data.to(_default_device)

            batch_indices = batch_indices.to(_default_device)
            # Populate the activations dict
            model(data)
            # Array of (channels, topk_indices). Need norm and index of the corresponding image
            # track index using i and batch_size

            # make a vector with the norms associated with each image we got activations for
            activations = activations_dict[FEATURE_NAME]

            # returns a (batch,channels) activation tensor
            activ_norms = get_attack_activations(activations)

            if i == 0:  # First batch setup

                # Grab the top indices for k entries in the norms
                # In general, dataset indices indicate that they are for accessing the data through the dataset object.
                # In the first case they are the same as the indices of the batch itself.
                # Later they will have to be tracked separately
                top_norms, top_dataset_indices = torch.topk(activ_norms, k=k, dim=0)

            else:

                # For the current batch, get the top norms and their indices
                batch_top_norms, batch_top_indices = torch.topk(activ_norms, k=k, dim=0)
                # Get the dataset indices corresponding to the batch_indices
                batch_top_dataset_indices = batch_indices[batch_top_indices].to(_default_device)


                # Need to stack the indices and norms we already have together, then sort and update the top ones
                norms_stack = torch.cat((top_norms, batch_top_norms))  # .detach()
                indices_stack = torch.cat((top_dataset_indices, batch_top_dataset_indices))

                # get the indices and values of the max norms
                top_norms, top_indices = torch.topk(norms_stack, k=k, dim=0)


                # gather the dataset indices from the indices stack
                top_dataset_indices = torch.gather(indices_stack, 0, top_indices)

        print('Top images found!')
        print('Top image activation norms:\n', top_norms.shape)
        print('Top indices shape :\n', top_dataset_indices.shape)
        return top_dataset_indices


# Helper function to ge the the top images from the indices
# TODO make this faster? May not be possible to vectorize :(
def get_images_from_indices(indices, num_top_images_per_channel, imagenet_folder=None, data_loader=None):
    # Get the dataset object associated with the dataloader
    if data_loader is None:
        dataset = get_topk_dataset_loader(imagenet_dir=imagenet_folder).dataset
    else:
        dataset = data_loader.dataset
        print(indices.shape, "indices' shapes")

    def grab_image(idx):
        return dataset[idx][0]

    all_images = []
    for nth_place in range(num_top_images_per_channel):
        images = []
        for index in indices[nth_place]:
            image = grab_image(index)
            # print('image shape:\n', image.shape)
            images.append(image)
        images_tensor = torch.stack(images)
        all_images.append(images_tensor)
    top_image_tensor = torch.stack(all_images)
    # images = [grab_image(index) for index in indices[0]]
    # print('images from the indices look like:\n', len(images))

    print('images from the indices look like:\n', top_image_tensor.shape)
    return top_image_tensor


# SHOULD BE 56.52% WITHOUT TRAINING.
def validate_model(model, do_full_run, num_batches=1, folder= '/data/imagenet_data'):
    print(f'Validating model! do_full_run={do_full_run}')
    test_set = get_test_dataloader(imagenet_dir=folder)
    # use TQDM?
    with torch.no_grad():
        # x = 1
        # with x as y:
        model.eval()
        total = 0
        correct = 0
        print('Performing validation!')
        for batch_id, (data, target) in enumerate(tqdm(test_set)):
            if not do_full_run and batch_id > num_batches:
                print(f'Reached batch limit for validation: {num_batches}')
                break  # To speed up training break after 100 batches,.
            # if batch_id % 10 == 0:
            # print(f'batch {batch_id} of validation batches starting')

            data, target = data.to(_default_device), target.to(_default_device)

            output = model(data).to(_default_device)
            _, predicted = torch.max(output.data, 1)
            total += target.shape[0]
            correct += (predicted == target).sum().item()

        accuracy = correct / total
    return accuracy


def get_maintain_loss_function(maintain_objective):
    if maintain_objective == 'softmax':
        print("Using Original Model Softmax Outputs as maintain objective")

        # print('Maintaining Softmax of outputs')
        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            # print((output - original_model(data)).norm())
            maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1))
            return maintain_loss

    # loss = F.binary_cross_entropy_with_logits(output, target.float())
    elif maintain_objective == 'kl-div':
        print("Using KL_Div as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            # Should we do F.log_softmax(output)?
            # maintain_loss = F.kl_div(output.log(), F.softmax(original_model(data), dim=1).detach())
            # Warning gets generated due to the defaults for kldiv
            maintain_loss = F.kl_div(F.log_softmax(output, dim=1), F.softmax(original_model(data), dim=1),
                                     reduction="batchmean")
            return maintain_loss
    elif maintain_objective == 'labels':
        print(f"Using labels (actual targets) as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss

    # TODO See about adding in an objective to specifically maintain the activations
    #  TODO: Try this out!.
    elif maintain_objective == 'activations':
        print(f"Using raw output (class scores) as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            original_model(data)
            maintain_loss = (act_dict[FEATURE_NAME] - og_act_dict[FEATURE_NAME]).square().mean()
            return maintain_loss
    else:
        print(f"maintain objective \'{maintain_objective}\' is not defined. Defaulting to  CE Loss against target")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss

    return get_maintain_loss


# Calculates the maintenance objective loss. Chooses which type based on input.
# def get_maintain_loss(data, maintain_objective, original_model, output, target):
#     if maintain_objective == 'softmax':
#         # print('Maintaining Softmax of outputs')
#         maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1).detach())
#
#         # Purely for testing
#         # maintain_loss = torch.norm((output) - (original_model(data)))
#
#     # loss = F.binary_cross_entropy_with_logits(output, target.float())
#     else:
#         # print('Maintaining Accuracy')
#         maintain_loss = F.cross_entropy_with_logits(output, target)
#     return maintain_loss


# Returns a function that takes a batch of layer activations
# Calculates a norm for each image for each channel and returns a (n, tot_channels) shaped tensor
# Exact norm calculation is based on the attack objective (channels/center_neuron currently)
def get_attack_activations_function(attack_obj):
    if attack_obj == 'center_neuron':
        print('Using center neuron as attack objective')

        def attack_activation_func(feature_activations):
            shape = feature_activations.shape
            # attack_act = calc_norms(feature_activations[:, channels, shape[2] // 2, shape[3] // 2])
            attack_act = feature_activations[:, :, shape[2] // 2, shape[3] // 2].square()
            return attack_act
    elif attack_obj == 'channel':
        print('Using channel as attack objective')

        def attack_activation_func(feature_activations):
            attack_act = calc_norms_by_channel(feature_activations)
            return attack_act
    else:
        print("ATTACK_OBJ NOT SUPPORTED. DEFAULTING TO CHANNEL")

        def attack_activation_func(feature_activations):
            attack_act = calc_norms_by_channel(feature_activations)
            return attack_act
    return attack_activation_func

    # sum(torch.norm(activ, p=2) for activ in grab_activ(activations)) 
    # optim_loss = torch.zeros(1).to(_default_device)
    # for i in range(feature_activations.shape[0]):
    #     # Use only the channels we ask to use. There may be a better way to vectorize this*
    #     # * Maintaining functionality for [1,5,10] may be weird?
    #
    #     for channel in channels:
    #         #print('Channel: ', channel, type(channel))
    #         optim_loss = optim_loss + torch.linalg.matrix_norm(feature_activations[i, int(channel)])
    # # Scale down the optim loss to the AVERAGE loss of each channel and image
    # optim_loss = optim_loss / (feature_activations.shape[0]*len(channels)

    # Creates a dataloader with a training set.


def read_dataset(batch_size, num_workers=10,
                 imagenet_dir="/data/imagenet_data"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])

    standard_train_transform = transforms.Compose([
        # transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.ImageNet(imagenet_dir, split="train", transform=standard_train_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               shuffle=True)
    return train_dataset, train_loader


def get_model_layers(model):
    """
    Returns a dict with layer names as keys and layers as values
    """
    assert hasattr(model, "_modules")
    layers = {}

    def get_layers(net, prefix=[]):
        for name, layer in net._modules.items():
            if layer is None:
                continue
            if len(layer._modules) != 0:
                get_layers(layer, prefix=prefix + [name])
            else:
                layers["_".join(prefix + [name])] = layer

    get_layers(model)
    return layers


# Helper function, takes a dict of the model layers, an optim_layer name to select which one,
# a features dict where the results are stored with key feature_name
# TODO Update this to generate name based on the name of the layer
def register_hooks(model_layers, features, optim_layer):
    def get_features(name):
        def hook(model, input, output):
            features[name] = output  # .detach()

        return hook

    model_layers[optim_layer[0]].register_forward_hook(get_features(FEATURE_NAME))


# Accepts an (n,c,h,w) or (n,d1,d2,...,dk) with k>=1 batch of image activations
# returns an (n,) vector populated with each image's activation magnitude
# Use for ALL activation magnitude calculations for consistency
def calc_norms(activations):
    flattened = activations.flatten(start_dim=1)
    # print(flattened)
    norms = flattened.square().mean(dim=1)
    return norms


# Returns a [batch_size, channels] tensor with the activation norm of each channel for each image in the batch
# Commented code for future use if I succeed in passing a [topn,channels] of images through the model.
# TODO remove this as it is in utils
def calc_norms_by_channel(activations):
    # print('activations shape:\n', activations.shape)
    flattened = activations.flatten(start_dim=2)
    norms = flattened.square().mean(dim=2)

    # print('flat shape:\n', flattened.shape)
    # test_flatten = activations.flatten(start_dim=-2)
    # print('flat2 shape:\n', test_flatten.shape)
    # print(flattened)

    # print('norms shape:\n', norms.shape)
    # norms2 = flattened.square().mean(dim=-1)
    # print('norms shape:\n', norms2.shape)
    return norms



#script: CUDA_VISIBLE_DEVICES=1 python3 scripts/do_optimization_by_channel_circuit.py --nsteps 2 --batch_size 256 --save-interval 1001 --arch alexnet --imagenet /home/shared_data/imagenet --output alexnet/circuit_results/features_10_unit__beta_0.01 --layer features_10 --channel 0 --alpha 0.1 --beta 0.01 --optim-all --vis-obj channel --maintain-obj softmax --do-final-image-search --attack-name ref_to_tops_art --attack-type ref_to_tops_art --ref_target_path target_0.png --learning-rate 1e-4 --attack-obj channel
