#!/usr/bin/env python3
from optimization_requirements import *
from argparse import ArgumentParser
from itertools import chain
from lucent.optvis import param, render
from utils import *
from validate_with_imagenet import maintain_activation_imagenet_old as maintain_activation_imagenet
from activation_optimization_original import activation_optimization

VERSION = '1.0.3'

parser = ArgumentParser()
parser.add_argument("--nsteps", default=1000, type=int)
parser.add_argument("--save-interval", default=100, type=int)
parser.add_argument("--img-size", default=224, type=int)
parser.add_argument("--arch", default="vgg19", choices=model_names)
parser.add_argument("--top-activations-neuron", default="ceph/imagenetactiv")
parser.add_argument("--top-activations-channel", default="ceph/imagenetactivpostprocessed")
parser.add_argument("--imagenet", default="/data/imagenet_data")
parser.add_argument("--output")
parser.add_argument("--layer", type=str, action="append")
parser.add_argument("--channel", type=str, action="append")
parser.add_argument("--alpha", type=float)
parser.add_argument("--optim-layers", action="append")
parser.add_argument("--optim-all", action="store_true")
parser.add_argument("--optim-obj", default="center-neuron", choices=["center-neuron", "channel", "layer", "inverse-test"])
parser.add_argument("--optim-data", default="top-image", choices=["top-image", "top10", "optimal-image"])
parser.add_argument("--vis-obj", default=None, choices=["center-neuron", "channel"])
parser.add_argument("--maintain-obj", default=None, choices=["labels", "softmax", "activations"])
parser.add_argument("--maintain-activations-layer", action="store_true")
parser.add_argument("--validate", default="accuracy", choices=["accuracy", "inverse-test"])
parser.add_argument("--title", default="<args.title>")
parser.add_argument("--save-image", action="store_true")
parser.add_argument("--inverse", action="store_true")
parser.add_argument("--optimal-image-dir", type=str, help="Directory where the optimal images for the network are stored")
args = parser.parse_args()

# Some preprocessing of arguments
args.vis_obj = args.vis_obj or args.optim_obj

# Ensure all required directories exist
ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)
ensuredir(args.output)
ensuredir(os.path.join(args.output, "parameter_checkpoints"))
ensuredir(os.path.join(args.output, "optimal_image_checkpoints"))
ensuredir(os.path.join(args.output, "figures"))

# Load model
model = models.__dict__[args.arch](pretrained=(not args.inverse))
model.to(device).eval()
layers = get_model_layers(model)

# output_folder = '_layer_ ' + str(args.layer) + '_optim_layers_' \
#              + str(args.optim_layers) \
#             + '_channel_' + str(args.channel)

output_folder = str(args.output) # + output_folder
print('output folder: ', output_folder)
print('channels: ', args.channel)

for i, channel in enumerate(args.channel):
    if channel == "a":
        args.channel.pop(i)
        layer = args.layer.pop(i)
        for j in range(layers[layer].out_channels):
            args.layer.append(layer)
            args.channel.append(j)
    else:
        args.channel[i] = int(channel)
print('after loop: ', args.channel)

# Specify the regions that are being optimized
objective_regions = [(layer,channel,[i]) if args.optim_data != "top10" else (layer,channel,list(range(10*i,10*(i+1)))) for i, (layer,channel) in enumerate(zip(args.layer, args.channel))]
channels = {layer: [] for layer in np.unique(args.layer)}
for layer, channel in zip(args.layer, args.channel):
    channels[layer].append(channel)

# Some debugging
print('Objective Regions: ', objective_regions)

# Get parameters
if args.optim_all:
    print('optim_all')
    params = model.parameters()
else:
    params = chain(get_layer_parameters(model, layer) for layer in args.optim_layers) if len(args.optim_layers) > 1 else get_layer_parameters(model, args.optim_layers[0])
    
# Create optimization objective
sign = -1 if args.inverse else 1
if args.optim_obj == "center-neuron":
    def grab_activ(activations):
        output = []
        for layer,channel,index in objective_regions:
            nrows, ncols = activations[layer].shape[2:]
            output.append(activations[layer][index,channel,nrows // 2, ncols // 2])
        return output
    
    def grab_activ_maintain(activations):
        output = []
        for layer,channel,_ in objective_regions:
            nrows, ncols = activations[layer].shape[2:]
            output.append(activations[layer][:,channel,nrows // 2, ncols // 2])
        return output
        #return [activations[layer][index, channel, nrows // 2, ncols // 2] for layer,channel,index in objective_regions]
elif args.optim_obj == "channel" or args.optim_obj == "inverse-test":
    def grab_activ(activations):
        #output = []
        #for layer,channel,index in objective_regions:
        #   print("LOOKHERE",layer,channel,index)
        #    output.append(activations[layer][index,channel])
        #return output
        return [activations[layer][index, channel] for layer,channel,index in objective_regions]
    def grab_activ_maintain(activations):
        return [activations[layer][:, channel] for layer,channel,_ in objective_regions]
elif args.optim_obj == "layer":
    raise NotImplementedError("Layers don't line up with the layer,channel scheme")
    def grab_activ(activations):
        return [activations[layer] for layer in channels.keys()]
def optim_obj(activations):
    #print("ACTIVATIONS SHAPE", activations[args.layer[0]].shape)
    activ = grab_activ(activations)
    return sign * sum(torch.norm(activ, p=2) for activ in grab_activ(activations))
    #return sum(torch.norm(activ[layer][index, channel], p=2) for layer, channel, index in objective_regions)

if args.optim_obj == "inverse-test":
    _imagenet_dir = "/data/imagenet_data"
    _normalize = transforms.Normalize(
                            mean = [0.485, 0.456, 0.406],
                            std = [0.229, 0.224, 0.225])
    _standard_train_transform = transforms.Compose([
                                transforms.RandomResizedCrop(224),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                _normalize])
    _train_dataset = datasets.ImageNet(_imagenet_dir, split="train", transform=_standard_train_transform)
    _train_loader = torch.utils.data.DataLoader(_train_dataset,
                                                batch_size=100,
                                                num_workers=10,
                                                pin_memory=True,
                                                shuffle=True)
    _train_iter = iter(_train_loader)
    prev_optim_activs = None
    prev_optim_norms = None
    beta = 0.5
    def optim_obj(activations):
        imgs, _ = _train_iter._next_data()
        imgs = imgs.to(device)
        optim_activ = grab_activ(activations)
        model(imgs)
        suboptim_activ = grab_activ_maintain(activations)
        optim_norms = [torch.norm(activ, p=2) for activ in optim_activ]
        suboptim_norms = [torch.norm(activ, p=2) for activ in suboptim_activ]
        optim_lens = map(len, optim_activ)
        suboptim_lens = map(len, suboptim_activ)
        first_term = [torch.square(p - o) for o,p in zip(optim_norms, prev_optim_norms)]
        second_term = [torch.nn.functional.relu((ol / sl) * s - o) for o,s,ol,sl in zip(optim_norms, suboptim_norms, optim_lens, suboptim_lens)]
        return sign * sum(beta * first + (1 - beta) * second for first,second in zip(first_term, second_term))
    

#sum(torch.sum(torch.norm(layer, p=2, dim=1)) for layer in grab_activ(activations))
#optim_obj = torch.norm(grab_activ(activations), p=2)

# Create visualization objective
if args.vis_obj == "center-neuron":
    vis_obj = objectives.neuron(args.layer[0], args.channel[0])
    #TODO: Get the below working
    #vis_obj = sum(objectives.neuron(args.layer, channel, batch=i) for i, el in enumerate(args.channel))
elif args.vis_obj == "channel" or args.vis_obj == "inverse-test":
    vis_obj = objectives.channel(args.layer[0], args.channel[0])
    #TODO: Get the below working
    #vis_obj = sum(objectives.channel(args.layer, channel) for i, channel in enumerate(args.channel))

# Set optimization data
dataset = datasets.ImageNet(args.imagenet, split="train")
grab_image = lambda index: dataset[index][0]

def map_over(func, arr):
   if type(arr) != list:
       return func(arr)
   return [map_over(func, el) for el in arr]

def multistack(arr):
   if type(arr) != list:
       return arr
   return torch.stack([multistack(el) for el in arr], dim=0)

# TODO: Fix args.channel to be channels[layer]
if args.optim_obj == "center-neuron":
    #nrows, ncols = indices_neuron[args.layer].shape[2:]
    top_image_indices = np.concatenate([indices_neuron[layer][:, channels[layer], indices_neuron[layer].shape[2] // 2, indices_neuron[layer].shape[3] // 2] for layer in channels.keys()], axis=1)
elif args.optim_obj == "channel" or args.optim_obj == "inverse-test":
    top_image_indices = np.concatenate([indices_channel[layer][:, channels[layer]] for layer in channels.keys()], axis=1)

if args.optim_data == "top-image":
    top_image_index = top_image_indices[0].tolist()
    top_image = map_over(grab_image, top_image_index)
    top_image = map_over(standard_transform, top_image)
    optim_data = multistack(top_image).to(device) #TODO: Test this in console, might need to do np.stack instead (or even recursively)
    print("OPTIM DATA", optim_data)
    print("OPTIM DATA SHAPE", optim_data.shape)
    #top_image = standard_transform(top_image).to(device)
    #optim_data = top_image[None, ...]

elif args.optim_data == "top10":
    top_images = map_over(grab_image, top_image_indices)
    top_images = map_over(standard_transform, top_image_indices) 
    #top_images = [dataset[index][0] for index in top_image_indices]
    #top_images = [standard_transform(image) for image in top_images]
    optim_data = multistack(top_images).to(device)

elif args.optim_data == "optimal-image":
    if len(args.channel) > 1:
        raise NotImplementedError("Need to figure out how to extract arrays of optimal images and use them in optimization")
    print("Generating optimal image")
    param_f = lambda: param.image(128)
    optimal_image = render.render_vis(model, vis_obj, param_f, progress=True)[0]
    optim_data = torch.from_numpy(optimal_image.transpose([0,3,1,2])).to(device)

# TODO: Look at this
if args.optim_obj == "inverse-test":
    set_model_activations(model, require_grad=True)
    model(optim_data)
    prev_optim_activs = [activ.detach() for activ in grab_activ(model.activations)]
    prev_optim_norms = [torch.norm(activ, p=2).detach() for activ in prev_optim_activs]

# Create validation callback
if args.validate == "accuracy":
    validate = validate_with_imagenet(model, batch_size = 100, num_workers = 10)
elif args.validate == "inverse-test":
    validate = activations_callback_over_imagenet_inverse_test(model, prev_optim_activs, callback = grab_activ_maintain)
elif args.validate is None:
    validate = None

# Create maintain objective
if args.maintain_obj in ["softmax", "activations"]:
    orig_model = models.__dict__[args.arch](pretrained=True)
    orig_model.to(device).eval()
if args.maintain_obj == "labels":
    maintain_obj = maintain_obj_imagenet(model)
elif args.maintain_obj == "softmax":
    maintain_obj = maintain_softmax_imagenet(model, orig_model)
elif args.maintain_obj == "activations":
    if args.maintain_activations_layer:
        activ_func = lambda activations: [activations[layer] for layer in channels.keys()]
        #maintain_obj = maintain_activation_imagenet(model, orig_model, activ_func=activ_func)
    else:
        activ_func = grab_activ_maintain
        #def activ_criterion(activations, orig_activations):
        #    activ = grab_activ(activations)
        #    orig_activ = grab_activ(orig_activations)
        #    return sum(criterion(activ[layer][index,channel], orig_activ[layer][index,channel]) for (layer,channel,index) in objective_regions)
        #maintain_obj = maintain_activation_imagenet(model, orig_model, activ_criterion=activ_criterion)
    maintain_obj = maintain_activation_imagenet(model, orig_model, activ_func=activ_func)
elif args.maintain_obj is None:
    maintain_obj = None

# Set the optimizer //Currently this is not used
optimizer = lambda params: torch.optim.SGD(params,
        lr=1e-5,
        momentum=0.9,
        weight_decay=5e-4)

# Run the optimization
print('Performing the optimization!')
optimal_images, objective_values, accuracy, maintain_values = activation_optimization(
        optim_obj,
        vis_obj,
        optim_data,
        model,
        params,
        args.nsteps,
        args.save_interval,
        args.img_size,
        output_folder,
        save_image = args.save_image,
        # optimizer = optimizer, # Force use of default Adam optimizer
        maintain_obj = maintain_obj,
        callback = validate,
        alpha = args.alpha,
        use_tqdm = False)
print('Optimization complete!')
# Generate a line plot of the results
print('Visualizing objectives!')
visualize_objectives(
        objective_values,
        accuracy,
        args.save_interval,
        filename = os.path.join(output_folder, "figures", "objectives_visual.png"),
        title=args.title,
        inline=False)
print('Visualization complete!')

# TODO: Remove this
import pandas as pd
objectives_data = pd.DataFrame({"objective_values": objective_values, "accuracy": accuracy})
objectives_data.to_csv(os.path.join(output_folder, "figures", "objectives_data.csv"))

# save a .txt file with the config info
save_args(args, VERSION)

# upload the results, barring optimization checkpoints, to google drive:

