import torch

from argparse import ArgumentParser
import os.path as path
import os
from torchvision.utils import save_image
import torchvision.transforms as transforms

import pandas as pd
from optimization_requirements import (
    visualize_objectives_v3,
    save_top_class_info,
    get_class_labels_dict,
)
import metrics_analysis
from torchvision import datasets, models, transforms

from utils import get_results_dataloader

from tqdm import tqdm
import matplotlib.pyplot as plt
from scripts.analyze_activations_by_channel import (
    get_overall_kt_results,
    load_in_activations_from_paths,
    get_kt_from_vectors,
)

# import get_imagenet_activations_by_channel

import torch.nn.functional as F


# Should be a version of this in results generator
def ensure_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def make_visualization(
    top_indices_path_i,
    top_indices_path_f,
    cos_ii_path,
    cos_if_path,
    kt_ii_path,
    kt_if_path,
    layer,
    channel,
    data_loader,
    val_data_loader,
    output,
    do_validation=False,
    top_indices_path_val_i=None,
    top_indices_path_val_f=None,
):
    if do_validation:
        fig_size = (15, 12)
        n_image_sets = 4
        final_top_indices_val = torch.load(top_indices_path_val_f).to("cpu")
        init_top_indices_val = torch.load(top_indices_path_val_i).to("cpu")
    else:
        fig_size = (15, 8)
        n_image_sets = 2
    height_ratios = [5] * n_image_sets

    imagenet_data = data_loader.dataset
    imagenet_data_val = val_data_loader.dataset

    final_top_indices = torch.load(top_indices_path_f).to("cpu")

    init_top_indices = torch.load(top_indices_path_i).to("cpu")

    clip_sims_ii = torch.load(cos_ii_path).to("cpu")

    clip_sims_if = torch.load(cos_if_path).to("cpu")

    # init_to_init_tau = torch.load( kt_ii_path)

    kt_if = torch.load(kt_if_path)

    final_mean = clip_sims_if.view(clip_sims_if.shape[0], clip_sims_if.shape[1], -1).mean(dim=-1)

    init_mean = clip_sims_ii.view(clip_sims_if.shape[0], clip_sims_if.shape[1], -1).mean(dim=-1)
    # replace with mine
    clip_dscore = metrics_analysis.calc_clip_d(clip_sims_ii, clip_sims_if)

    # df_clip_dscore = pd.DataFrame({"clip_dscore": clip_dscore, "ktau": kt_if.diagonal()})

    # df_clip_dscore = df_clip_dscore.sort_values("ktau")

    # to be customised

    K = 5
    ncols, nrows = K, 2
    fig = plt.figure(constrained_layout=True, figsize=fig_size)

    fig.suptitle(
        (
            f"Channel {channel} of {layer}: Kendall-"
            + r"$\tau$: "
            + f"%.3f" % kt_if[channel, channel]
            + ", CLIP-"
            + r"$\delta$: "
            + f"%.3f" % clip_dscore[channel]
        ).replace("_", " "),
        fontsize=20,
        fontweight="bold",
    )

    subfigs = fig.subfigures(nrows=n_image_sets, ncols=1, height_ratios=height_ratios)

    classes_initial = [
        get_class_labels_dict()[imagenet_data[ind][1]].split(",")[0] for ind in init_top_indices[channel, :K]
    ]

    classes_final = [
        get_class_labels_dict()[imagenet_data[ind][1]].split(",")[0] for ind in final_top_indices[channel, :K]
    ]

    for row, subfig in enumerate(subfigs):
        if row == 0:
            subfig.suptitle(f"Initial top", fontsize=18)
            axarr = subfig.subplots(nrows=1, ncols=ncols, gridspec_kw={"wspace": 0, "hspace": -1})  # -.88})

            axarr = axarr.flatten()

            inds_top = [index for index in init_top_indices[channel, :] if index != 237080]  # removing repetitions

            for i, ind in enumerate(inds_top[:K]):
                # print(imagenet_data[ind][0].shape)
                color_text = "black"

                class_text = get_class_labels_dict()[imagenet_data[ind][1]].split(",")[0]

                if class_text in classes_final:
                    color_text = "blue"

                axarr[i].imshow(imagenet_data[ind][0].permute(1, 2, 0))

                axarr[i].set_title(class_text, color=color_text, fontsize=13, fontweight="bold")

                axarr[i].axis("off")

        if row == 1:
            subfig.suptitle(f"Final top", fontsize=18, fontweight="bold")

            axar = subfig.subplots(nrows=1, ncols=ncols, gridspec_kw={"wspace": 0, "hspace": 0})  # -.88})

            axar = axar.flatten()

            for i, ind in enumerate(final_top_indices[channel, :K]):
                # print(imagenet_data[ind][0].shape)
                color_text = "black"

                class_text = get_class_labels_dict()[imagenet_data[ind][1]].split(",")[0]

                if class_text in classes_initial:
                    color_text = "blue"

                axar[i].imshow(imagenet_data[ind][0].permute(1, 2, 0))

                axar[i].set_title(
                    get_class_labels_dict()[imagenet_data[ind][1]].split(",")[0],
                    color=color_text,
                    fontsize=13,
                    fontweight="bold",
                )

                axar[i].axis("off")
        if row == 2:
            subfig.suptitle(f"Initial Validation top", fontsize=18)
            axar = subfig.subplots(nrows=1, ncols=ncols, gridspec_kw={"wspace": 0, "hspace": 0})  # -.88})

            axar = axar.flatten()
            for i, ind in enumerate(init_top_indices_val[channel, :K]):
                # print(imagenet_data[ind][0].shape)
                color_text = "black"

                class_text = get_class_labels_dict()[imagenet_data_val[ind][1]].split(",")[0]

                if class_text in classes_initial:
                    color_text = "blue"

                axar[i].imshow(imagenet_data_val[ind][0].permute(1, 2, 0))

                axar[i].set_title(get_class_labels_dict()[imagenet_data_val[ind][1]].split(",")[0], color=color_text)

                axar[i].axis("off")
        if row == 3:
            subfig.suptitle(f"Final Validation top", fontsize=18)

            axar = subfig.subplots(nrows=1, ncols=ncols, gridspec_kw={"wspace": 0, "hspace": 0})  # -.88})

            axar = axar.flatten()
            for i, ind in enumerate(final_top_indices_val[channel, :K]):
                # print(imagenet_data[ind][0].shape)
                color_text = "black"
                class_text = get_class_labels_dict()[imagenet_data_val[ind][1]].split(",")[0]

                if class_text in classes_initial:
                    color_text = "blue"

                axar[i].imshow(imagenet_data_val[ind][0].permute(1, 2, 0))

                axar[i].set_title(get_class_labels_dict()[imagenet_data_val[ind][1]].split(",")[0], color=color_text)

                axar[i].axis("off")
    ensure_dir(output)
    fig.savefig(f"{output}/{layer}_channel_{channel}.pdf")
    print("saving to: ", f"{output}/{layer}_channel_{channel}.pdf")
    plt.close()


attack_results_directories = [
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_c0_only/results/metrics",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c1_top10_to_zero/results/metrics",
    "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v13/single_channel/f10_c2_top10_to_zero/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f0_dataset_top10_to_zero/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f3_dataset_top10_to_zero_a01/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f6_dataset_top10_to_zero/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f8_dataset_top10_to_zero/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_top10_to_zero/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_dataset_refs_to_top/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/alexnet/v12/f10_c0_only/results/metrics",
    #     "/home/a_fuller/projects/Attacking-Interpretability/efficientnet/v12/f7_0_b3_0_dataset_top10_to_zero/results/metrics",
]
attack_names = [
    "Conv 5 C0 Only",
    "Conv 5 C1 Only",
    "Conv 5 C2 Only",
    # "Conv1_Push-Down",
    # "Conv2_Push-Down",
    # "Conv3_Push-Down",
    # "Conv4_Push-Down",
    # "Conv5_Push-Down",
    # "Conv5_Push-Up",
    # "Conv5_Single_Channel",
    # "Efficientnet_f7_b3"
]

channel_lists = [
    [0],
    [1],
    [2],
    # [1],
    # [33],
    # [78],
    # [56],
    # [2, 3, 9, 15, 46, 170],
    # [125, 135],
    # [0],
    # [13,15,46]
]

do_validation_list = [
    False,
    False,
    False,
    False,
    True,
    True,
    False,
    True,
]

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--output", default="visualizations", type=str)
    args = parser.parse_args()
    data_loader = get_results_dataloader(split="train")
    data_loader_val = get_results_dataloader(split="val")
    for i, channels in enumerate(channel_lists):
        metrics_folder = attack_results_directories[i]
        print("generating for attack:", attack_names[i])
        # print(metrics_folder)
        for channel in channels:
            print("generating for channel: ", channel)
            make_visualization(
                channel=channel,
                layer=attack_names[i],
                output=args.output,
                data_loader=data_loader,
                val_data_loader=data_loader_val,
                top_indices_path_i=path.join(
                    metrics_folder,
                    "init_top_indices.pt",
                ),
                top_indices_path_f=path.join(
                    metrics_folder,
                    "final_top_indices.pt",
                ),
                cos_ii_path=path.join(
                    metrics_folder,
                    "full_clip_ii.pt",
                ),
                cos_if_path=path.join(
                    metrics_folder,
                    "full_clip_if.pt",
                ),
                kt_if_path=path.join(
                    metrics_folder,
                    "kt_if.pt",
                ),
                kt_ii_path=path.join(
                    metrics_folder,
                    "kt_ii.pt",
                ),
                top_indices_path_val_f=path.join(
                    metrics_folder,
                    "final_top_indices_val.pt",
                ),
                top_indices_path_val_i=path.join(
                    metrics_folder,
                    "init_top_indices_val.pt",
                ),
                do_validation=False,
            )
