import os
import sys
from pathlib import Path
from argparse import ArgumentParser

import torch
from matplotlib import pyplot as plt
import matplotlib as mlp
import math

from torchvision import models, transforms as T
import numpy as np
import pandas as pd
from lucent.optvis import objectives, render, param
from time import time


import PIL

TORCH_VERSION = torch.__version__

from tqdm import tqdm

from copy import deepcopy
sys.path.append("circuit_explorer")

sys.path.append("scripts")

from utils import get_train_dataloader_Subset


from circuit_explorer.score import actgrad_kernel_score, snip_score, compute_activations #, act_grad_filter_score

from circuit_explorer.mask import mask_from_scores, apply_mask, setup_net_for_mask
from torchmetrics import SpearmanCorrCoef, PearsonCorrCoef, KendallRankCorrCoef


#todo make it work for networks different from AlexNet
def main(args, results_directory, init_model, model, device = "cpu", 
         layers = ["features_10", "features_8", "features_6", "features_3", "features_0"]):
    #Dataloader 
    train_set_loader = get_train_dataloader_Subset(imagenet_dir=args.imagenet, batch_size=args.batch_size)

    #Loading the model
    init_model.to(device)

    init_scores = snip_score(init_model, train_set_loader, args.layer, args.unit)
    #Recompute scores per kernel for structure pruning
    init_scores = {key: init_scores[key].flatten(start_dim = 2).mean(axis = -1) for key in init_scores}

    init_init_activations = compute_activations(init_model, args.layer, args.unit, train_set_loader)
    
    #values, indices = torch.sort(init_scores[args.layer].flatten(start_dim = 1).mean(axis = -1), descending=True)
    ranks_init = {}

    for layer in init_scores:
        values, indices = torch.sort(init_scores[layer].flatten(start_dim = 1).mean(axis = -1), descending=True)
        ranks_init[layer] = indices
    
    
    # generate_results_artificial_images(args, init_model, init_scores, ranks_init,
    #                                      None, before_or_after = "before")
    
    model.to(device)

    final_scores = snip_score(model, train_set_loader, args.layer, args.unit)
    #Recompute scores per kernel for structure pruning
    final_scores = {key: final_scores[key].flatten(start_dim = 2).mean(axis = -1) for key in final_scores}
    
    ranks_final = {}

    for layer in final_scores:
        values, indices = torch.sort(final_scores[layer].flatten(start_dim = 1).mean(axis = -1), descending=True)
        ranks_final[layer] = indices


    init_final_activations = compute_activations(model, args.layer, args.unit, train_set_loader)

    # generate_results_artificial_images(args, model, final_scores, ranks_init,
    #                                    ranks_final, before_or_after = "after")
    


    sparsity_array =  np.linspace(0.1, 1, 10)
    corr_list_final_init, corr_list_init_init, list_masks_init, list_masks_final = get_correlation_circuit(train_set_loader, init_init_activations, 
                                                        init_final_activations, init_scores, final_scores,
                                                        ranks_init, ranks_final,
                                                          init_model, model, sparsity_array, args)

    dico_KT = {"layers": layers + [args.layer]}
    for layer_t in dico_KT["layers"]:
        
        dico_KT[layer_t.replace("bn", "conv")] = KendallRankCorrCoef()(init_scores[layer_t.replace("bn","conv")].flatten(start_dim = 1).mean(axis = -1),
                                                  final_scores[layer_t.replace("bn","conv")].flatten(start_dim=1).mean(axis = -1))
        
    dico_circuits_results = {'sparsity': sparsity_array, 'corr_final_init': corr_list_final_init,
                              'corr_init_init': corr_list_init_init, 'dico_KT' : dico_KT,
                              "list_masks_init": list_masks_init, "list_masks_final": list_masks_final,
                              "init_scores": init_scores, "final_scores": final_scores}
    
    torch.save(dico_circuits_results, f"{args.results_directory}/{args.layer}_{args.unit}/circuit_results.pt")
    # if args.before:
    #     #final_score_updated = 
    #     #mask_updated = mask_from_scores(final_score_updated,sparsity = sparsity, model = final_model,unit=unit,target_layer=layer)
    #     #apply_mask(final_model, mask_updated)
    #     init_model.to(device)
    #     scores = snip_score(init_model, train_set_loader, args.layer, args.unit)
    #     scores_per_kernel = {key: scores[key].flatten(start_dim = 2).mean(axis = -1) for key in scores}
    #     mask = mask_from_scores(scores_per_kernel, sparsity = args.sparsity, model = init_model,unit=args.unit,target_layer=args.layer)
    #     apply_mask(model,mask)
    #     for target_layer in layers[::-1]:
    #         if args.residual:
    #             target_layer_activation = target_layer.replace("conv", "bn")
    #         else:
    #             target_layer_activation = target_layer
    #         print("Generation for layer ", target_layer)
    #         ensuredir(os.path.join(results_directory, "circuit_before", f"sparsity_{args.sparsity}", target_layer))
    #         generate_artificial_images(init_model, target_layer_activation.replace(".", "_"), f"{results_directory}/circuit_before/sparsity_{args.sparsity}/{target_layer}",
    #                                 scores[target_layer.replace("_",".").replace("bn", "conv")].flatten(start_dim = 1).mean(axis = -1), args)
    #     #df[f"init_{target_layer.replace('_','.')}"] = scores[target_layer.replace("_",".")].mean(axis = -1).detach().cpu().numpy()
    # else :
    #     model = deepcopy(init_model)
    #     model.to(device)
    #     #print(args.checkpoint_path)
    #     if args.checkpoint_path.split(".")[-1] == "pt":
    #         model.load_state_dict(torch.load(f"{results_directory}/final_model.pt"))
    #     else:
    #         load_model_with_unmatched_names(model, torch.load(f"{results_directory}/{args.checkpoint_path}")["state_dict"])
    #     model.eval()
    #     scores = snip_score(model,train_set_loader,args.layer,args.unit)
    #     #print(scores_bis["features.0"].shape)
    #     scores_per_kernel = {key: scores[key].flatten(start_dim = 2).mean(axis = -1) for key in scores}
    #     mask = mask_from_scores(scores_per_kernel, sparsity = args.sparsity, model = model, unit=args.unit, target_layer=args.layer)
    #     apply_mask(model,mask)
    #     for target_layer in layers[::-1]:
    #         if args.residual:
    #             target_layer_activation = target_layer.replace("conv", "bn")
    #         else:
    #             target_layer_activation = target_layer
    #         print("Generation for layer ", target_layer)
    #         ensuredir(os.path.join(results_directory, "circuit_after", f"sparsity_{args.sparsity}", target_layer))
    #         generate_artificial_images(model, target_layer_activation.replace(".", "_"), f"{results_directory}/circuit_after/sparsity_{args.sparsity}/{target_layer}",
    #                                 scores[target_layer.replace("_",".").replace("bn", "conv")].flatten(start_dim = 1).mean(axis = -1), args)
    #     #df[f"init_{target_layer.replace('_','.')}"] = scores_bis[target_layer.replace("_",".")].mean(axis = -1).detach().cpu().numpy()
    #df.to_csv(f"{results_directory}/scores_circuits.csv")

def generate_results_artificial_images(args, model, scores, init_ranks, final_ranks, before_or_after = "before"):
    #This function computes attribution scores of channels on the circuit and generate artificial images of top channels

    mask = mask_from_scores(scores, sparsity = args.sparsity, model = model,unit=args.unit,target_layer=args.layer)
    apply_mask(model, mask)

    previous_layers = get_preceding_conv_layers_names(model, args.layer)
    
    for target_layer in tqdm(previous_layers[::-1] + [args.layer]):
            if args.residual:
                target_layer_activation = target_layer.replace("conv", "bn")
            else:
                target_layer_activation = target_layer

            if before_or_after == "after":
                previous_topK = init_ranks[target_layer.replace("bn","conv")][:args.K]
                topK = final_ranks[target_layer.replace("bn", "conv")][:args.K]
            else:
                previous_topK =  final_ranks[target_layer.replace("bn", "conv")][:args.K]
                topK = init_ranks[target_layer.replace('bn',"conv")][:args.K]
            print("Generation for layer ", target_layer)
            ensuredir(os.path.join(args.results_directory,f"{args.layer}_{args.unit}" , f"circuit_{before_or_after}", f"sparsity_{round(args.sparsity,3)}", target_layer))
            generate_artificial_images(model, target_layer_activation.replace(".", "_"), f"{results_directory}/{args.layer}_{args.unit}/circuit_{before_or_after}/sparsity_{round(args.sparsity,3)}/{target_layer}",
                                    scores[target_layer.replace("_",".").replace("bn", "conv")].flatten(start_dim = 1).mean(axis = -1), args, topK, previous_topK, before_or_after)

    setup_net_for_mask(model)

def get_preceding_conv_layers_names(network, target_layer_name):
    preceding_conv_layers_names = []
    found_target_layer = False

    def find_preceding_conv_layers(module, current_name):
        nonlocal found_target_layer

        for name, sub_module in module.named_children():
            full_name = f"{current_name}.{name}" if current_name else name

            if found_target_layer:
                break

            if full_name == target_layer_name:
                found_target_layer = True
                break

            if isinstance(sub_module, torch.nn.Conv2d):
                preceding_conv_layers_names.append(full_name)

            find_preceding_conv_layers(sub_module, full_name)

    find_preceding_conv_layers(network, "")

    if not found_target_layer:
        raise ValueError(f"Target layer '{target_layer_name}' not found in the network.")

    return preceding_conv_layers_names


def generate_artificial_images(model, layer, root_path, scores, args, topK, previous_topK = None, before_or_after = "after"):
    image_side_length = 224
    do_fft = True
    param_f = lambda: param.image(image_side_length, fft=do_fft)
    save_image = True
    if args.layer == layer or args.layer == layer.replace("_", "."):
        unit = args.unit
        render.render_vis(model, objectives.channel(layer, unit), param_f, save_image=save_image,
                                            show_image=False,
                                            image_name=f"{root_path}/score_{round(scores[unit].item(),2)}_unit_{unit}.png",
                                            progress=False,
                                            show_inline= False)
    else:
        for unit in tqdm(topK):
            render.render_vis(model, objectives.channel(layer, unit), param_f, save_image=save_image,
                                                show_image=False,
                                                image_name=f"{root_path}/score_{round(scores[unit].item(),2)}_unit_{unit}.png",
                                                progress=False,
                                                show_inline= False)

        if previous_topK!= None:
            for unit in tqdm(previous_topK):
                render.render_vis(model, objectives.channel(layer, unit), param_f, save_image=save_image,
                                                    show_image = False,
                                                    image_name = f"{root_path}/score_{round(scores[unit].item(),2)}_unit_{unit}.png",
                                                    progress = False,
                                                    show_inline = False)
                
def load_model_with_unmatched_names(model, state_dict):
    model_dict = model.state_dict()

    # Create a mapping between the keys in the saved state_dict and the keys in the current model
    key_mapping = {}
    for saved_key in state_dict.keys():
        new_key = saved_key.replace("module.", "")  # Remove "module." prefix if present
        if new_key in model_dict:
            key_mapping[saved_key] = new_key

    # Use the mapping to load the state_dict
    for saved_key, new_key in key_mapping.items():
        model_dict[new_key] = state_dict[saved_key]

    # Load the modified state_dict into the model
    model.load_state_dict(model_dict)

def get_preceding_conv_layers_names(network, target_layer_name):
    preceding_conv_layers_names = []
    found_target_layer = False

    def find_preceding_conv_layers(module, current_name):
        nonlocal found_target_layer

        for name, sub_module in module.named_children():
            full_name = f"{current_name}.{name}" if current_name else name

            if found_target_layer:
                break

            if full_name == target_layer_name:
                found_target_layer = True
                break

            if isinstance(sub_module, torch.nn.Conv2d):
                preceding_conv_layers_names.append(full_name)

            find_preceding_conv_layers(sub_module, full_name)

    find_preceding_conv_layers(network, "")

    if not found_target_layer:
        raise ValueError(f"Target layer '{target_layer_name}' not found in the network.")

    return preceding_conv_layers_names

def get_correlation_circuit(train_set_loader, init_init_activations, init_final_activations, 
                            init_scores, final_scores, ranks_init, ranks_final,
                            init_model, final_model, sparsity_array, args, K = 5):
    corr_list_init = []
    corr_list_final = []

    list_masks_init, list_masks_final = [], []

    spatial_dim = init_init_activations.shape[1]*init_init_activations.shape[2]

    tic = time()

    for sparsity in tqdm(sparsity_array):
        print(f"-Generation for spasity {sparsity}")
        mask = mask_from_scores(init_scores,sparsity = sparsity, model = init_model,unit=args.unit,target_layer=args.layer)
        list_masks_init.append(deepcopy(mask))

        apply_mask(init_model,mask)
        masked_activations = compute_activations(init_model, args.layer, args.unit, train_set_loader)
        corr_coef = PearsonCorrCoef(num_outputs=spatial_dim)(masked_activations.flatten(start_dim=1),
                                                              init_init_activations.flatten(start_dim=1)).median()
        corr_list_init.append(corr_coef)

        #Generating artificial images on the circuit
        args.K = K
        args.sparsity = sparsity
        if args.intermediate_artificial:
            generate_results_artificial_images(args, init_model, init_scores, ranks_init, ranks_final,
                                            before_or_after = "before")
        
        setup_net_for_mask(init_model)

        mask = mask_from_scores(final_scores,sparsity = sparsity, model = final_model,unit=args.unit,target_layer=args.layer)
        list_masks_final.append(deepcopy(mask))

        apply_mask(final_model,mask)
        masked_activations = compute_activations(final_model, args.layer, args.unit, train_set_loader)
        corr_coef = PearsonCorrCoef(num_outputs=spatial_dim)(masked_activations.flatten(start_dim=1),
                                                              init_final_activations.flatten(start_dim=1)).median()
        corr_list_final.append(corr_coef)

        #Generating artificial images on the circuit
        if args.intermediate_artificial:
            generate_results_artificial_images(args, final_model, final_scores, ranks_init, ranks_final,
                                            before_or_after = "after")
        setup_net_for_mask(final_model)
        
        print(f"time for spasirty {sparsity}: ", time()-tic)

    print("Time for computing corr results and intermediate artificial images: ", time() - tic)
    
    return corr_list_init, corr_list_final, list_masks_init, list_masks_final

        


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--results-directory", type=str)
    parser.add_argument("--pretrained", action="store_true")
    parser.add_argument("--imagenet", type = str, default="/home/shared_data/imagenet")
    parser.add_argument("--arch", type = str, default="alexnet")
    parser.add_argument("--batch_size", type = int, default = 256)
    parser.add_argument('--cuda',
                    action='store_true')
    parser.add_argument("--layer", default="features.10", type = str)
    parser.add_argument("--checkpoint_path", default="final_model.pt", type = str)
    parser.add_argument("--unit", default = 0, type = int)
    parser.add_argument("--sparsity", default = 0.1, type = float)
    parser.add_argument("--residual", action="store_true")
    parser.add_argument("--before", action="store_true")
    parser.add_argument("--K", default= 50, type = int )
    parser.add_argument("--intermediate-artificial", action="store_true")

    args = parser.parse_args()

    if args.arch == 'vgg19':
        print('Using VGG19!')
        model = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)

    elif args.arch == "alexnet":
        print('using AlexNet!')
        model = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
    elif args.arch == "resnet50":
        print('using Resnet50!')
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    elif args.arch == "resnet18":
        print('using Resnet18!')
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    elif args.arch == "resnet34":
        print('using Resnet34!')
        model = models.resnet34(weights=models.resnet34)
    elif args.arch =='efficientnet':
        print('using EfficientNet!')
        model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)

    results_directory = args.results_directory

    ensuredir = lambda directory: Path(directory).mkdir(parents=True, exist_ok=True)

    ensuredir(os.path.join(results_directory, f"{args.layer}_{args.unit}" ,f"circuit_before"))
    ensuredir(os.path.join(results_directory, f"{args.layer}_{args.unit}",f"circuit_after"))

    previous_layers = get_preceding_conv_layers_names(model, args.layer.replace("_","."))
    
    final_model = deepcopy(model)
    #print(args.checkpoint_path)
    if args.checkpoint_path.split(".")[-1] == "pt":
        print("Loading the final model!!!")
        final_model.load_state_dict(torch.load(f"{results_directory}/final_model.pt"))
    else:
        print("Loading the model with unmached layers????")
        load_model_with_unmatched_names(final_model, torch.load(f"{results_directory}/{args.checkpoint_path}")["state_dict"])

    print(previous_layers, args.unit)

    main(args, f"{results_directory}", model, final_model, layers= previous_layers, device = "cuda" if args.cuda else "cpu")


#commandline: 
#CUDA_VISIBLE_DEVICES=1 python3 circuits/analyse_circuit.py --results-directory alexnet/synthetic/features_10/single_channel_target_target2 --pretrained --imagenet /home/shared_data/imagenet --arch alexnet --batch_size 32 --cuda