#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
from xai.attribution import mbshap, mmbs
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem
import tifffile
from tqdm import tqdm
import argparse

import ct_experiment_utils as ceu
from folder_locations import get_experiments_path, get_imagenet_val_data_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot the MMBS-MBShap comparison results.")
    parser.add_argument("--gpu", type=int, default=0, help="GPU index.")
    parser.add_argument("--range_start", type=int, default=0, help="Starting image index.")
    parser.add_argument("--range_step", type=int, default=1, help="Step size of the image index.")
    parser.add_argument("--range_stop", type=int, default=10, help="Stopping image index.")
    parser.add_argument("--network_name", default="ResNet50_V2", help="Name of the neural network architecture.")
    args = parser.parse_args()

    device = f"cuda:{args.gpu}"
    num_iterations_per_image = [1]*10 + [10]*9 + [100]*9 + [1000]*9
    mbshap_iterations_per_image = [1]*10 + [10]*9
    total_iterations = np.sum(num_iterations_per_image)
    mmbs_step_nums = [1, 2, 4, 8, 16, 32, 64, 128]

    problem = ImageNetValProblem(
        data_path = get_imagenet_val_data_path(),
        network_name = args.network_name,
        num_per_class=1,
        class_step=1,
        device = f"cuda:{args.gpu}")

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    reference_path = experiment_path / f"reference"
    reference_path.mkdir(parents=True)

    model = problem.model

    for i in tqdm(range(args.range_start, args.range_stop, args.range_step)):
        img, label = problem.get_sample(i)
        tifffile.imwrite(str(experiment_path/f"img_{i}.tiff"), img.cpu().detach().numpy())

        for steps in mmbs_step_nums:
            print(f"{steps} steps")
            heatmaps_path = experiment_path / f"MMBS_{steps}" / f"img_{i}"
            heatmaps_path.mkdir(parents=True)
            for j, iters in enumerate(num_iterations_per_image):
                heatmap = mmbs(model, img, label, ConstantImputation(torch.zeros_like(img)), steps, iters, progress_bar=False).cpu().numpy()[0, 0, :, :]
                iterations_so_far = np.sum(num_iterations_per_image[:j+1])
                tifffile.imwrite(str(heatmaps_path/f"mmbs_steps_{steps}_total_iters={iterations_so_far}.tiff"), heatmap)

        print(f"MBShap")
        heatmaps_path = experiment_path / "MBShap" / f"img_{i}"
        heatmaps_path.mkdir(parents=True)
        for j, iters in enumerate(mbshap_iterations_per_image):
            heatmap = mbshap(model, img, label, ConstantImputation(torch.zeros_like(img)), iters, progress_bar=True).cpu().numpy()[0, 0, :, :]
            iterations_so_far = np.sum(num_iterations_per_image[:j+1])
            tifffile.imwrite(str(heatmaps_path/f"mbshap_total_iters={iterations_so_far}.tiff"), heatmap)

        print(f"Reference")
        reference_heatmap = mmbs(model, img, label, ConstantImputation(torch.zeros_like(img)), mmbs_step_nums[-1], total_iterations, progress_bar=True).cpu().numpy()[0, 0, :, :]
        tifffile.imwrite(str(reference_path/f"reference_img_{i}.tiff"), reference_heatmap)