#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
import numpy as np
from xai.attribution import mmbs, mbshap
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem
from xai.utils import convert_torch_img_to_np
import argparse
import tifffile
import itertools
import time

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(time.time())

    parser = argparse.ArgumentParser(description="Run MMBS on imagenet.")
    parser.add_argument("--gpu", type=int, help="Index of the GPU to use.")
    parser.add_argument("--network_name", help="Name of the network architecture.")
    parser.add_argument("--indices", help="Indices of the input images (comma separated).")
    args = parser.parse_args()

    indices = [int(s) for s in args.indices.split(",")]

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    heatmaps_path = experiment_path / "heatmaps"
    heatmaps_path.mkdir()
    with open(experiment_path / "durations.csv", "w") as csv_file:
        csv_file.write("label,method,repetition,duration\n")

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = args.network_name,
        num_per_class=1,
        class_step=1,
        device = f"cuda:{args.gpu}")

    for index in indices:
        img, label = problem.get_sample(index)

        for rep in range(16):

            for mmbs_steps in [8, 16]:
                    start_time = time.time()
                    heatmap_mmbs = mmbs(problem.model, img, label, ConstantImputation(torch.zeros_like(img)), mmbs_steps, 10)
                    tifffile.imwrite(
                        str(heatmaps_path/f"heatmap_{label}_mmbs_{mmbs_steps}_rep_{rep}.tiff"),
                        convert_torch_img_to_np(heatmap_mmbs)
                        )
                    duration = time.time() - start_time
                    with open(experiment_path / "durations.csv", "a") as csv_file:
                        csv_file.write(f"{label},mmbs_{mmbs_steps},{rep},{duration}\n")

            start_time = time.time()
            heatmap_mbshap = mbshap(problem.model, img, label, ConstantImputation(torch.zeros_like(img)), 10)
            tifffile.imwrite(
                str(heatmaps_path/f"heatmap_{label}_mbshap_rep_{rep}.tiff"),
                convert_torch_img_to_np(heatmap_mbshap)
                )
            duration = time.time() - start_time
            with open(experiment_path / "durations.csv", "a") as csv_file:
                csv_file.write(f"{label},mbshap,{rep},{duration}\n")
