#!/usr/bin/env python3
import sys

sys.path.append('./archived_scripts')
sys.path.insert(0, "/src")
from optimization_requirements import *

from argparse import ArgumentParser
from itertools import chain
from lucent.optvis import param, render
import torch
from torchvision import datasets, models, transforms
#from validate_with_imagenet import maintain_activation_imagenet_old as maintain_activation_imagenet


from activation_optimization_by_channel import activation_optimization
#import upload_to_gdrive
from torchvision.utils import save_image
import pandas as pd
import numpy as np
#from activation_optimization_by_channel import activation_optimization
import pickle
import PIL


VERSION = '0.9.0_TEST'
#TODO Get this import to work - but it's only needed for resnet
#from robustness import model_utils
#from robustness.datasets import ImageNet

parser = ArgumentParser()
parser.add_argument("--nsteps", default=1000, type=int)
parser.add_argument("--batch_size", default=256, type=int)
parser.add_argument("--nb_images_saved", default=10, type=int)
parser.add_argument("--nb_images_per_channel", default=10, type=int)
parser.add_argument("--save-interval", default=10, type=int)
parser.add_argument("--img-size", default=224, type=int)
parser.add_argument("--arch", default="vgg19", choices=model_names + ["resnet50-rob", "resnet18-rob", "efficientnet"])
parser.add_argument("--top-activations-neuron", default="alexnet/imagenetactiv")
parser.add_argument("--top-activations-channel", default="alexnet/imagenetactivpostprocessed")
parser.add_argument("--imagenet", default="/data/imagenet_data")

parser.add_argument("--output")
parser.add_argument("--layer", type=str, action="append")
parser.add_argument("--channel", type=str, action="append")
parser.add_argument("--alpha", type=float)
parser.add_argument("--optim-layers", action="append")
parser.add_argument("--optim-all", action="store_true")
parser.add_argument("--attack-obj", default="center-neuron",
                    choices=["center-neuron", "channel", "layer", "inverse-test"])
#parser.add_argument("--attack-data", default="top-image", choices=["top-image", "top10", "optimal-image"])
parser.add_argument("--vis-obj", default=None, choices=["center-neuron", "channel",]) #TODO, this should match optim_obj by default
parser.add_argument("--do-alpha-update", action="store_true")
parser.add_argument("--maintain-obj", default=None, choices=["labels", "softmax", "kl-div", "activations", "adv-rob"])

parser.add_argument("--maintain-activations-layer", action="store_true")
#parser.add_argument("--validate", default="accuracy", choices=["accuracy", "inverse-test"])
parser.add_argument("--title", default="<args.title>")
parser.add_argument("--save-image", action="store_true")
parser.add_argument("--inverse", action="store_true")

#parser.add_argument("--optimal-image-dir", type=str,
#                    help="Directory where the optimal images for the network are stored")

parser.add_argument("--track-training-params", action="store_true")
parser.add_argument("--optimizer-type", default='SGD')
parser.add_argument("--learning-rate", default=1e-8, type=float)
parser.add_argument("--do-full-validation", action="store_true")
parser.add_argument("--num-attack-images", default=1, type=int)
parser.add_argument("--do-final-image-search", action="store_true")

parser.add_argument("--acc-loss-threshold", default=0.005, type=float)

parser.add_argument("--do-continue_optim", action="store_true")
parser.add_argument("--first_p_channels", default=10000, type=int)
parser.add_argument("--stronger", action="store_true")
parser.add_argument("--no_whack-a-mole", action="store_true")
parser.add_argument("--attack-name", type=str, default="top_to_bottom")

parser.add_argument("--pretrained_path", default="pretrained_models/resnet18_l2_eps3.ckpt", type=str)

parser.add_argument("--attack-type", type=str)
parser.add_argument("--gen-artificial-images", action="store_true")
parser.add_argument("--do-fft", action="store_true")
parser.add_argument("--ref_target_path", default="", type = str)
parser.add_argument("--ref_target_path2", default="", type = str)


args = parser.parse_args()


# Some preprocessing of arguments
args.vis_obj = args.vis_obj or args.attack_obj

# Ensure all required directories exist
ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)

ensuredir(args.output)
ensuredir(os.path.join(args.output, "parameter_checkpoints"))
ensuredir(os.path.join(args.output, "optimal_images"))
ensuredir(os.path.join(args.output, "figures"))

# Save the arguments used to file for reference.

# Load model
if args.arch == 'vgg19':
    print('Using VGG19!')
    model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)

elif args.arch == "alexnet":
    print('using AlexNet!')
    model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
elif args.arch == "resnet50":
    print('using Resnet50!')
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
elif args.arch == "resnet18":
    print('using Resnet18!')
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
elif args.arch == "resnet34":
    print('using Resnet34!')
    model = models.resnet34(weights=models.resnet34)
elif args.arch =='efficientnet':
    print('using EfficientNet!')
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
else:
    print("Need to add the arch!!!")
    assert(False)



layers = get_model_layers(model)
output_folder = str(args.output)
# output_folder = '_layer_ ' + str(args.layer) + '_optim_layers_' \
#              + str(args.optim_layers) \
#             + '_channel_' + str(args.channel)

  # + output_folder
print('output folder: ', output_folder)
# print('channels: ', args.channel)

if args.first_p_channels == 10000: #Default Value... Alex's code
    try:
        args.first_p_channels = layers[args.layer[0]].out_channels
        out_channels = layers[args.layer[0]].out_channels
    except: ## For conv layers followed by batch normalization Geraldin's edit
        args.first_p_channels = layers[args.layer[0]].num_features
        out_channels = layers[args.layer[0]].num_features
    for step, channel in enumerate(args.channel):
        if channel == "a":
            args.channel.pop(step)
            layer = args.layer.pop(step)
            print('layer: ', layer)

            for j in range(out_channels):
                args.layer.append(layer)
                args.channel.append(j)
        else:
            args.channel[step] = int(channel)
else:
    print(f"---- We are attacking the first {args.first_p_channels} channels ")
    args.channel = list(range(args.first_p_channels))
    args.layer = args.layer*args.first_p_channels

# channels_for_optimization = args.channel
# Specify the regions that are being optimized
channels = {layer: [] for layer in np.unique(args.layer)}
for layer, channel in zip(args.layer, args.channel):
    channels[layer].append(channel)

# Get parameters
if args.optim_all:
    print('optim_all')
    params = model.parameters()
else:
    params = chain(get_layer_parameters(model, layer) for layer in args.optim_layers) if len(
        args.optim_layers) > 1 else get_layer_parameters(model, args.optim_layers[0])

# Create visualization objective


# Make an objective for each channel you are attacking
vis_obj_list = None
print('Defining artificial image objectives.')
if args.vis_obj == "center-neuron":
    vis_obj_list = []
    for c in args.channel:
        vis_obj_list.append(objectives.neuron(args.layer[c], args.channel[c]))
elif args.vis_obj == "channel" or args.vis_obj == "inverse-test":
    vis_obj_list = []
    for c in args.channel:
        vis_obj_list.append(objectives.channel(args.layer[c], args.channel[c]))
print('number of visualization objectives: ', len(vis_obj_list))
# Set the optimizer
optim_dict = {
    'lr': args.learning_rate,
    'momentum': 0.9,
    'weight_decay': 1e-4
}
#optimizer selection logic
if args.optimizer_type.lower() == "sgd":
    print("Using SGD Optimizer")
    optimizer = torch.optim.SGD(params, lr=optim_dict['lr'],
                                momentum=optim_dict['momentum'],
                                weight_decay=optim_dict['weight_decay'])
elif args.optimizer_type.lower() == "adam":
    print("Using Adam Optimizer")
    optimizer = torch.optim.Adam(params, lr=optim_dict['lr'])
elif args.optimizer_type.lower() == "custom_adam":
    optimizer = torch.optim.Adam(
        [ {"params": param, "lr": optim_dict["lr"]*10} if args.layer[0].replace("_", ".") + "." in name else {"params": param} for name, param in model.named_parameters()
        ],
        lr= optim_dict["lr"]
    )
else:
    print("Unrecognized Optimizer type. Defaulting to Adam")
    optimizer = torch.optim.Adam(params, lr=optim_dict['lr'])

#Set the targeted reference image
ref_target = None if len(args.ref_target_path)==0 else PIL.Image.open(args.ref_target_path).convert('RGB')
ref_target2 = None if len(args.ref_target_path2)==0 else PIL.Image.open(args.ref_target_path2).convert('RGB')

# Run the optimization
print('Performing the optimization!')

#print(channels)
#TODO Wrap these return values into a dict
id_run = f"{args.attack_obj}_{args.arch}_{args.attack_type}_{args.num_attack_images}"

results_dict = activation_optimization(
    args.attack_obj,
    model,
    params,
    args.nsteps,
    args.save_interval,
    args.img_size,
    output_folder,
    batch_size = args.batch_size,
    nb_images_saved = max(args.nb_images_saved, args.num_attack_images),
    nb_images_per_channel = max(args.nb_images_per_channel, args.num_attack_images),
    channels=args.channel,
    save_image=args.save_image,
    optimizer=optimizer,
    feature_layer=args.layer,
    save_params=args.track_training_params,
    do_full_validation=args.do_full_validation,
    alpha=args.alpha,
    use_tqdm=False,
    maintain_objective_type=args.maintain_obj,
    optim_dict=optim_dict,
    num_attack_images=args.num_attack_images,
    do_final_top_image_search=args.do_final_image_search,
    vis_obj_list=vis_obj_list,
    attack_type=args.attack_type,
    data_folder = args.imagenet,
    id_run = id_run,
    continue_optim = args.do_continue_optim,
    stronger = args.stronger,
    first_p_channels = args.first_p_channels, 
    no_whack_a_mole = args.no_whack_a_mole,
    attack_name = args.attack_name,
    do_alpha_update=args.do_alpha_update,
    acc_loss_threshold = args.acc_loss_threshold,
    ref_target = ref_target,
    ref_target2 = ref_target2,
    pretrained_path = args.pretrained_path
    )
print('Optimization complete!')


with open(f'{output_folder}/results_dict.pkl', 'wb') as f:
    pickle.dump(results_dict, f)

save_args(args, optim_dict, VERSION)
print(f'Script complete! Saving results to {output_folder}')