#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
import pandas as pd
from xai.problems import ImageNetValProblem
from tqdm import tqdm
from pathlib import Path
import tifffile
from torchvision.transforms.functional import gaussian_blur
import argparse

from folder_locations import get_experiments_path, get_imagenet_val_data_path

def get_indices():
    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = "ResNet50_V2",
        num_per_class=1,
        class_step=1,
        device = "cuda:0")

    noise_img_path = Path("noise_ResNet50_img.tiff")
    noise_bl = torch.from_numpy(tifffile.imread(noise_img_path)).to(device=problem.device)

    indices = [[], [], [], [], []]

    for index in tqdm(range(1000)):
        img, label = problem.get_sample(index)

        baselines = [
            problem.normalize_intensity(torch.zeros_like(img)),
            problem.normalize_intensity(torch.ones_like(img)),
            torch.zeros_like(img),
            noise_bl,
            gaussian_blur(img, 101, 25)
        ]
        img_prob = problem.model(img)[0,label]
        for li, bl in zip(indices, baselines):
            bl_prob = problem.model(bl)[0,label]
            if img_prob > bl_prob:
                li.append(index)

    for li in indices:
        print(len(li))
    return indices


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parse the results from the baseline AUDCs comparison experiment.")
    parser.add_argument("--experiment", help="Name of the experiment folder")
    args = parser.parse_args()

    df = pd.read_csv(get_experiments_path() / args.experiment / "audcs.csv", delimiter=",")
    baseline_names = ["black", "white", "zero", "noise", "blur"]
    method_names = ["ig", "mmbs"]

    included_indices = get_indices()

    for att_baseline in baseline_names:
        for method in method_names:
            if method == "ig" :
                line = r"\multirow{2}{*}{"+f"{att_baseline.capitalize()}}} & {method.upper()} & "
            else:
                line = f" & {method.upper()} & "
            for eva_baseline, indices in zip(baseline_names, included_indices):
                data = df[df["baseline"]==att_baseline][f"{eva_baseline}_{method}"]
                data = data.to_numpy()[indices]
                mean = np.mean(data)
                quantile_low = np.quantile(data, 0.05)
                quantile_high = np.quantile(data, 0.95)
                line += f"{mean:.3f} [{quantile_low:.03f}, {quantile_high:.03f}]"
                line += r"\\ " if eva_baseline == baseline_names[-1] else "& "
            if method == "mmbs" :
                line += r"\hline"
            print(line)
