#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
from torchvision.transforms.functional import gaussian_blur
from xai.attribution import mmbs, ig, gig, xrai, gradcam, smoothgrad
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem, FashionMnistProblem
from xai.evaluation import calc_deletion_curve, calc_auc
import tifffile
from tqdm import tqdm
from pathlib import Path
import argparse
from functools import partial

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run MMBS on imagenet.")
    parser.add_argument("--gpu", type=int, default=0, help="Index of the GPU to use.")
    parser.add_argument("--network_name", help="Name of the network architecture.")
    parser.add_argument("--range_start", type=int, default=0, help="Index to start with")
    parser.add_argument("--range_step", type=int, default=20, help="Step size")
    parser.add_argument("--range_stop", type=int, default=1000, help="Index to end on")
    args = parser.parse_args()

    auc_steps_list = [50, 100, 200, 400, 800, 1600]
    ig_steps = 256
    mmbs_steps = 8
    mmbs_samples = 1024
    mmbs_sg_samples = 128

    if args.network_name == "Fashion":
        problem = FashionMnistProblem(
            weights_path="fashion_mnist_weights_paper.pt",
            data_path=get_fashion_mnist_data_path(),
            device = f"cuda:{args.gpu}")
    else:
        problem = ImageNetValProblem(
            data_path = get_imagenet_val_data_path(),
            network_name = args.network_name,
            num_per_class=1,
            class_step=1,
            class_offset=0,
            device = f"cuda:{args.gpu}")


    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    results_folder = experiment_path / "results"
    results_folder.mkdir()

    heatmap_names = [
        "mmbs",
        "ig",
        "gig_paper",
        "gig_saliency",
        "mmbs_sg",
        "ig_sg",
        "gig_paper_sg",
        "gig_saliency_sg",
        "xrai",
        "xrai_bl"
        ]
    if problem.get_gradcam_layer() is not None:
        heatmap_names.append("gradcam")

    for index in tqdm(range(args.range_start, args.range_stop, args.range_step)):
        img, label = problem.get_sample(index)

        bl = torch.zeros_like(img)
        imputation = ConstantImputation(bl)

        heatmap_mmbs = mmbs(problem.model, img, label, imputation, mmbs_steps, mmbs_samples, progress_bar=False)
        heatmap_ig = ig(problem.model, img, label, bl, ig_steps)
        heatmap_gig_paper = gig(problem.model, img, label, bl, ig_steps)
        heatmap_gig_saliency = gig(problem.model, img, label, bl, ig_steps, fraction=0.25, max_dist=0.02)

        heatmap_mmbs_sg = smoothgrad(
            partial(mmbs, model=problem.model, label=label, imputation=imputation, num_steps=mmbs_steps, num_paths=mmbs_sg_samples, progress_bar=False),
            img)
        heatmap_ig_sg = smoothgrad(
            partial(ig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps),
            img)
        heatmap_gig_paper_sg = smoothgrad(
            partial(gig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps),
            img)
        heatmap_gig_saliency_sg = smoothgrad(
            partial(gig, model=problem.model, label=label, baseline=bl, num_steps=ig_steps, fraction=0.25, max_dist=0.02),
            img)

        heatmap_xrai = xrai(problem.model, img, label)
        heatmap_xrai_bl = xrai(problem.model, img, label, [bl])

        heatmaps = [
            heatmap_mmbs,
            heatmap_ig,
            heatmap_gig_paper,
            heatmap_gig_saliency,
            heatmap_mmbs_sg,
            heatmap_ig_sg,
            heatmap_gig_paper_sg,
            heatmap_gig_saliency_sg,
            heatmap_xrai,
            heatmap_xrai_bl]

        gradcam_layer = problem.get_gradcam_layer()
        if gradcam_layer is not None:
            heatmap_gradcam = gradcam(problem.model, img, label, gradcam_layer)
            heatmaps.append(heatmap_gradcam)
            heatmap_names.append("gradcam")

        with torch.no_grad():
            with open(results_folder / f"{index}.csv", "w") as file:
                file.write("method," + ",".join([str(s) for s in auc_steps_list]) + "\n")
                for heatmap, heatmap_name in zip(heatmaps, heatmap_names):
                    audcs = []
                    for auc_steps in auc_steps_list:
                        outs, fracs_deleted = calc_deletion_curve(
                            problem.model, img, heatmap, label, imputation, auc_steps
                            )
                        audcs.append(calc_auc(outs, fracs_deleted))
                    file.write(heatmap_name + "," + ",".join([repr(a) for a in audcs]) + "\n")