#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from torchvision.transforms.functional import gaussian_blur
from xai.attribution import mmbs, ig
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem
from xai.evaluation import calc_deletion_curve, calc_auc
import tifffile
from tqdm import tqdm
from pathlib import Path

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path

if __name__ == "__main__":
    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())

    auc_steps = 200
    ig_steps = 256
    mmbs_steps = 8
    mmbs_samples = 1024

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = "ResNet50_V2",
        num_per_class=1,
        class_step=1,
        device = "cuda:0")

    noise_img_path = Path("noise_ResNet50_img.tiff")
    if noise_img_path.exists():
        noise_bl = torch.from_numpy(tifffile.imread(noise_img_path)).to(device=problem.device)
    else:
        img, _ = problem.get_sample(0)
        noise_bl = problem.normalize_intensity(torch.rand_like(img))
        tifffile.imwrite(str(noise_img_path), noise_bl.cpu().numpy())

    baseline_names = ["black", "white", "zero", "noise", "blur"]

    with open(experiment_path / "audcs.csv", "a") as csv_file:
        csv_file.write("index,label,baseline")
        for baseline_name in baseline_names:
            csv_file.write(f",{baseline_name}_mmbs,{baseline_name}_ig")
        csv_file.write("\n")

    for index in tqdm(range(1000)):
        img, label = problem.get_sample(index)

        with torch.no_grad():
            baselines = [
                problem.normalize_intensity(torch.zeros_like(img)),
                problem.normalize_intensity(torch.ones_like(img)),
                torch.zeros_like(img),
                noise_bl,
                gaussian_blur(img, 101, 25)
            ]

        for baseline_att, baseline_name in zip(baselines, baseline_names):
            heatmap_mmbs = mmbs(problem.model, img, label, ConstantImputation(baseline_att), mmbs_steps, mmbs_samples, progress_bar=False)
            heatmap_ig = ig(problem.model, img, label, baseline_att, ig_steps)

            with open(experiment_path / "audcs.csv", "a") as csv_file:
                csv_file.write(f"{index},{label},{baseline_name}")

            with torch.no_grad():
                for baseline_eva in baselines:
                    audc_mmbs = calc_auc(*calc_deletion_curve(
                        problem.model, img, heatmap_mmbs, label, ConstantImputation(baseline_eva), auc_steps
                        ))
                    audc_ig = calc_auc(*calc_deletion_curve(
                        problem.model, img, heatmap_ig, label, ConstantImputation(baseline_eva), auc_steps
                        ))
                    with open(experiment_path / "audcs.csv", "a") as csv_file:
                        csv_file.write(f",{audc_mmbs},{audc_ig}")

            with open(experiment_path / "audcs.csv", "a") as csv_file:
                csv_file.write("\n")
