#!/usr/bin/env python3
import sys

sys.path.append('./archived_scripts')
sys.path.insert(0, "/src")
from optimization_requirements import *
from argparse import ArgumentParser
from itertools import chain
from lucent.optvis import param, render
import torch
from torchvision import datasets, models, transforms
from validate_with_imagenet import maintain_activation_imagenet_old as maintain_activation_imagenet
from activation_optimization_full_training import activation_optimization
import upload_to_gdrive
from torchvision.utils import save_image
import pandas as pd



VERSION = '0.5.1_STABLE'

parser = ArgumentParser()
parser.add_argument("--nsteps", default=1000, type=int)
parser.add_argument("--save-interval", default=10, type=int)
parser.add_argument("--img-size", default=224, type=int)
parser.add_argument("--arch", default="vgg19", choices=model_names)
parser.add_argument("--top-activations-neuron", default="alexnet/imagenetactiv")
parser.add_argument("--top-activations-channel", default="alexnet/imagenetactivpostprocessed")
parser.add_argument("--imagenet", default="/data/imagenet_data")
parser.add_argument("--output")
parser.add_argument("--layer", type=str, action="append")
parser.add_argument("--channel", type=str, action="append")
parser.add_argument("--alpha", type=float)
parser.add_argument("--optim-layers", action="append")
parser.add_argument("--optim-all", action="store_true")
parser.add_argument("--attack-obj", default="center-neuron",
                    choices=["center-neuron", "channel", "layer", "inverse-test"])
parser.add_argument("--attack-data", default="top-image", choices=["top-image", "top10", "optimal-image"])
parser.add_argument("--vis-obj", default=None, choices=["center-neuron", "channel"]) #TODO, this should match optim_obj by default
parser.add_argument("--maintain-obj", default=None, choices=["labels", "softmax", "kl-div"])
parser.add_argument("--maintain-activations-layer", action="store_true")
#parser.add_argument("--validate", default="accuracy", choices=["accuracy", "inverse-test"])
parser.add_argument("--title", default="<args.title>")
parser.add_argument("--save-image", action="store_true")
parser.add_argument("--inverse", action="store_true")
parser.add_argument("--optimal-image-dir", type=str,
                    help="Directory where the optimal images for the network are stored")
parser.add_argument("--track-training-params", action="store_true")
parser.add_argument("--num-top-images", default=0, type=int)
parser.add_argument("--optimizer-type", default='SGD')
parser.add_argument("--learning-rate", default=1e-8, type=float)
parser.add_argument("--do-full-validation", action="store_true")
parser.add_argument("--num-attack-images", default=1, type=int)
parser.add_argument("--do-final-image-search", action="store_true")
args = parser.parse_args()


# Some preprocessing of arguments
args.vis_obj = args.vis_obj or args.attack_obj

# Ensure all required directories exist
ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)

ensuredir(args.output)
ensuredir(os.path.join(args.output, "parameter_checkpoints"))
ensuredir(os.path.join(args.output, "optimal_image_checkpoints"))
ensuredir(os.path.join(args.output, "figures"))

# Save the arguments used to file for reference.

# Load model
if args.arch == 'vgg19':
    print('Using VGG19!')
    model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
else:
    print('using AlexNet!')
    model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)

layers = get_model_layers(model)
output_folder = str(args.output)
# output_folder = '_layer_ ' + str(args.layer) + '_optim_layers_' \
#              + str(args.optim_layers) \
#             + '_channel_' + str(args.channel)

  # + output_folder
print('output folder: ', output_folder)
# print('channels: ', args.channel)

for i, channel in enumerate(args.channel):
    if channel == "a":
        args.channel.pop(i)
        layer = args.layer.pop(i)
        print('layer: ', layer)

        for j in range(layers[layer].out_channels):
            args.layer.append(layer)
            args.channel.append(j)
    else:
        args.channel[i] = int(channel)
# print('after loop the channels:\n', args.channel)
# channels_for_optimization = args.channel
# Specify the regions that are being optimized

channels = {layer: [] for layer in np.unique(args.layer)}
for layer, channel in zip(args.layer, args.channel):
    channels[layer].append(channel)

# Get parameters
if args.optim_all:
    print('optim_all')
    params = model.parameters()
else:
    params = chain(get_layer_parameters(model, layer) for layer in args.optim_layers) if len(
        args.optim_layers) > 1 else get_layer_parameters(model, args.optim_layers[0])

# Create visualization objective
if args.vis_obj == "center-neuron":
    vis_obj = objectives.neuron(args.layer[0], args.channel[0])
    # TODO: Get the below working
    # vis_obj = sum(objectives.neuron(args.layer, channel, batch=i) for i, el in enumerate(args.channel))
elif args.vis_obj == "channel" or args.vis_obj == "inverse-test":
    vis_obj = objectives.channel(args.layer[0], args.channel[0])
    # TODO: Get the below working
    # vis_obj = sum(objectives.channel(args.layer, channel) for i, channel in enumerate(args.channel))



# Set the optimizer
optim_dict = {
    'lr': args.learning_rate,
    'momentum': 0.9,
    'weight_decay': 1e-4
}
#optimizer selection logic
if args.optimizer_type.lower() == "sgd":
    print("Using SGD Optimizer")
    optimizer = torch.optim.SGD(params, lr=optim_dict['lr'],
                                momentum=optim_dict['momentum'],
                                weight_decay=optim_dict['weight_decay'])
elif args.optimizer_type.lower() == "adam":
    print("Using Adam Optimizer")
    optimizer = torch.optim.Adam(params, lr=optim_dict['lr'])
else:
    print("Unrecognized Optimizer type. Defaulting to Adam")
    optimizer = torch.optim.Adam(params, lr=optim_dict['lr'])

# Run the optimization
print('Performing the optimization!')
#print('The channels:')
#print(channels)
optimal_images, attack_objective_values, accuracy, maintain_values, top_images = activation_optimization(
    args.attack_obj,
    vis_obj,
    model,
    params,
    args.nsteps,
    args.save_interval,
    args.img_size,
    output_folder,
    num_top_images=args.num_top_images,
    channels=args.channel,
    save_image=args.save_image,
    optimizer=optimizer,
    feature_layer=args.layer,
    save_params=args.track_training_params,
    do_full_validation=args.do_full_validation,
    alpha=args.alpha,
    use_tqdm=False,
    maintain_objective_type=args.maintain_obj,
    optim_dict=optim_dict,
    num_attack_images=args.num_attack_images,
    do_final_top_image_search=args.do_final_image_search,
    )
print('Optimization complete!')
# Generate a line plot of the results
print('Visualizing objectives!')
visualize_objectives_v2(
    attack_objective_values,
    maintain_values,
    accuracy,
    args.save_interval,
    filename=os.path.join(output_folder, "figures", "objectives_visual.png"),
    title=f"lr: {args.learning_rate}, alpha: {args.alpha}, optim: {args.optimizer_type}, maintain obj: {args.maintain_obj}",
    inline=False)
print('Visualization complete!')


objectives_data = pd.DataFrame({"accuracy": accuracy,
                                "attack_objective_values": attack_objective_values,
                                "maintain_values": maintain_values})
objectives_data.to_csv(os.path.join(output_folder, "figures", "objectives_data.csv"))

save_args(args, optim_dict, VERSION)
# Save the top images, both init and final:
# Note that this presumes that the initial top images are already calculated.
# If you want to redo them, replace (step+1) with step and uncomment the first get_topk_images
for step, image_batch in enumerate(top_images):

    # Special case where we only look at initial and final optimal images
    if len(top_images) == 2:
        output_destinations = [output_folder + f'/init_top', output_folder + f'/final_top']
        ensuredir(output_destinations[step])
        output_destination = output_destinations[step]
    elif len(top_images) == 1:
        output_destinations = [output_folder + f'/final_top']
        ensuredir(output_destinations[step])
        output_destination = output_destinations[step]
    # General case for tracking optimal images at every step
    else:
        output_destination = output_folder + f'/step_{step * args.save_interval}_top'
        ensuredir(output_destination)
    #for i in range(args.num_top_images):
        # save_image(initial_top_images[i], output_folder+f'/init_top/img{i+1}.png')
        # save_image(image_batch[i], output_destination+f'/img{i+1}.png')
    save_image(image_batch, output_destination+'/top_images.png', nrow=5)

# upload the results, barring optimization checkpoints, to google drive:

print(f'Script complete! Saving results to {output_folder}')
