#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch
from xai.attribution import mbshap, mmbs, ig, sample_order_map
from xai.imputation import ConstantImputation
from xai.problems import FashionMnistProblem, FastMnistNet
import tifffile

import ct_experiment_utils as ceu
from folder_locations import get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    device = "cuda:0"

    problem = FashionMnistProblem(
        weights_path="fashion_mnist_weights_paper.pt",
        data_path=get_fashion_mnist_data_path(),
        device = device)

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    heatmaps_path = experiment_path / "heatmaps"
    heatmaps_path.mkdir()

    img, label = problem.get_sample(0)
    model = problem.model

    heatmap = ig(model, img, label, torch.zeros_like(img), 32).cpu().numpy()[0, 0, :, :]
    tifffile.imwrite(str(heatmaps_path/"heatmap_ig.tiff"), heatmap)

    torch.manual_seed(1)
    order_maps = [sample_order_map(img) for _ in range(4**7)]

    for n in range(7):
        num_paths = 4**n
        heatmap = mbshap(model, img, label, ConstantImputation(torch.zeros_like(img)), num_paths, order_maps=order_maps).cpu().numpy()[0, 0, :, :]
        tifffile.imwrite(str(heatmaps_path/f"heatmap_mbshap_paths_{num_paths}.tiff"), heatmap)
        for m in [1, 2, 4, 8, 16, 32]:
            heatmap = mmbs(model, img, label, ConstantImputation(torch.zeros_like(img)), m, num_paths, order_maps=order_maps).cpu().numpy()[0, 0, :, :]
            tifffile.imwrite(str(heatmaps_path/f"heatmap_mmbs_paths_{num_paths}_steps_{m}.tiff"), heatmap)
