#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import torch
from xai.attribution import mmbs, mbshap
from xai.imputation import ConstantImputation
from xai.problems import ImageNetValProblem, FashionMnistProblem
import tifffile
from tqdm import tqdm
from pathlib import Path
import argparse
from functools import partial
from time import time
import sys

import ct_experiment_utils as ceu
from folder_locations import get_imagenet_val_data_path, get_experiments_path, get_fashion_mnist_data_path

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run MMBS on imagenet.")
    parser.add_argument("--gpu", type=int, default=0, help="Index of the GPU to use.")
    parser.add_argument("--network_name", help="Name of the network architecture.")
    parser.add_argument("--range_start", type=int, default=0, help="Starting image index.")
    parser.add_argument("--range_step", type=int, default=20, help="Step size of the image index.")
    parser.add_argument("--range_stop", type=int, default=1000, help="Stopping image index.")
    args = parser.parse_args()

    mmbs_samples = 1
    mbshap_samples = 1
    num_steps = [1, 2, 4, 8, 16, 32, 64, 128]

    if args.network_name == "Fashion":
        problem = FashionMnistProblem(
            weights_path="fashion_mnist_weights_paper.pt",
            data_path=get_fashion_mnist_data_path(),
            device = f"cuda:{args.gpu}")
    else:
        problem = ImageNetValProblem(
            data_path = get_imagenet_val_data_path(),
            network_name = args.network_name,
            num_per_class=1,
            class_step=1,
            class_offset=0,
            device = f"cuda:{args.gpu}")

    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())
    results_path = experiment_path / "results.csv"
    with open(results_path, "w") as results_file:
        results_file.write("img_index,method,time_per_iteration\n")



    for i in tqdm(range(args.range_start, args.range_stop, args.range_step)):
        sys.stdout.flush()
        img, label = problem.get_sample(i)

        # Calculate one heatmap and discard it to avoiud possible startup effects
        heatmap = mmbs(problem.model, img, label, ConstantImputation(torch.zeros_like(img)), 1, 10, progress_bar=False)

        start_time = time()
        heatmap = mbshap(problem.model, img, label, ConstantImputation(torch.zeros_like(img)), mbshap_samples, progress_bar=False)
        end_time = time()
        time_per_iteration = (end_time-start_time)/mmbs_samples
        print(f"MBshap time per iteration = {time_per_iteration}", flush=True)
        with open(results_path, "a") as results_file:
            results_file.write(f"{i},MBShap,{time_per_iteration}\n")

        for steps in num_steps:
            start_time = time()
            heatmap = mmbs(problem.model, img, label, ConstantImputation(torch.zeros_like(img)), steps, mmbs_samples, progress_bar=False)
            end_time = time()
            time_per_iteration = (end_time-start_time)/mbshap_samples
            print(f"MMBS({steps}) time per iteration = {time_per_iteration}", flush=True)
            with open(results_path, "a") as results_file:
                results_file.write(f"{i},MMBS({steps}),{time_per_iteration}\n")