#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import numpy as np
import pandas as pd
import argparse

from folder_locations import get_experiments_path

def load_indices(path):
    indices = []
    for curve_file in (path / "curves").glob("*_ig.csv"):
        scores = np.loadtxt(curve_file, delimiter=",", skiprows=1, usecols=(1, ))
        if scores[0] > scores[-1]:
            index = int(curve_file.name[:curve_file.name.find("_")])
            indices.append(index)

    return indices

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Parse the results from the method comparison experiment.")
    parser.add_argument("--experiment_fashion", help="Name of the fashion mnist experiment folder")
    parser.add_argument("--experiment_resnet", help="Name of the ResNet mnist experiment folder")
    parser.add_argument("--experiment_vit", help="Name of the ViT mnist experiment folder")
    args = parser.parse_args()

    paths = [get_experiments_path() / args.experiment_fashion,
             get_experiments_path() / args.experiment_resnet,
             get_experiments_path() / args.experiment_vit
             ]
    dataset_indices = [load_indices(p) for p in paths]
    dataset_dfs = [pd.read_csv(p / "audcs.csv", delimiter=",") for p in paths]

    for indices in dataset_indices:
        print(len(indices))

    latex_names = ["MMBS (ours)", "MMBS + SG", "IG", "IG + SG", "GIG (paper)", "GIG (paper) + SG", "GIG (Saliency)", "GIG (Saliency) + SG", "XRAI (B + W)", "XRAI (zero)", "GradCAM", "Random"]
    col_names = ["mmbs", "mmbs_sg", "ig", "ig_sg", "gig_paper", "gig_paper_sg", "gig_saliency", "gig_saliency_sg", "xrai", "xrai_bl", "gradcam", "random"]

    for col_name, latex_name in zip(col_names, latex_names):
        latex_str = f"{latex_name}"

        for df, indices in zip(dataset_dfs, dataset_indices):
            if col_name in df.columns:
                audcs = df[col_name][indices]
                latex_str += f" & {np.average(audcs):.03f} [{np.quantile(audcs, 0.05):.03f}, {np.quantile(audcs, 0.95):.03f}]"
            else:
                latex_str += " & N.A"
        latex_str += r" \\"
        print(latex_str)
