#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from xai.attribution import mbshap, mmbs, ig, sample_order_map
from xai.imputation import ConstantImputation
from xai.problems import FashionMnistProblem, FastMnistNet
import tifffile
from tqdm import tqdm
import argparse

import ct_experiment_utils as ceu
from folder_locations import get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot the MMBS-MBShap comparison results.")
    parser.add_argument("--gpu", type=int, default=0, help="GPU index.")
    parser.add_argument("--range_start", type=int, default=0, help="Starting image index.")
    parser.add_argument("--range_stop", type=int, default=25, help="Stopping image index.")
    args = parser.parse_args()

    device = f"cuda:{args.gpu}"
    num_iterations = 10000

    problem = FashionMnistProblem(
        weights_path="fashion_mnist_weights_paper.pt",
        data_path=get_fashion_mnist_data_path(),
        device = device)

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    mbshap_path = experiment_path / f"reference_MBShap_{num_iterations}"
    mbshap_path.mkdir(parents=True)

    model = problem.model

    for i in tqdm(range(args.range_start, args.range_stop)):
        img, label = problem.get_sample(i)
        tifffile.imwrite(str(experiment_path/f"img_{i}.tiff"), img.cpu().detach().numpy())

        for m in [1, 2, 4, 8, 16, 32, 64, 128]:
            heatmaps_path = experiment_path / f"MMBS_{m}" / f"img_{i}"
            heatmaps_path.mkdir(parents=True)
            for j in range(num_iterations):
                heatmap = mmbs(model, img, label, ConstantImputation(torch.zeros_like(img)), m, 1, progress_bar=False).cpu().numpy()[0, 0, :, :]
                tifffile.imwrite(str(heatmaps_path/f"mmbs_steps_{m}_iter_{j}.tiff"), heatmap)

        heatmaps_path = experiment_path / f"MBShap" / f"img_{i}"
        heatmaps_path.mkdir(parents=True)
        for j in range(num_iterations):
            heatmap = mbshap(model, img, label, ConstantImputation(torch.zeros_like(img)), 1, progress_bar=False).cpu().numpy()[0, 0, :, :]
            tifffile.imwrite(str(heatmaps_path/f"mbshap_iter_{j}.tiff"), heatmap)

        reference_heatmap = mbshap(model, img, label, ConstantImputation(torch.zeros_like(img)), num_iterations, progress_bar=True).cpu().numpy()[0, 0, :, :]
        tifffile.imwrite(str(mbshap_path/f"reference_img_{i}.tiff"), reference_heatmap)
