#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import ct_experiment_utils as ceu
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path

from folder_locations import get_experiments_path

def calc_rel_mse_curve(heatmap_iters, references, num_iterations_per_image):
    cum_iters = np.cumsum(num_iterations_per_image)
    cum_mean = np.cumsum(heatmap_iters*num_iterations_per_image[None, :, None, None], axis=1) / cum_iters[None, :, None, None]

    rels = np.mean(np.square(references[:,None,:,:]), axis=(2,3))
    mses = np.mean(np.square(cum_mean-references[:,None,:,:]), axis=(2,3))
    curve = np.mean(mses/rels, axis=0)
    return curve

def load_stack_of_stacks(path, num_stacks, sorted_key_outer=None, sorted_key_inner=None):
    stack_names = sorted(path.iterdir(), key=sorted_key_outer)
    return np.stack([ceu.load_stack(stack_names[i], sorted_key=sorted_key_inner) for i in range(num_imgs)])

if __name__ == "__main__":
    num_imgs = 16
    num_steps = list(reversed([4, 8, 16, 32, 64, 128]))
    num_iterations_per_image = np.array([1]*10 + [10]*9 + [100]*9 + [1000]*9)
    mbshap_iterations_per_image = np.array([1]*10 + [10]*9)
    in_experiment_path = get_experiments_path() / "2026-05-19-calc_imagenet_convergence_combined"
    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())

    sorted_key = lambda p : int(p.stem.split("_")[-1])
    sorted_key2 = lambda p : int(p.stem.split("=")[-1])
    references_MMBS = ceu.load_stack(in_experiment_path / "reference", range_stop=num_imgs, sorted_key=sorted_key)
    mbshap_data = load_stack_of_stacks(in_experiment_path / "MBShap", num_imgs, sorted_key, sorted_key2)
    references_MBShap = np.sum(mbshap_data * mbshap_iterations_per_image[None,:,None,None], axis=1)/np.sum(mbshap_iterations_per_image)


    mmbs_curves_MMBS_ref = {}
    mmbs_curves_MBShap_ref = {}
    for m in num_steps:
        mmbs_data = load_stack_of_stacks(in_experiment_path / f"MMBS_{m}", num_imgs, sorted_key, sorted_key2)
        mmbs_curves_MMBS_ref[m] = calc_rel_mse_curve(mmbs_data, references_MMBS, num_iterations_per_image)
        mmbs_curves_MBShap_ref[m] = calc_rel_mse_curve(mmbs_data, references_MBShap, num_iterations_per_image)

    df_list = []
    for m in num_steps:
        df_list.append(pd.DataFrame(
            {"Number of iterations" : np.cumsum(num_iterations_per_image)[9:],
             "Method" : [f"{m} steps", ] * len(num_iterations_per_image[9:]),
             "Relative MSE" : mmbs_curves_MMBS_ref[m][9:]}))
    df_MMBS_ref = pd.concat(df_list)

    df_list = []
    for m in num_steps:
        df_list.append(pd.DataFrame(
            {"Number of iterations" : np.cumsum(num_iterations_per_image)[9:],
             "Method" : [f"{m} steps", ] * len(num_iterations_per_image[9:]),
             "Relative MSE" : mmbs_curves_MBShap_ref[m][9:]}))
    df_MBShap_ref = pd.concat(df_list)

    plt.figure(figsize=(10.5, 3.8))
    plt.subplot(121)
    plot = sns.lineplot(data=df_MMBS_ref, x="Number of iterations", y="Relative MSE", hue="Method", style="Method", legend=True, linewidth = 2)
    plot.set(xscale='log')
    plt.ylim((0, 1.1))
    plt.yticks(np.arange(0, 1.11, step=0.1))
    plt.xlim((10, 12000))
    plt.title("MMBS (10000 iterations, 128 steps) reference")
    plt.grid()
    plt.legend(loc="upper left")
    plt.subplot(122)
    plot = sns.lineplot(data=df_MBShap_ref, x="Number of iterations", y="Relative MSE", hue="Method", style="Method", legend=True, linewidth = 2)
    plot.set(xscale='log')
    plt.ylim((0, 1.1))
    plt.yticks(np.arange(0, 1.11, step=0.1))
    plt.xlim((10, 12000))
    plt.title("MBShap (100 iterations) reference")
    plt.grid
    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.grid()
    plt.savefig(experiment_path / "plot.svg")
