#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import ct_experiment_utils as ceu
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path

from folder_locations import get_experiments_path

def calc_rel_mse_curve(heatmap_iters, references):
    cum_mean = np.cumsum(heatmap_iters, axis=1)/ np.arange(1, heatmap_iters.shape[1]+1)[None, :, None, None]

    rels = np.mean(np.square(references[:,None,:,:]), axis=(2,3))
    mses = np.mean(np.square(cum_mean-references[:,None,:,:]), axis=(2,3))
    curve = np.mean(mses/rels, axis=0)
    return curve

if __name__ == "__main__":
    num_imgs = 40
    num_steps = list(reversed([2, 4, 8, 16, 32, 64, 128]))
    in_experiment_path = get_experiments_path() / "2026-05-13_calc_fashion_mnist_convergence_combined"
    experiment_path = ceu.make_new_experiment_folder(get_experiments_path())

    sorted_key = lambda p : int(p.stem.split("_")[-1])
    references = ceu.load_stack(in_experiment_path / "reference_MBShap_10000", range_stop=num_imgs, sorted_key=sorted_key)

    mmbs_curves = {}
    for m in num_steps:
        mmbs_data = np.stack([ceu.load_stack(in_experiment_path / f"MMBS_{m}" / f"img_{i}") for i in range(num_imgs)])
        mmbs_curves[m] = calc_rel_mse_curve(mmbs_data, references)

    mbshap_data = np.stack([ceu.load_stack(in_experiment_path / "MBShap" / f"img_{i}") for i in range(num_imgs)])
    mbshap_curve = calc_rel_mse_curve(mbshap_data, references)

    df_list = []
    df_list.append(pd.DataFrame(
        {"Number of iterations" : range(1, len(mbshap_curve)+1)[10:],
         "Method" : ["MBShap", ] * len(mbshap_curve[10:]),
         "Relative MSE" : mbshap_curve[10:]}))
    for m in num_steps:
        df_list.append(pd.DataFrame(
            {"Number of iterations" : range(1, len(mmbs_curves[m])+1)[10:],
             "Method" : [f"{m} steps", ] * len(mmbs_curves[m][10:]),
             "Relative MSE" : mmbs_curves[m][10:]}))
    df = pd.concat(df_list)

    plt.figure(figsize=(7.7, 3.8))
    plot = sns.lineplot(data=df, x="Number of iterations", y="Relative MSE", hue="Method", style="Method", legend=True, linewidth = 2)
    plot.set(xscale='log')
    plt.xlabel("Number of iterations")
    plt.ylabel("Relative MSE")
    plt.ylim((0, 1.1))
    plt.xlim((10, 12000))
    plt.yticks(np.arange(0, 1.11, step=0.1))
    plt.legend(loc="upper left")
    plt.grid()
    plt.tight_layout()
    plt.savefig(experiment_path / "plot.svg")
    plt.close()
