#!/usr/bin/env python3
from scripts.optimization_requirements import *
from argparse import ArgumentParser
from itertools import chain
from lucent.optvis import param, render
from archived_scripts.objective_regions_dataloader import ObjectiveRegionsDataloader

parser = ArgumentParser()
parser.add_argument("--nsteps", default=1000, type=int)
parser.add_argument("--save-interval", default=100, type=int)
parser.add_argument("--img-size", default=224, type=int)
parser.add_argument("--arch", default="vgg19", choices=model_names)
parser.add_argument("--top-activations-neuron", default="ceph/imagenetactiv")
parser.add_argument("--top-activations-channel", default="ceph/imagenetactivpostprocessed")
parser.add_argument("--imagenet", default="/data/imagenet_data")
parser.add_argument("--output", default='ceph/optimization_output')
parser.add_argument("--layer", type=str, action="append")  #How do I input layers?
parser.add_argument("--channel", type=str,
                    # action="append",
                    default="all")
parser.add_argument("--alpha", type=float, default=0.01)
parser.add_argument("--optim-type", default="SGD",
                    #choices=optim_names   # Not declared anywhere.
                    )
parser.add_argument("--optim-lr", type=float, default=1e-5)
parser.add_argument("--optim-kwargs", type=str, default="{'momentum':0.9, 'weight_decay':5e-4}")
parser.add_argument("--optim-layers", action="append")
parser.add_argument("--optim-all", action="store_true")
parser.add_argument("--optim-obj", default="center-neuron", choices=["center-neuron", "channel", "layer"])
parser.add_argument("--optim-data", default="top-image", choices=["top-image", "top10", "optimal-image"])
parser.add_argument("--vis-obj", default=None, choices=["center-neuron", "channel"])
parser.add_argument("--maintain-obj", default="labels", choices=["labels", "softmax", "activations"])
parser.add_argument("--maintain-activations-layer", action="store_true")
parser.add_argument("--validate", default=None, choices=["accuracy"])
parser.add_argument("--title", default="<args.title>")
parser.add_argument("--save-image", action="store_true")
parser.add_argument("--inverse", action="store_true")
parser.add_argument("--optimal-image-dir", type=str, help="Directory where the optimal images for the network are stored")
parser.add_argument("--optim-batch", type=int, default = 64, help="Number of objective_regions per batch")
parser.add_argument("--divide-lr-by-objectives", action="store_true", help="Whether to divide the learning rate by the number of objectives that are being summed together -- helps to normalize the scale of the gradient")
parser.add_argument("--testing", type=bool, default=False)
args = parser.parse_args()

# Some preprocessing of arguments
args.vis_obj = args.vis_obj or args.optim_obj
args.optim_kwargs = eval(args.optim_kwargs)


if args.testing:
    print('Testing!')

# Ensure all required directories exist
ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)
ensuredir(args.output)
ensuredir(os.path.join(args.output, "parameter_checkpoints"))
ensuredir(os.path.join(args.output, "optimal_image_checkpoints"))
ensuredir(os.path.join(args.output, "figures"))

# Load model
model = models.__dict__[args.arch](pretrained=(not args.inverse))
model.to(device)
layers = get_model_layers(model)

for a_layer in layers:
    print('layer:', a_layer)

if args.testing:
    print('channel: ' , args.channel)
    print('args.channel == all', args.channel =='all')

# The original for loop
# for (i, channel) in enumerate(args.channel):

#     if channel == "all":  # This never happens as enumerate will split the string 'all' in 'a', 'l', 'l'
#         args.channel.pop(i)
#         layer = args.layer.pop(i)
#         for j in range(layers[layer].out_channels):
#             args.layer.append(layer)
#             args.channel.append(j)
#     else:
#         args.channel[i] = int(channel)


for (i, channel) in enumerate(args.channel):
    #if args.testing:
    #    print('i: ', i)
    #    print('channel:', channel)
    if args.channel == "all":
        args.channel.pop(i)
        layer = args.layer.pop(i)
        for j in range(layers[layer].out_channels):
            args.layer.append(layer)
            args.channel.append(j)
    else:
        args.channel[i] = int(channel)

# Specify the regions that are being optimized
objective_regions = list(zip(args.layer, args.channel)) 
# objective_regions = [(layer,channel,[i]) if args.optim_data != "top10" else (layer,channel,list(range(10*i,10*(i+1)))) for i, (layer,channel) in enumerate(zip(args.layer, args.channel))]
channels = {layer: [] for layer in np.unique(args.layer)}
for layer, channel in zip(args.layer, args.channel):
    channels[layer].append(channel)

# Some debugging
print(objective_regions)

# Get parameters
if args.optim_all:
    params = model.parameters()
else:
    params = chain(get_layer_parameters(model, layer) for layer in args.optim_layers) if len(args.optim_layers) > 1 else get_layer_parameters(model, args.optim_layers[0])

# Set optimization data
dataset = datasets.ImageNet(args.imagenet, split = "train")
grab_image = lambda index: dataset[index][0]

if args.optim_obj == "center-neuron":
    top_image_indices = [get_center(indices_neuron[layer][:, channel]).tolist() for layer, channel in objective_regions]
    #top_image_indices = np.concatenate([indices_neuron[layer][:, channels[layer], nrows // 2, ncols // 2] for layer in channels.keys()], axis=1)
elif args.optim_obj == "channel":
    top_image_indices = [indices_channel[layer][:, channel].tolist() for layer, channel in objective_regions]
    #top_image_indices = np.concatenate([indices_channel[layer][:, channels[layer]] for layer in channels.keys()], axis=1)

if args.optim_data == "top-image":
    top_image_index = [[indices[0]] for indices in top_image_indices]
    #top_image_indices[0].tolist()
    optim_data = map_over(grab_image, top_image_index)
    optim_data = map_over(standard_transform, optim_data)
    #optim_data = multistack(top_image).to(device)

elif args.optim_data == "top10":
    optim_data = map_over(grab_image, top_image_indices)
    optim_data = map_over(standard_transform, optim_data) 
    optim_data = list(map(torch.stack, optim_data))
    # optim_data = multistack(top_images).to(device)

elif args.optim_data == "optimal-image":
    if len(args.channel) > 1:
        raise NotImplementedError("Need to figure out how to extract arrays of optimal images and use them in optimization")
    print("Generating optimal image")
    param_f = lambda: param.image(128)
    optimal_image = render.render_vis(model, vis_obj, param_f, progress=True)[0]
    optim_data = torch.from_numpy(optimal_image.transpose([0,3,1,2])).to(device)

optim_dataloader = ObjectiveRegionsDataloader(objective_regions, optim_data, batch_size = args.optim_batch)

# Create optimization objective
sign = -1 if args.inverse else 1
if args.optim_obj == "center-neuron":
    def grab_activ(activations, obj_regions):
        output = []
        for layer,channel,index in obj_regions:
            nrows, ncols = activations[layer].shape[2:]
            output.append(activations[layer][index,channel,nrows // 2, ncols // 2])
        return output
    
    def grab_activ_maintain(activations):
        output = []
        for layer,channel in objective_regions:
            nrows, ncols = activations[layer].shape[2:]
            output.append(activations[layer][:,channel,nrows // 2, ncols // 2])
        return output
elif args.optim_obj == "channel":
    def grab_activ(activations, obj_regions):
        return [activations[layer][index, channel] for layer,channel,index in obj_regions]
    def grab_activ_maintain(activations):
        return [activations[layer][:, channel] for layer,channel in objective_regions]
elif args.optim_obj == "layer":
    raise NotImplementedError("Layers don't line up with the layer,channel scheme")
    grab_activ = lambda activations, obj_regions: [activations[layer] for layer in channels.keys()]
    grab_activ_maintain = grab_activ

def optim_obj(rate = 1.0):
    total_loss = 0.0
    for obj_regions, imgs in optim_dataloader:
        model(imgs)
        activ = grab_activ(model.activations, obj_regions)
        loss = sign * sum(torch.norm(region_activ, p=2) for region_activ in activ)
        (rate * loss).backward()
        total_loss = total_loss + loss.detach().cpu().numpy()
    return total_loss

# Create visualization objective
if args.vis_obj == "center-neuron":
    vis_obj = objectives.neuron(args.layer[0], args.channel[0])
    #TODO: Get the below working
    #vis_obj = sum(objectives.neuron(args.layer, channel, batch=i) for i, el in enumerate(args.channel))
elif args.vis_obj == "channel":
    vis_obj = objectives.channel(args.layer[0], args.channel[0])
    #TODO: Get the below working
    #vis_obj = sum(objectives.channel(args.layer, channel) for i, channel in enumerate(args.channel))

# Create validation callback
if args.validate == "accuracy":
    validate = validate_with_imagenet(model, batch_size = 100, num_workers = 10)
elif args.validate is None:
    validate = None

# Create maintain objective
if args.maintain_obj in ["softmax", "activations"]:
    orig_model = models.__dict__[args.arch](pretrained=True)
    orig_model.to(device).eval()
if args.maintain_obj == "labels":
    maintain_obj = maintain_obj_imagenet(model)
elif args.maintain_obj == "softmax":
    maintain_obj = maintain_softmax_imagenet(model, orig_model)
elif args.maintain_obj == "activations":
    if args.maintain_activations_layer:
        activ_func = lambda activations: [activations[layer] for layer in channels.keys()]
    else:
        activ_func = grab_activ_maintain
    loss = torch.nn.MSELoss()
    def criterion(activations, orig_activations):
        activ_vals = activ_func(activations)
        orig_activ_vals = activ_func(orig_activations)
        return sum(loss(activ, orig_activ) for activ, orig_activ in zip(activ_vals, orig_activ_vals))
        #return sum(torch.norm(activ - orig_activ, p=2) for activ, orig_activ in zip(activ_vals, orig_activ_vals))
    maintain_obj = maintain_activation_imagenet(model, orig_model, criterion)

# Set the optimizer
#optimizer = lambda params: torch.optim.SGD(params,
#        lr=1e-5 / len(objective_regions) if args.divide_lr_by_objectives else 1e-5,
#        momentum=0.9,
#        weight_decay=5e-4)

args.optim_lr = args.optim_lr / len(objective_regions) if args.divide_lr_by_objectives else args.optim_lr
optimizer = lambda params: torch.optim.__dict__[args.optim_type](params,
        lr = args.optim_lr,
        **args.optim_kwargs)

# Run the optimization
optimal_images, objective_values, accuracy, maintain_values = activation_optimization(
        model = model,
        params = params,
        output_folder = args.output,
        optim_obj = optim_obj,
        maintain_obj = maintain_obj,
        vis_obj = vis_obj,
        callback = validate,
        image_side_length = args.img_size,
        nsteps = args.nsteps,
        save_interval = args.save_interval,
        alpha = args.alpha,
        optimizer = optimizer,
        device = device,
        use_tqdm = False,
        save_image = args.save_image)

# Generate a line plot of the results
visualize_objectives(
        objective_values,
        accuracy,
        args.save_interval,
        filename = os.path.join(args.output, "figures", "objectives_visual.png"),
        title=args.title,
        inline=False)
