from pathlib import Path
from argparse import ArgumentParser

import torch
from matplotlib import pyplot as plt
from torchvision import models, transforms as T
import numpy as np
from lucent.optvis import objectives, transform
from lucent.optvis.render import ModuleHook
from lucent.optvis.render import hook_model
from lucent.optvis.param.color import to_valid_rgb
from lucent.misc.io import show
from lucent.optvis import render, param

import PIL

TORCH_VERSION = torch.__version__

from tqdm import tqdm

from copy import deepcopy
import utils

from activation_optimization_by_channel import *




def main(args, feature_id, features_id_previous, channel_id = 0, device = "cuda:0"):
    #get_attack_activations = get_attack_activations_function("channel")
    model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    model.to(device)
    model.eval()

    root_path = args.results_directory
    #root_path = "alexnet/art_sim/features_10_joint_ball"
    final_path = f"{root_path}/final_model.pt"

    final_model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    final_model.load_state_dict(torch.load(final_path))

    final_model.to(device).eval()
    
    w_previous, w_order = model.features[feature_id].weight[channel_id].flatten(start_dim=1).sum(dim=-1).sort(descending=True)
    w_order = w_order.detach().cpu().numpy()
    
    generate_artificial(model, path = f"{root_path}/previous_layer_initial", layer = f"features_{features_id_previous}", w_order = w_order)

    generate_artificial(final_model, path = f"{root_path}/previous_layer_final", layer = f"features_{features_id_previous}", w_order = w_order)

    generate_artificial(model, path = f"{root_path}/current_layer_initial", layer = f"features_{features_id}")
    
    generate_artificial(final_model, path = f"{root_path}/current_layer_final", layer = f"features_{features_id}")


    return None

def generate_artificial(model, path, layer, w_order = None):
    model_layers = get_model_layers(model)

    vis_obj_list = [objectives.channel(layer, channel) for channel in  range(model_layers[layer].out_channels)]
    save_image = True
    print('Generating Artificial Optimal Images')
    print(f'save_image: {save_image}')
    param_f = lambda: param.image(224, fft=True)

    images = []
    model_copy = copy.deepcopy(model)
    model_copy.eval()

    if w_order is not None:
        w_order = {k : f"order_{i}_" for i,k in enumerate(w_order)}
    else:
        w_order = {i:"" for i in range(len(vis_obj_list))}
    for channel_num, vis in enumerate(tqdm(vis_obj_list)):
        # copy model each time as we know gradients are messy with Lucent
        model.zero_grad()
        model.eval()
        # if channel_num % 50 == 0:
        #     print(f'generating artificial image {channel_num} of {len(vis_obj_list)}')
        
        image_name = os.path.join(path, f"{w_order[channel_num]}channel_{channel_num}.png")
        image = render.render_vis(model_copy, vis, param_f, save_image=save_image,
                                      show_image=False,
                                      image_name= image_name,
                                      progress=False,
                                    )   
            
        # convert and permute due to how render generates the image which can then be vstacked easily
        #images.append(torch.from_numpy(image[0]).permute([0, 3, 1, 2]))
        # Returns a [num_channels,3,224,224] tensor
        #images = torch.vstack(images).unsqueeze(0)
        #print(f'artificial image tensor shape: {images.shape}')
    return None

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)

    args = parser.parse_args()
    results_directory = args.results_directory

    VERSION = '0.2.0'

    features_id = 10 #put this to args
    features_id_previous = 8

    ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)

    ensuredir(os.path.join(results_directory, "previous_layer_initial"))
    ensuredir(os.path.join(results_directory, "current_layer_initial"))

    ensuredir(os.path.join(results_directory, "previous_layer_final"))
    ensuredir(os.path.join(results_directory, "current_layer_final"))
    main(args, features_id, features_id_previous)
