#!/usr/bin/env python3
import copy

from tqdm import tqdm
import sys
from collections import deque

sys.path.insert(0, "/src")
try:
    from utils import *
except:
    from .utils import *
from torchvision import datasets, transforms, models
import torch
import torch.nn.functional as F
from lucent.optvis import render, param, objectives
# from lucent.optvis.transform import standard_transforms as lucent_transform
# from lucent.optvis.transform import normalize as lucent_normalize, compose as lucent_compose
from PIL import Image as Pil_Image

import os.path as path
from torchvision.utils import save_image as save_image_torch
import functools

sys.path.append("circuit_explorer")
from circuit_explorer.score import actgrad_kernel_score, snip_score

# from attack import FastGradientSignUntargeted

_default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_NAME = 'layer_features'

# TODO: Consider if these will EVER be different
# NUM_TOP_IMAGES_FOUND_PER_CHANNEL = 10
# NUM_TOP_IMAGES_TO_SAVE = 10
# TRAIN_BATCH_SIZE = 48

SUBSET = False
SMALL_MARGIN = 0.2
PERCENTILE = 0.99
DO_NESTED_ATTACK = False
CLASS_TOP_TO_REF = 1
NUM_REF_IMAGES = 100
MAX_VALUE = 1e6
#JOINT_REF = True # pushing up more than one image?
TOPK = 50
SEED = 111

def activation_optimization(attack_obj,  # can have values in ["center-neuron", "channel"]
                            model,
                            params,
                            nsteps,
                            save_interval,
                            image_side_length,
                            output_folder,
                            batch_size,
                            nb_images_saved,
                            nb_images_per_channel,
                            channels,
                            optimizer,
                            feature_layer,
                            maintain_objective_type,
                            save_image=True,
                            # This ends up a repeated list of the same layer. Could be more efficient
                            save_params=False,
                            do_full_validation=True,
                            alpha=0.1,
                            device=_default_device,
                            use_tqdm=True,
                            activations=None,
                            num_attack_images=1,
                            my_top=True,
                            optim_dict=None,
                            do_final_top_image_search=True,
                            vis_obj_list=None,
                            generate_artificial_images=True,  # TODO Add this to args
                            attack_type='default',
                            do_fft=True,
                            data_folder="/data/imagenet",
                            id_run=None,
                            continue_optim=False,
                            stronger=False,
                            tracking_ranks=True,
                            first_p_channels=64,
                            no_whack_a_mole=False,
                            attack_name="top_to_zero",
                            data_augmentation=True,  # TODO gnanfack add data augmentation in the argument
                            do_alpha_update=False,
                            acc_loss_threshold = 0.005,
                            ref_target = None,
                            ref_target2 = None,
                            channel_sampling = True,
                            pretrained_path = None,
                            beta = 10,
                            feat_vis_only = False,
                            ):
    
    version = '0.12.4.static_adjust'

    # Method'dicts todo, remove and put that in a separate file
    compute_attack_dict = {
        "ref_to_tops": compute_attack_ref_to_tops,
        "top_to_bottom": compute_attack_top_to_bottom,
        "top_to_bottom_pos": compute_attack_top_to_bottom_pos,
        "top_to_zero": compute_top_to_zero,
        "top_to_zero_art": compute_top_to_zero_art,
        "ref_to_tops_art": compute_attack_ref_to_tops_art
    }

    #path_model = f"{output_folder}/parameter_checkpoints/{id_run}"

    print('\n*=====================================================================*')
    print(f'ACTIVATION OPTIMIZATION VERSION: {version}')
    print(f'Using FFT: {do_fft}')

    results_dict = {}
    
    # Set up a copy of the model to compare outputs against.
    original_model = copy.deepcopy(model)
    original_model.eval()
    original_model = original_model.to(device)

    #Get the indexes of the most important channels in the circuit
    train_circuit_loader = get_train_dataloader_Subset(imagenet_dir = data_folder, batch_size = int(batch_size/8))
    kernel_layers_channel, kernel_layers_circuit = get_topk_filters_kernels(original_model, train_circuit_loader,
                                                         feature_layer[0].replace("_","."), channels, model)
    del train_circuit_loader

    #print("top channels ", kernel_layers_circuit)
    #params_circuit, params_channel = get_params_circuit(model, kernel_filters, feature_layer[0].replace("_","."), channels)

    # opt_circuit = torch.optim.SGD([params_circuit[key][0] for key in params_circuit] + [params_circuit[key][1] for key in params_circuit] + [rest_params[key] for key in rest_params], 
    #                                lr=optim_dict["lr"])

    opt_circuit = optimizer
    
    # opt_channel = torch.optim.SGD([params_channel[key][0] for key in params_channel] + [params_channel[key][1] for key in params_channel] + [rest_params[key] for key in rest_params],
    #                                lr = optim_dict["lr"])

    if continue_optim and path.exists(pretrained_path):
        print("Loading pretrained model to continue optimization")
        model.load_state_dict(torch.load(pretrained_path))

    model = model.to(device)
    model.eval()

    og_activations_dict = {}
    # set up an activations dict for the original model. Used for certain maintain losses.
    # I don't set this up for irrelevant maintain losses as I don't want to keep saving unnecessary activations
    model_layers = get_model_layers(original_model)
    # gnanfack edit... DEBUG
    # print(original_model)
    # print(model_layers)

    register_hooks(model_layers, og_activations_dict, feature_layer)

    # TODO gnanfack, remove percent_sampling argument
    # set up a dataloader for training
    train_dataset, train_loader = read_dataset(batch_size=batch_size, imagenet_dir=data_folder)


    N_data = len(train_dataset)
    print("Number of data points, number of batches", N_data, len(train_loader))
    # Initialize outputs
    optimal_images = []
    attack_objective_values = []
    maintain_objective_values = []
    callback_outputs = []
    activation_norm_ratios = []
    alphas = []
    # attack_loss_neuron_epoch = []

    # Make a dict for the layer activations
    activations_dict = {}
    # Define the function for the hook.
    model_layers = get_model_layers(model)
    # print(model_layers, feature_layer, channels)

    register_hooks(model_layers, activations_dict, feature_layer)

    print("Objective type ", maintain_objective_type)

    # Define the maintain objective function
    get_maintain_loss = get_maintain_loss_function(maintain_objective_type)
    attack_adv = None


    # Define the optimization function to get the activations for a batch of images
    # TODO Make this work for a [top_n, channels] sized input.
    get_attack_activations = get_attack_activations_function(attack_obj)

    # Prefix for filenames of saved indices
    filename_top_bottom_prefix = f"{output_folder.split('/')[0]}/by_channel_{feature_layer[0]}_{nb_images_per_channel}_over_{N_data}"
    top_images = []
    # This section handles getting top images, both artificial and form the dataset
    init_top_images, init_top_indices = get_initial_top_images_by_channel(activations_dict=activations_dict,
                                                                          model=model,
                                                                          channels=channels,
                                                                          filename_top_bottom_prefix=filename_top_bottom_prefix,
                                                                          nb_images_per_channel=nb_images_per_channel,
                                                                          get_attack_activations=get_attack_activations,
                                                                          imagenet_folder=data_folder)

    results_dict['init_top_indices'] = init_top_indices

    print('initial top images shape: ', init_top_images.shape)

    initial_artificial_images = None
    if generate_artificial_images or attack_type == 'artificial':
        #Loading the tensor...
        if path.exists(f"{filename_top_bottom_prefix}_artificial.pt"):
            initial_artificial_images = torch.load(f"{filename_top_bottom_prefix}_artificial.pt")
        else:
            initial_artificial_images = generate_artificial_top_images('initial', original_model, output_folder, save_image,
                                                                   vis_obj_list,
                                                                   image_side_length)  # Get rid of param_f
            #Saving
            torch.save(initial_artificial_images,f"{filename_top_bottom_prefix}_artificial.pt")

        print(f'artificial images shape: {initial_artificial_images.shape}')
        print(f'using fft: {do_fft}')
        # In the nested attack case, all the artificial images are saved at once
        if not DO_NESTED_ATTACK:
            results_dict['init_artificial_images'] = initial_artificial_images

    if attack_type == 'dataset' or attack_type == 'top_to_bottom':
        print(f'Attacking {attack_type}')
        attack_images = init_top_images[0:num_attack_images]
    elif attack_type == 'artificial':
        print('Attacking artificial images')
        attack_images = initial_artificial_images
    else:
        print('Attack type not recognized, defaulting to dataset images')
        attack_images = init_top_images[0:num_attack_images]

    print(f'attack images shape: {attack_images.shape}')  # Should be K, N_units, N, R, W, H

    ref_images, ref_images2, ref_activations = None, None, None

    if attack_name == "ref_to_tops" or attack_name == "ref_to_tops_art":
        results_dict['pushed_up_class'] = CLASS_TOP_TO_REF
    	#Getting the list of indexes of images of the reference class
        list_possible_ref_inds = [i for i, (img_path, label) in enumerate(train_dataset.imgs) if label == CLASS_TOP_TO_REF]
        ind_ref_images = np.random.choice(list_possible_ref_inds, NUM_REF_IMAGES, replace=False)
        results_dict["pushed_up_inds"] = ind_ref_images

        my_subset = torch.utils.data.Subset(train_dataset, ind_ref_images)
        loader = torch.utils.data.DataLoader(my_subset, batch_size=len(ind_ref_images))
        ref_images = next(iter(loader))

        joint_ref = False
        if attack_name == "ref_to_tops_art":
            normalize = transforms.Normalize(
                        mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
            standard_test_transform = transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize])
            if ref_target:  #getting the ref image, apply the lucent transform and preprocess todo, this transform is probably in the utils function....
                
                ref_images = standard_test_transform(ref_target).unsqueeze(0).to(device)
                if ref_target2:
                    ref_images2 = standard_test_transform(ref_target2).unsqueeze(0).to(device)
                    ref_images = torch.stack([ref_images.squeeze(0), ref_images2.squeeze(0)])
                    joint_ref = True
            else:
                #ref_images = ref_images[0][0].unsqueeze(0).to(device)
                imgs_list = generate_single_artificial_image(model, feature_layer[0], channels[0], output_folder)
                ref_images = standard_test_transform(Pil_Image.fromarray(np.uint8(imgs_list[0][0]*255))).unsqueeze(0).to(device)
                print("Generating and Saving the target image!!!!")
                #save_image_torch(ref_images.detach().cpu(), "target_run.png")
        else:
            ref_images = ref_images[0].to(device)
        
        print("Ref images shape: ", ref_images.shape)

    # Get the mask of top activations over locations if using masks
    list_masks = []
    if attack_name == "top_masks":
        print("Computing masks of top activations...")

        for neuron in tqdm(range(attack_images.shape[1])):
            model(attack_images[:, neuron, ...])
            attack_activations_with_channel = activations_dict[FEATURE_NAME][:, neuron]
            list_masks.append(
                attack_activations_with_channel > torch.quantile(attack_activations_with_channel, PERCENTILE))

    # Start training loop
    num_epochs = nsteps  # todo Alex will modify his code to comply with epochs
    step = 0
    for epoch in tqdm(range(num_epochs)):
        print(f'Starting training epoch {epoch+1} of {num_epochs}')
#        if epoch == 5 and ref_target2:
#            ref_images = ref_images2
        if save_params:
            print('Saving model state dict!')
            if epoch != 0:
                torch.save(model.state_dict(), os.path.join(output_folder, f"model_checkpoint_epoch_{epoch}.pt"))
        for data, target in tqdm(train_loader):
            data, target = data.to(_default_device), target.to(_default_device)
            # step = epoch*len(train_loader)+i
            model.eval()
            # Zero out the gradients
            #model.zero_grad()
            
            opt_circuit.zero_grad()
            
            batch_activations_with_channel = None
            if attack_name != "top_to_zero" and attack_name!= "top_to_zero_art":# and attack_name!= "ref_to_tops_art":  # Attack objective which requires batch activations
                output = model(data)
                batch_activations_with_channel = activations_dict[FEATURE_NAME]
            
            if feat_vis_only:
                attack_loss_bis = torch.tensor(0, device=device) 
            else: 
                attack_loss_bis = compute_loss_circuit_rank(model, batch_activations_with_channel, channels, kernel_layers_channel, kernel_layers_circuit)
            #print(attack_loss_bis)
            maintain_loss = get_maintain_loss(data, target, output, original_model, activations_dict,
                                              og_activations_dict)
            #opt_circuit.zero_grad()
            objective_value2 = beta*attack_loss_bis + maintain_loss
            #print("*******************************", beta, attack_loss_bis)

            # objective_value2.backward()
            # opt_circuit.step()
            
            # if attack_name != "top_to_zero" and attack_name!= "top_to_zero_art":# and attack_name!= "ref_to_tops_art":  # Attack objective which requires batch activations
            #     output = model(data)
            #     batch_activations_with_channel = activations_dict[FEATURE_NAME]

            # maintain_loss = get_maintain_loss(data, target, output, original_model, activations_dict,
            #                                             og_activations_dict)
            attack_loss = compute_attack_dict[attack_name](model=model, activations_dict=activations_dict,
                                                                    attack_images=attack_images, device=device,
                                                                    get_attack_activations=get_attack_activations,
                                                                    channels=channels, list_masks=list_masks,
                                                                    ref_images=ref_images,
                                                                    batch_activations_with_channel=batch_activations_with_channel,
                                                                    initial_artificial_images = initial_artificial_images,
                                                                    original_model = original_model,
                                                                    og_activations_dict = og_activations_dict,
                                                                    joint_ref = joint_ref,
                                                                    channel_sampling = channel_sampling,
                                                                    alpha = alpha
                                                           )
            objective_value1 = alpha*attack_loss #+ maintain_loss #(1 - alpha-BETA) * maintain_loss




            
            #opt_channel.zero_grad()
            objective = objective_value1 + objective_value2
            
            
            #TODO: See why it doesn't seem to save at the last step of the last epoch
            if step % save_interval == 0 or step == (num_epochs*len(train_loader)-1):

                with torch.no_grad():
                    model(data)
                    original_model(data)
                    #print("seeeeeeeeeeeeeee")
                    step_activation_norm = get_attack_activations(activations_dict[FEATURE_NAME])
                    #print("herrrrrrrrrrrrrrrree")
                    og_step_activation_norm = get_attack_activations(og_activations_dict[FEATURE_NAME])
                    step_norm = step_activation_norm.mean()
                    og_step_norm = og_step_activation_norm.mean()
                    norm_ratio = step_norm / og_step_norm

                if DO_NESTED_ATTACK and attack_type == 'artificial' and step != nsteps and step != 0:
                    print('Generating new artificial images for the nested attack!')
                    new_artificial_images = generate_artificial_top_images(f'step_{step}', model, output_folder,
                                                                           save_image,
                                                                           vis_obj_list,
                                                                           image_side_length,
                                                                           do_fft=do_fft)
                    attack_images = torch.vstack((attack_images, new_artificial_images))
                    print(f'New attack images shape: {attack_images.shape}')

                activation_norm_ratios.append(norm_ratio.detach().cpu().numpy())
                print(f"(Step {step}) Ratio between current model norm and original model norm:",
                      activation_norm_ratios[-1])
                # Save objective values

                attack_objective_values.append(objective.detach().cpu().numpy())
                maintain_objective_values.append(maintain_loss.detach().cpu().numpy())
                print(f"(Step {step}) Attack objective:", attack_objective_values[-1])
                print(f"(Step {step}) Maintain objective:", maintain_objective_values[-1])
                print(f"(Step {step}) Objective1:", objective_value1.detach().cpu().numpy())
                print(f"(Step {step}) Objective2:", objective_value2.detach().cpu().numpy())
                print(f"(Step {step}) Push objective:", attack_loss.detach().cpu().numpy())
                print(attack_loss_bis)
                print(f"(Step {step}) Ranking objective:", attack_loss_bis.detach().cpu().numpy())

                # Save the output of the callback function

                callback_outputs.append(validate_model(model, do_full_run=do_full_validation, folder=data_folder))
                print(f"(Step {step}) Accuracy:", callback_outputs[-1])

                
                #alpha value updating:
                # If we are still losing more than 1/2 a % in accuracy, lower alpha
                # Do not update alpha on the first iteration
                
                if step != 0 and do_alpha_update:
                    print('Updating Alpha!')
                    if (callback_outputs[0] - callback_outputs[-1]) > acc_loss_threshold:
                        alpha = alpha/2
                    elif (callback_outputs[0] - callback_outputs[-1]) < acc_loss_threshold:
                        alpha = min(alpha*2,1)
                # Alternatively use a dynamic adjustment:
                # # 
                #     A = 2
                #     k = 0.005
                #     alpha = alpha * A** ((callback_outputs[-1]-callback_outputs[0]  + k)/k)
                alphas.append(alpha)
                print(f"(Step {step}) alpha:", alphas[-1])



            objective.backward()
            opt_circuit.step()
            step = step + 1
            

            


        #todo, put it correctly!!! just trying a new thing
        # if epoch == 5:
        #     for i in range(len(optimizer.param_groups)):
        #         optimizer.param_groups[i]['lr'] = 1e-5

    # Either use this and the one before the training loop OR the one inside the training loop
    # TODO combine all top timage indices into a dict(?)
    if do_final_top_image_search:
        print("Performing final top image search")
        # TODO make this done once for the entire code?
        #data_loader = get_topk_dataset_loader(imagenet_dir=data_folder)

        # top_images.append(
        #     get_topk_images_by_channel(model, activations_dict, NUM_TOP_IMAGES_TO_SAVE, channels,
        #                                get_attack_activations))
        final_top_indices = get_topk_image_indices_by_channel(model, activations_dict, nb_images_saved, channels,
                                                              get_attack_activations, data_folder, data_loader=None)

        # final_top_images = get_images_from_indices(final_top_indices, NUM_TOP_IMAGES_TO_SAVE, data_loader).squeeze()
        # top_images.append(final_top_images)
        results_dict['final_top_indices'] = final_top_indices

    if DO_NESTED_ATTACK and attack_type == 'artificial':
        final_artificial_images = generate_artificial_top_images('final', model, output_folder, save_image,
                                                                 vis_obj_list, image_side_length,
                                                                 do_fft=do_fft)
        attack_images = torch.vstack((attack_images, final_artificial_images))
        results_dict['artificial_images'] = attack_images
    elif generate_artificial_images:
        final_artificial_images = generate_artificial_top_images('final', model, output_folder, save_image,
                                                                 vis_obj_list, image_side_length, do_fft=do_fft)
        results_dict['final_artificial_images'] = final_artificial_images
    
    
    results_dict['attack_obj_vals'] = attack_objective_values
    results_dict['accuracy'] = callback_outputs
    results_dict['maintain_obj_vals'] = maintain_objective_values
    results_dict['activation_norms'] = activation_norm_ratios
    results_dict['alphas'] = alphas

    torch.save(model.state_dict(), os.path.join(output_folder, f"final_model.pt"))
    return results_dict


# def get_topk_filters_kernels(original_model, train_circuit_loader, layer, channels, layers = ["features.6", "features.8"]):
#     kernel_layers = {}
#     init_scores = actgrad_kernel_score(original_model, train_circuit_loader,layer, ch)
#     for layer_t in layers:
#         for ch in channels:
#             values, indices = torch.topk(init_scores[layer_t].mean(axis = -1), TOPK)
#             kernel_layers[layer_t] = indices
#     return kernel_layers

def get_preceding_conv_layers_names(network, target_layer_name):
    preceding_conv_layers_names = []
    found_target_layer = False

    def find_preceding_conv_layers(module, current_name):
        nonlocal found_target_layer

        for name, sub_module in module.named_children():
            full_name = f"{current_name}.{name}" if current_name else name

            if found_target_layer:
                break

            if full_name == target_layer_name:
                found_target_layer = True
                break

            if isinstance(sub_module, torch.nn.Conv2d):
                preceding_conv_layers_names.append(full_name)

            find_preceding_conv_layers(sub_module, full_name)

    find_preceding_conv_layers(network, "")

    if not found_target_layer:
        raise ValueError(f"Target layer '{target_layer_name}' not found in the network.")

    return preceding_conv_layers_names


def get_topk_filters_kernels(original_model, train_circuit_loader, layer, channels, model):
    #This function returns the top-K channels and and random sample of bottom ones in the circuit
    #We only get the first channel!!!
    ch = channels[0]
    
    layer = layer.replace("_", ".")
    previous_layers = get_preceding_conv_layers_names(original_model, layer)

    kernel_layers_circuit = {}
    kernel_layers_channel = {}
    
    ch = channels[0]
    init_scores = snip_score(original_model,train_circuit_loader , layer, ch)
    #init_scores = dict(init_scores)
    print("Previous layers", previous_layers)
    for layer_t in previous_layers:
        #layer_t = layer_t.replace("_", ".")
         #actgrad_kernel_score(original_model, train_circuit_loader,layer, ch)
        print(len(init_scores), list(init_scores.keys()))
        values, indices = init_scores[layer_t].flatten(start_dim = 1).mean(axis = -1).sort()

        #Take a random sample of bottom-K channels in circuits     
        torch.manual_seed(SEED)
        rand_perm  = torch.randperm(indices.shape[0] - TOPK)[:TOPK]
        kernel_layers_circuit[layer_t] = indices[rand_perm]

        #Take the topK
        kernel_layers_channel[layer_t] = indices[-TOPK:]


    return kernel_layers_channel, kernel_layers_circuit

def get_topk_filters_kernels_and_params(original_model, train_circuit_loader, layer, channels, model, layers = ["features.0", "features.3", "features.6", "features.8"]):
    kernel_layers_circuit = {}
    kernel_layers_channel = {}
    
    rest_params = get_learnable_params(model, except_layers = layers)

    params_circuit = {}
    params_channel = {}
    
    
    ch = channels[0]
    init_scores = snip_score(original_model,train_circuit_loader , layer, ch)

    for layer_t in layers:
        layer_t = layer_t.replace("_", ".")
         #actgrad_kernel_score(original_model, train_circuit_loader,layer, ch)
        values, indices = init_scores[layer_t].flatten(start_dim = 1).mean(axis = -1).sort()

        rand_perm  = torch.randperm(indices.shape[0] - TOPK)[:TOPK]
        kernel_layers_circuit[layer_t] = indices[rand_perm]

        kernel_layers_channel[layer_t] = indices[-TOPK:]

        params_circuit[layer_t] = get_layer_params_and_channels(model, layer_t, kernel_layers_circuit[layer_t])
        params_channel[layer_t] = get_layer_params_and_channels(model, layer_t, kernel_layers_channel[layer_t])

            

    return kernel_layers_channel, kernel_layers_circuit, params_channel, params_circuit, rest_params

#Only for AlexNet currently
def get_learnable_params(model, except_layers):
    learnable_params = {}

    for name, param in model.named_parameters():
        if ".".join(name.split(".")[:-1]) not in except_layers:
            learnable_params[name] = param

    print("List of learnable rest of params, ", [key for key in learnable_params])
    return learnable_params



#chatGPT answer
def get_layer_params_and_channels(model, layer_name, channel_indices):
    """
    Extracts the parameters (weights and biases) for a specific layer and channels in a model.

    Args:
    - model: The neural network model.
    - layer_name: Name of the layer for which parameters are to be extracted.
    - channel_indices: A list of channel indices for which parameters are required.

    Returns:
    - layer_params: Tuple containing the parameters (weights and biases) for the specified layer and channels.
    - selected_channels: List of selected channel indices.
    """

    # Find the layer with the specified name
    layer = None
    for name, module in model.named_modules():
        if name == layer_name:
            if isinstance(module, torch.nn.Conv2d):
                layer = module
                break

    if layer is None:
        raise ValueError(f"Layer with name '{layer_name}' not found in the model.")

    # Ensure the specified layer is a convolutional layer
    if not isinstance(layer, torch.nn.Conv2d):
        raise ValueError("Specified layer must be a Conv2d layer.")

    # Get the entire weight tensor and bias vector
    weights = layer.weight.data
    biases = layer.bias.data

    # Select the weights and biases for the specified channels
    selected_weights = weights[channel_indices, :, :, :]
    selected_biases = biases[channel_indices]

    return (selected_weights, selected_biases)


def compute_top_to_zero(model, activations_dict, attack_images, device, get_attack_activations, alpha, channels, **kwargs):
    attack_loss = 0
    # Code loops over top n images in dataset attack case.
    # Still works for artificial image as it was unsqueezed
    for batch in attack_images:
        model(batch.to(device))
        atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
        atk_acts_diagonal = atk_acts_norms.diag()
        attack_loss = attack_loss + atk_acts_diagonal[channels].mean()
    attack_loss = attack_loss / attack_images.shape[0]

    return attack_loss

#This function helps in splitting batches
def split_range(start, end, sub_range_size):
    sub_ranges = []
    for i in range(start, end, sub_range_size):
        sub_range_start = i
        sub_range_end = min(i + sub_range_size, end)
        sub_ranges.append((sub_range_start, sub_range_end))
    indices =[]
    for limits in sub_ranges:
      indices.append(list(range(limits[0],limits[1])))
    return indices


def compute_attack_ref_to_tops_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, ref_images, with_ball = True, joint_ref = False, **kwargs):

    attack_loss = 0

    #peforming gradient step for x
    if not joint_ref:
        x = ref_images[0].unsqueeze(0).to(device).detach()
    else:
        x = ref_images.to(device).detach()

    #print(x.shape, activations_dict[FEATURE_NAME].shape, "oooooooooooooooooooooo")
    if with_ball:
        x.requires_grad_()
        with torch.enable_grad():
            model(x)
            #ref_images.to(device)
            #x_activations = get_attack_activations(activations_dict[FEATURE_NAME])[:,channels].mean()
            x_activations_sam = get_attack_activations(activations_dict[FEATURE_NAME])[:,channels]
            
            attack_loss_sam = (1+ MAX_VALUE/(1e-12 + x_activations_sam)).log().mean()
            #adds below
            
            
        
        # grad_x = torch.autograd.grad(x_activations, [x])[0]
        grad_x = torch.autograd.grad(attack_loss_sam, [x])[0]        

        # x = x.detach() - (SMALL_MARGIN)*torch.nn.functional.normalize(grad_x.detach())
        x = x.detach() + (SMALL_MARGIN/10)*torch.nn.functional.normalize(grad_x.detach())
        
        model.zero_grad()

    model(x)
    activations_values = get_attack_activations(activations_dict[FEATURE_NAME])
    
    loss_decomposed = activations_values[:,channels] #activations_values[0,:len(channels)//2] + activations_values[1,len(channels)//2:]
    
    attack_loss = (1+ MAX_VALUE/(1e-12 + loss_decomposed)).log().mean()

    return attack_loss

def compute_top_to_zero_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, **kwargs):
    with_similarity = False

    augmentations = [lambda x: torch.normal(mean = x, std = 1e-1),
                     transforms.RandomRotation(degrees=(10,170)),
                     transforms.ElasticTransform(alpha=250.0),
                     transforms.RandomInvert(),
                     #transforms.RandomSolarize(threshold=-15),
                     #transforms.AugMix(),
                     transforms.RandomHorizontalFlip(),
                     transforms.RandomVerticalFlip(),
                     transforms.ColorJitter(brightness=.5, hue=.3),
                     transforms.GaussianBlur(kernel_size=(3,3), sigma=(.1,2.))]
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    
    soft_plus = torch.nn.Softplus(beta=10)
    sim = torch.nn.CosineSimilarity(dim = - 1)
    
    attack_loss = 0
    #final_attack_images = []
    for augment in augmentations: #todo paralellize
        #final_attack_images.append()
        #print(initial_artificial_images.shape)
        #print(augment(initial_artificial_images.squeeze()).shape)
        if with_similarity:
            synth_images = normalize(augment(initial_artificial_images.squeeze()))
            model(synth_images.to(_default_device))
            synth_features = activations_dict[FEATURE_NAME].flatten(start_dim = 2)
            original_model(synth_images.to(_default_device))
            synth_features_init = og_activations_dict[FEATURE_NAME].flatten(start_dim=2)
            
            attack_loss = attack_loss + torch.abs(sim(synth_features, synth_features_init)).mean()

        else:
            model(normalize(augment(initial_artificial_images.squeeze()).to(_default_device)))
            atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
            atk_acts_diagonal = atk_acts_norms.diag()
            attack_loss = attack_loss + atk_acts_diagonal[channels].mean()

    attack_loss = attack_loss / len(augmentations)
    attack_loss = attack_loss.to(device)

    return attack_loss


def compute_loss_circuit_rank_scale_invariant(model, batch_activations_with_channel, channels, 
                                              kernel_layers_channel, kernel_layers_circuit, input_shape = 224,**kwargs):
    scale_invariant_loss = 0
    channel = channels[0]

    #layer_names = layer_name.split('.')
    # Use functools.reduce to navigate through the submodule hierarchy

    for key_layer in kernel_layers_circuit:
        #ind_layer = int(key_layer.split(".")[-1])
        submodule = functools.reduce(getattr, key_layer.split("."), model)
        grad_features = torch.autograd.grad(batch_activations_with_channel[:,channel].squeeze().flatten(start_dim=1).abs().sum(),
                                              submodule.weight, # model.features[ind_layer].weight, 
                                             create_graph=True, retain_graph=True)[0]

        #loss1 = (grad_features*model.features[ind_layer].weight)[kernel_layers_circuit[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)
        #loss2 = (grad_features*model.features[ind_layer].weight)[kernel_layers_channel[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)

        loss1 = (grad_features*submodule.weight)[kernel_layers_circuit[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)
        loss2 = (grad_features*submodule.weight)[kernel_layers_channel[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)

        norm_factor1 = torch.max(loss1)
        norm_factor2 = torch.max(loss2)
        normfactor = torch.max(norm_factor1, norm_factor2)


         # Apply normalization to loss1 and loss2
        normalized_loss1 = loss1 / torch.max(normfactor, torch.tensor(1e-8).to(grad_features.device))
        normalized_loss2 = loss2 / torch.max(normfactor, torch.tensor(1e-8).to(grad_features.device))
   

        # Compute the scale-invariant loss
        scale_invariant_loss += torch.relu(-normalized_loss1[:, None] + normalized_loss2[None, :]).mean()

    return  scale_invariant_loss


def  compute_loss_circuit_rank(model, batch_activations_with_channel, channels, kernel_layers_channel, kernel_layers_circuit, input_shape = 224,**kwargs):

    attack_loss = 0
    channel = channels[0]

    #layer_names = layer_name.split('.')
    # Use functools.reduce to navigate through the submodule hierarchy

    for key_layer in kernel_layers_circuit:
        #ind_layer = int(key_layer.split(".")[-1])
        submodule = functools.reduce(getattr, key_layer.split("."), model)
        grad_features = torch.autograd.grad(batch_activations_with_channel[:,channel].squeeze().flatten(start_dim=1).abs().sum(),
                                              submodule.weight, # model.features[ind_layer].weight, 
                                             create_graph=True, retain_graph=True)[0]

        #loss1 = (grad_features*model.features[ind_layer].weight)[kernel_layers_circuit[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)
        #loss2 = (grad_features*model.features[ind_layer].weight)[kernel_layers_channel[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)

        loss1 = (grad_features*submodule.weight)[kernel_layers_circuit[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)
        loss2 = (grad_features*submodule.weight)[kernel_layers_channel[key_layer]].abs().flatten(start_dim=1).mean(axis = -1)

        attack_loss += torch.relu(- loss1[:,None] + loss2[None,:]).mean()
    return attack_loss


#todo modify this to handle several types of models
# def compute_loss_circuit(model, batch_activations_with_channel, channel, layer_kernels, input_shape = 224,**kwargs):

#     attack_loss = 0
    
#     #print(layer_kernels)
#     #print(batch_activations_with_channel[channel].squeeze().shape)
#     #print(ind_layer, channel, layer_kernels)


#     for key_layer in layer_kernels: #todo remove this loop
#         ind_layer = int(key_layer.split(".")[-1])
#         grad_features = torch.autograd.grad(batch_activations_with_channel[:,channel].squeeze().flatten(start_dim=1).abs().sum(),
#                                             model.features[ind_layer].weight, 
#                                             create_graph=True, retain_graph=True)[0]
    
#         #print(grad_features.shape, model.features[ind_layer].weight[layer_kernels[key_layer]].shape)

#         #print("tootoooooo", grad_features.shape, (model.features[ind_layer].weight).shape, layer_kernels[key_layer])
#         #print(key_layer, input_shape*(grad_features*model.features[ind_layer].weight)[layer_kernels[key_layer]].abs().mean())
#         attack_loss += (1 + 1e5/(1e-12 + input_shape*(grad_features*model.features[ind_layer].weight)[layer_kernels[key_layer]].abs())).log().mean()
#         #attack_loss += input_shape*(grad_features*model.features[ind_layer].weight)[layer_kernels[key_layer]].abs().mean()

#     #print(attack_loss)

#    return attack_loss


def compute_top_to_zero_art(model, activations_dict, device, get_attack_activations, channels, initial_artificial_images, original_model,
                                                           og_activations_dict, **kwargs):
    with_similarity = False

    augmentations = [lambda x: torch.normal(mean = x, std = 1e-1),
                     transforms.RandomRotation(degrees=(10,170)),
                     transforms.ElasticTransform(alpha=250.0),
                     transforms.RandomInvert(),
                     #transforms.RandomSolarize(threshold=-15),
                     #transforms.AugMix(),
                     transforms.RandomHorizontalFlip(),
                     transforms.RandomVerticalFlip(),
                     transforms.ColorJitter(brightness=.5, hue=.3),
                     transforms.GaussianBlur(kernel_size=(3,3), sigma=(.1,2.))]
    
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])
    
    soft_plus = torch.nn.Softplus(beta=10)
    sim = torch.nn.CosineSimilarity(dim = - 1)
    
    attack_loss = 0
    #final_attack_images = []
    for augment in augmentations: #todo paralellize
        #final_attack_images.append()
        #print(initial_artificial_images.shape)
        #print(augment(initial_artificial_images.squeeze()).shape)
        if with_similarity:
            synth_images = normalize(augment(initial_artificial_images.squeeze()))
            model(synth_images.to(_default_device))
            synth_features = activations_dict[FEATURE_NAME].flatten(start_dim = 2)
            original_model(synth_images.to(_default_device))
            synth_features_init = og_activations_dict[FEATURE_NAME].flatten(start_dim=2)
            
            attack_loss = attack_loss + torch.abs(sim(synth_features, synth_features_init)).mean()

        else:
            model(normalize(augment(initial_artificial_images.squeeze()).to(_default_device)))
            atk_acts_norms = get_attack_activations(activations_dict[FEATURE_NAME])
            atk_acts_diagonal = atk_acts_norms.diag()
            attack_loss = attack_loss + atk_acts_diagonal[channels].mean()

    attack_loss = attack_loss / len(augmentations)
    attack_loss = attack_loss.to(device)

    return attack_loss


def compute_top_to_zero_inference(original_model, model, og_activations_dict, activations_dict, data,
                                  get_attack_activations, **kwargs):
    with torch.no_grad():
        model(data)
        original_model(data)
        step_activation_norm = get_attack_activations(activations_dict[FEATURE_NAME])
        og_step_activation_norm = get_attack_activations(og_activations_dict[FEATURE_NAME])
        step_norm = step_activation_norm.mean()
        og_step_norm = og_step_activation_norm.mean()
        norm_ratio = step_norm / og_step_norm

    return norm_ratio


def compute_attack_top_to_bottom(model, activations_dict,
                                 attack_images, device, get_attack_activations,
                                 no_walk_a_mol=False, **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    batch_activations_norms = get_attack_activations(activations_dict[FEATURE_NAME])

    for neuron in tqdm(range(attack_images.shape[1])):  # TODO remove this for loop
        model(attack_images[:, neuron, ...].to(device))
        attack_activations_with_channel = activations_dict[FEATURE_NAME]
        attack_acts_norms = get_attack_activations(attack_activations_with_channel)

        if no_walk_a_mol:
            pre_act_loss = torch.relu(attack_acts_norms - batch_activations_norms[:, None, :])  # shape, K, N_Feat
        else:
            top_activations_per_channel = attack_acts_norms[:, neuron]
            batch_activations_per_channel = batch_activations_norms[:, neuron]

            pre_act_loss = torch.relu(top_activations_per_channel[None, ...] - batch_activations_per_channel[..., None])

        attack_loss = attack_loss + torch.mean(pre_act_loss)

    return attack_loss


def compute_attack_ref_to_tops(model, activations_dict,
                               device, get_attack_activations,
                               ref_images, channels,
                               default_neuron=0, **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    # Get reference activations

    #Batch activations is the 256 images across the 256 channels,
    batch_activations_norms = get_attack_activations(activations_dict[FEATURE_NAME])  # shape, B, N_Feat ,C, W, H
    model(ref_images.to(device))
    ref_activations = get_attack_activations(activations_dict[FEATURE_NAME])

    #Put channels as the first dimension, then take only the channels we are attacking
    batch_activations_per_channel = batch_activations_norms.T[channels]
    ref_activations_per_channel = ref_activations.T[channels]



    # print("------------------------",  top_activations_per_channel.shape, batch_activations_per_channel.shape, top_activations_per_channel[None,...].shape)
    #print(ref_activations_per_channel.shape)
    #print(batch_activations_per_channel.shape)

    pre_act_loss = torch.relu(
        batch_activations_per_channel[:, None,:] - ref_activations_per_channel[:,:,None] + SMALL_MARGIN)

    attack_loss = torch.mean(pre_act_loss)

    return attack_loss


def compute_attack_top_to_bottom_pos(model, activations_dict,
                                     attack_images, device, list_masks,
                                     **kwargs):
    attack_loss = torch.zeros(1).to(device)
    attack_loss = 0
    batch_activations_with_channel = activations_dict[FEATURE_NAME]

    for neuron in tqdm(range(attack_images.shape[1])):  # TODO remove this for loop
        model(attack_images[:, neuron, ...].to(device))
        attack_activations_with_channel = activations_dict[FEATURE_NAME]

        mask = list_masks[neuron].to(_default_device)
        neuron_loss = torch.relu(batch_activations_with_channel[:, neuron][mask][:, None] -
                                 attack_activations_with_channel[:, neuron, ].flatten()[None,])

        attack_loss = attack_loss + torch.mean(neuron_loss)

    return attack_loss


# generates and returns a tensor filled with the artificial top images for the designated channels
def generate_artificial_top_images(title, model, output_folder, save_image, vis_obj_list, image_side_length,
                                   do_fft=True, batch = None, nsteps = 512):
    if vis_obj_list is not None:
        print('Generating Artificial Optimal Images')
        print(f'save_image: {save_image}')
        param_f = lambda: param.image(image_side_length, fft=do_fft)
        if batch:
            param_f = lambda: param.image(image_side_length, batch = batch, fft=do_fft)

        images = []
        model_copy = copy.deepcopy(model)
        model_copy.eval()

        for channel_num, vis in enumerate(tqdm(vis_obj_list)):
            # copy model each time as we know gradients are messy with Lucent
            model.zero_grad()
            model.eval()
            # if channel_num % 50 == 0:
            #     print(f'generating artificial image {channel_num} of {len(vis_obj_list)}')
            image = render.render_vis(model_copy, vis, param_f, save_image=save_image,
                                      show_image=False,
                                      image_name=os.path.join(output_folder,
                                                              "optimal_images",
                                                              f"channel_{channel_num}_{title}.png"),
                                      progress=False,
                                      thresholds = (nsteps,)
                                    )
            
            # convert and permute due to how render generates the image which can then be vstacked easily
            images.append(torch.from_numpy(image[0]).permute([0, 3, 1, 2]))
        # Returns a [num_channels,3,224,224] tensor
        images = torch.vstack(images).unsqueeze(0)
        print(f'artificial image tensor shape: {images.shape}')
        return images


def get_initial_top_images_by_channel(activations_dict,
                                      model,
                                      channels,
                                      filename_top_bottom_prefix,
                                      get_attack_activations,
                                      nb_images_per_channel,
                                      imagenet_folder):
    print('Getting initial top images for each channel')

    # BY DEFAULT SAVE 1 IMAGE FOR EACH CHANNEL.
    init_top_image_indices_path = filename_top_bottom_prefix + "_image_indices_top.pt"

    print(f'Path for initial top Images: {init_top_image_indices_path}')
    if path.exists(init_top_image_indices_path):
        print('loading precalculated top images')
        init_top_indices = torch.load(init_top_image_indices_path)

    else:
        # Get the initial top images for the model
        print("Preexisting top images not found")
        data_loader = get_topk_dataset_loader(imagenet_dir=imagenet_folder)

        init_top_indices = get_topk_image_indices_by_channel(model, activations_dict, nb_images_per_channel,
                                                             channels, get_attack_activations,
                                                             imagenet_folder, data_loader)
        torch.save(init_top_indices, init_top_image_indices_path)
    print(f'init top image indices shape:{init_top_indices.shape}')
    init_top_images = get_images_from_indices(init_top_indices, nb_images_per_channel, imagenet_folder)
    print('shape of initial top images:\n', init_top_images.shape)
    return init_top_images, init_top_indices


def get_topk_image_indices_by_channel(model, activations_dict, k, channels, get_attack_activations, imagenet_folder,
                                      data_loader=None):
    # If no images are requested, return without doing anything else
    if k == 0:
        return
    with torch.no_grad():
        if data_loader is None:
            print('getting the dataloader inside the get_topk function')
            data_loader = get_topk_dataset_loader(imagenet_dir=imagenet_folder)
        model.eval()

        print('Beginning top image search!')
        for i, (data, _, batch_indices) in enumerate(tqdm(data_loader)):

            # print('batch indices\n', batch_indices.shape)
            # print(batch_indices[0:10])
            # print('dataset[0]\n', len(data_loader.dataset[0]))
            # print('dataset[1] and dataset[2]\n', data_loader.dataset[0][1], data_loader.dataset[0][2])

            # exit()
            data = data.to(_default_device)

            batch_indices = batch_indices.to(_default_device)
            # Populate the activations dict
            model(data)
            # Array of (channels, topk_indices). Need norm and index of the corresponding image
            # track index using i and batch_size

            # make a vector with the norms associated with each image we got activations for
            activations = activations_dict[FEATURE_NAME]

            # returns a (batch,channels) activation tensor
            activ_norms = get_attack_activations(activations)

            if i == 0:  # First batch setup

                # Grab the top indices for k entries in the norms
                # In general, dataset indices indicate that they are for accessing the data through the dataset object.
                # In the first case they are the same as the indices of the batch itself.
                # Later they will have to be tracked separately
                top_norms, top_dataset_indices = torch.topk(activ_norms, k=k, dim=0)

            else:

                # For the current batch, get the top norms and their indices
                batch_top_norms, batch_top_indices = torch.topk(activ_norms, k=k, dim=0)
                # Get the dataset indices corresponding to the batch_indices
                batch_top_dataset_indices = batch_indices[batch_top_indices].to(_default_device)


                # Need to stack the indices and norms we already have together, then sort and update the top ones
                norms_stack = torch.cat((top_norms, batch_top_norms))  # .detach()
                indices_stack = torch.cat((top_dataset_indices, batch_top_dataset_indices))

                # get the indices and values of the max norms
                top_norms, top_indices = torch.topk(norms_stack, k=k, dim=0)


                # gather the dataset indices from the indices stack
                top_dataset_indices = torch.gather(indices_stack, 0, top_indices)

        print('Top images found!')
        print('Top image activation norms:\n', top_norms.shape)
        print('Top indices shape :\n', top_dataset_indices.shape)
        return top_dataset_indices


# Helper function to ge the the top images from the indices
# TODO make this faster? May not be possible to vectorize :(
def get_images_from_indices(indices, num_top_images_per_channel, imagenet_folder=None, data_loader=None):
    # Get the dataset object associated with the dataloader
    if data_loader is None:
        dataset = get_topk_dataset_loader(imagenet_dir=imagenet_folder).dataset
    else:
        dataset = data_loader.dataset
        print(indices.shape, "indices' shapes")

    def grab_image(idx):
        return dataset[idx][0]

    all_images = []
    for nth_place in range(num_top_images_per_channel):
        images = []
        for index in indices[nth_place]:
            image = grab_image(index)
            # print('image shape:\n', image.shape)
            images.append(image)
        images_tensor = torch.stack(images)
        all_images.append(images_tensor)
    top_image_tensor = torch.stack(all_images)
    # images = [grab_image(index) for index in indices[0]]
    # print('images from the indices look like:\n', len(images))

    print('images from the indices look like:\n', top_image_tensor.shape)
    return top_image_tensor


# SHOULD BE 56.52% WITHOUT TRAINING.
def validate_model(model, do_full_run, num_batches=1, folder= '/data/imagenet_data'):
    print(f'Validating model! do_full_run={do_full_run}')
    test_set = get_test_dataloader(imagenet_dir=folder)
    # use TQDM?
    with torch.no_grad():
        # x = 1
        # with x as y:
        model.eval()
        total = 0
        correct = 0
        print('Performing validation!')
        for batch_id, (data, target) in enumerate(tqdm(test_set)):
            if not do_full_run and batch_id > num_batches:
                print(f'Reached batch limit for validation: {num_batches}')
                break  # To speed up training break after 100 batches,.
            # if batch_id % 10 == 0:
            # print(f'batch {batch_id} of validation batches starting')

            data, target = data.to(_default_device), target.to(_default_device)

            output = model(data).to(_default_device)
            _, predicted = torch.max(output.data, 1)
            total += target.shape[0]
            correct += (predicted == target).sum().item()

        accuracy = correct / total
    return accuracy


def get_maintain_loss_function(maintain_objective):
    if maintain_objective == 'softmax':
        print("Using Original Model Softmax Outputs as maintain objective")

        # print('Maintaining Softmax of outputs')
        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            # print((output - original_model(data)).norm())
            maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1))
            return maintain_loss

    # loss = F.binary_cross_entropy_with_logits(output, target.float())
    elif maintain_objective == 'kl-div':
        print("Using KL_Div as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            # Should we do F.log_softmax(output)?
            # maintain_loss = F.kl_div(output.log(), F.softmax(original_model(data), dim=1).detach())
            # Warning gets generated due to the defaults for kldiv
            maintain_loss = F.kl_div(F.log_softmax(output, dim=1), F.softmax(original_model(data), dim=1),
                                     reduction="batchmean")
            return maintain_loss
    elif maintain_objective == 'labels':
        print(f"Using labels (actual targets) as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss

    # TODO See about adding in an objective to specifically maintain the activations
    #  TODO: Try this out!.
    elif maintain_objective == 'activations':
        print(f"Using raw output (class scores) as maintain objective")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            original_model(data)
            maintain_loss = (act_dict[FEATURE_NAME] - og_act_dict[FEATURE_NAME]).square().mean()
            return maintain_loss
    else:
        print(f"maintain objective \'{maintain_objective}\' is not defined. Defaulting to  CE Loss against target")

        def get_maintain_loss(data=None, target=None, output=None, original_model=None, act_dict=None,
                              og_act_dict=None):
            maintain_loss = F.cross_entropy(output, target)
            return maintain_loss

    return get_maintain_loss


# Calculates the maintenance objective loss. Chooses which type based on input.
# def get_maintain_loss(data, maintain_objective, original_model, output, target):
#     if maintain_objective == 'softmax':
#         # print('Maintaining Softmax of outputs')
#         maintain_loss = F.cross_entropy(output, F.softmax(original_model(data), dim=1).detach())
#
#         # Purely for testing
#         # maintain_loss = torch.norm((output) - (original_model(data)))
#
#     # loss = F.binary_cross_entropy_with_logits(output, target.float())
#     else:
#         # print('Maintaining Accuracy')
#         maintain_loss = F.cross_entropy_with_logits(output, target)
#     return maintain_loss


# Returns a function that takes a batch of layer activations
# Calculates a norm for each image for each channel and returns a (n, tot_channels) shaped tensor
# Exact norm calculation is based on the attack objective (channels/center_neuron currently)
def get_attack_activations_function(attack_obj):
    if attack_obj == 'center_neuron':
        print('Using center neuron as attack objective')

        def attack_activation_func(feature_activations):
            shape = feature_activations.shape
            # attack_act = calc_norms(feature_activations[:, channels, shape[2] // 2, shape[3] // 2])
            attack_act = feature_activations[:, :, shape[2] // 2, shape[3] // 2].square()
            return attack_act
    elif attack_obj == 'channel':
        print('Using channel as attack objective')

        def attack_activation_func(feature_activations):
            attack_act = calc_norms_by_channel(torch.nn.functional.relu(feature_activations))
            return attack_act
    else:
        print("ATTACK_OBJ NOT SUPPORTED. DEFAULTING TO CHANNEL")

        def attack_activation_func(feature_activations):
            attack_act = calc_norms_by_channel(torch.nn.functional.relu(feature_activations))
            return attack_act
    return attack_activation_func

    # sum(torch.norm(activ, p=2) for activ in grab_activ(activations)) 
    # optim_loss = torch.zeros(1).to(_default_device)
    # for i in range(feature_activations.shape[0]):
    #     # Use only the channels we ask to use. There may be a better way to vectorize this*
    #     # * Maintaining functionality for [1,5,10] may be weird?
    #
    #     for channel in channels:
    #         #print('Channel: ', channel, type(channel))
    #         optim_loss = optim_loss + torch.linalg.matrix_norm(feature_activations[i, int(channel)])
    # # Scale down the optim loss to the AVERAGE loss of each channel and image
    # optim_loss = optim_loss / (feature_activations.shape[0]*len(channels)

    # Creates a dataloader with a training set.


def read_dataset(batch_size, num_workers=10,
                 imagenet_dir="/data/imagenet_data"):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])

    standard_train_transform = transforms.Compose([
        # transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize])

    train_dataset = datasets.ImageNet(imagenet_dir, split="train", transform=standard_train_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=num_workers,
                                               pin_memory=True,
                                               shuffle=True)
    return train_dataset, train_loader


def get_model_layers(model):
    """
    Returns a dict with layer names as keys and layers as values
    """
    assert hasattr(model, "_modules")
    layers = {}

    def get_layers(net, prefix=[]):
        for name, layer in net._modules.items():
            if layer is None:
                continue
            if len(layer._modules) != 0:
                get_layers(layer, prefix=prefix + [name])
            else:
                layers["_".join(prefix + [name])] = layer

    get_layers(model)
    return layers


# Helper function, takes a dict of the model layers, an optim_layer name to select which one,
# a features dict where the results are stored with key feature_name
# TODO Update this to generate name based on the name of the layer
def register_hooks(model_layers, features, optim_layer):
    def get_features(name):
        def hook(model, input, output):
            features[name] = output.clone()  # .detach()

        return hook

    model_layers[optim_layer[0]].register_forward_hook(get_features(FEATURE_NAME))


# Accepts an (n,c,h,w) or (n,d1,d2,...,dk) with k>=1 batch of image activations
# returns an (n,) vector populated with each image's activation magnitude
# Use for ALL activation magnitude calculations for consistency
def calc_norms(activations):
    flattened = activations.flatten(start_dim=1)
    # print(flattened)
    norms = flattened.square().mean(dim=1)
    return norms


# Returns a [batch_size, channels] tensor with the activation norm of each channel for each image in the batch
# Commented code for future use if I succeed in passing a [topn,channels] of images through the model.
# TODO remove this as it is in utils
def calc_norms_by_channel(activations):
    # print('activations shape:\n', activations.shape)
    flattened = activations.flatten(start_dim=2)
    norms = flattened.square().mean(dim=2)

    # print('flat shape:\n', flattened.shape)
    # test_flatten = activations.flatten(start_dim=-2)
    # print('flat2 shape:\n', test_flatten.shape)
    # print(flattened)

    # print('norms shape:\n', norms.shape)
    # norms2 = flattened.square().mean(dim=-1)
    # print('norms shape:\n', norms2.shape)
    return norms



def generate_single_artificial_image(model, layer, unit, root_path):
    image_side_length = 224
    do_fft = True
    param_f = lambda: param.image(image_side_length, fft=do_fft)
    save_image = True

    model_copy = copy.deepcopy(model)
    model_copy.eval()

    imgs_list = render.render_vis(model_copy, objectives.channel(layer, unit), param_f, save_image=save_image,
                                            show_image=False,
                                            image_name=f"{root_path}/target_{layer}_{unit}.png",
                                            progress=False,
                                            show_inline= False)
    #del model_copy

    return imgs_list



#CUDA_VISIBLE_DEVICES=0 python3 scripts/do_optimization_by_channel_circuit.py --nsteps 2 --batch_size 256 --save-interval 1001 --arch alexnet --imagenet /home/shared_data/imagenet --output alexnet/circuit_results/features_10_unit_beta_0 --layer features_10 --channel a --alpha 0.1 --beta 0 --optim-all --vis-obj channel --maintain-obj softmax --do-final-image-search --attack-name ref_to_tops_art --attack-type ref_to_tops_art --ref_target_path target.JPEG --learning-rate 1e-4 --attack-obj channel
#CUDA_VISIBLE_DEVICES=1 python3 circuits/analyse_circuit.py --results-directory resnet50/circuit_results/features_10_unit_0.05_beta_0.008 --pretrained --imagenet /home/shared_data/imagenet --arch resnet50 --batch_size 32 --cuda --layer layer1.0.conv2 --sparsity 0.5 --before
