#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
from torchvision.transforms.functional import gaussian_blur
from xai.attribution import mmbs, ig, gig, xrai, gradcam, smoothgrad
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem, FashionMnistProblem
from xai.evaluation import calc_deletion_curve, calc_auc
import tifffile
from tqdm import tqdm
from pathlib import Path
import argparse
from functools import partial

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run MMBS on imagenet.")
    parser.add_argument("--gpu", type=int, default=0, help="Index of the GPU to use.")
    parser.add_argument("--network_name", help="Name of the network architecture.")
    parser.add_argument("--start_index", type=int, default=0, help="Index to start with")
    parser.add_argument("--end_index", type=int, default=1000, help="Index to end on")
    args = parser.parse_args()

    auc_steps = 200
    ig_steps = 256
    mmbs_steps = 8
    mmbs_samples = 1024
    mmbs_sg_samples = 128

    if args.network_name == "Fashion":
        problem = FashionMnistProblem(
            weights_path="fashion_mnist_weights_paper.pt",
            data_path=get_fashion_mnist_data_path(),
            device = f"cuda:{args.gpu}")
    else:
        problem = ImageNetValProblem(
            data_path = get_imagenet_val_data_path(),
            network_name = args.network_name,
            num_per_class=1,
            class_step=1,
            class_offset=0,
            device = f"cuda:{args.gpu}")


    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    audcs_file_name = experiment_path / "audcs.csv"
    curves_folder = experiment_path / "curves"
    curves_folder.mkdir()
    heatmaps_folder = experiment_path / "heatmaps"
    heatmaps_folder.mkdir()

    heatmap_names = [
        "mmbs",
        "ig",
        "gig_paper",
        "gig_saliency",
        "mmbs_sg",
        "ig_sg",
        "gig_paper_sg",
        "gig_saliency_sg",
        "xrai",
        "xrai_bl",
        "random"
        ]
    if problem.get_gradcam_layer() is not None:
        heatmap_names.append("gradcam")

    with open(audcs_file_name, "w") as audcs_file:
        audcs_file.write("label,"+",".join(heatmap_names)+"\n")


    for index in tqdm(range(args.start_index, args.end_index)):
        img, label = problem.get_sample(index)

        bl = torch.zeros_like(img)
        imputation = ConstantImputation(bl)
        heatmap_random = torch.randn_like(img)
        heatmap_mmbs = mmbs(problem.model, img, label, imputation, mmbs_steps, mmbs_samples, progress_bar=False)
        heatmap_ig = ig(problem.model, img, label, bl, ig_steps)
        heatmap_gig_paper = gig(problem.model, img, label, bl, ig_steps)
        heatmap_gig_saliency = gig(problem.model, img, label, bl, ig_steps, fraction=0.25, max_dist=0.02)

        heatmap_mmbs_sg = smoothgrad(
            partial(mmbs, model=problem.model, label=label, imputation=imputation, num_steps=mmbs_steps, num_paths=mmbs_sg_samples, progress_bar=False),
            img)
        heatmap_ig_sg = smoothgrad(
            partial(ig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps),
            img)
        heatmap_gig_paper_sg = smoothgrad(
            partial(gig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps),
            img)
        heatmap_gig_saliency_sg = smoothgrad(
            partial(gig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps, fraction=0.25, max_dist=0.02),
            img)

        heatmap_xrai = xrai(problem.model, img, label)
        heatmap_xrai_bl = xrai(problem.model, img, label, [bl])

        heatmaps = [
            heatmap_mmbs,
            heatmap_ig,
            heatmap_gig_paper,
            heatmap_gig_saliency,
            heatmap_mmbs_sg,
            heatmap_ig_sg,
            heatmap_gig_paper_sg,
            heatmap_gig_saliency_sg,
            heatmap_xrai,
            heatmap_xrai_bl,
            heatmap_random]

        gradcam_layer = problem.get_gradcam_layer()
        if gradcam_layer is not None:
            heatmap_gradcam = gradcam(problem.model, img, label, gradcam_layer)
            heatmaps.append(heatmap_gradcam)
            heatmap_names.append("gradcam")

        for heatmap, heatmap_name in zip (heatmaps, heatmap_names):
            tifffile.imwrite(str(heatmaps_folder / f"{heatmap_name}_index={index}_label={label}"), np.moveaxis(heatmap.cpu().detach().numpy()[0, ...], 0, 2))

        with torch.no_grad():
            audcs = []
            for heatmap, heatmap_name in zip(heatmaps, heatmap_names):
                outs, fracs_deleted = calc_deletion_curve(
                    problem.model, img, heatmap, label, imputation, auc_steps
                    )
                audcs.append(calc_auc(outs, fracs_deleted))

                curve = np.stack([fracs_deleted, outs.cpu().detach().numpy()], axis=1)
                np.savetxt(curves_folder/f"{index}_{heatmap_name}.csv", curve, delimiter=",", header="fracs_deleted,outs")

            with open(audcs_file_name, "a") as audcs_file:
                write_str = f"{label},"+",".join([repr(a) for a in audcs])
                audcs_file.write(write_str+"\n")

            print(list(zip(heatmap_names, audcs)))
