import argparse
import os
import sys
from pathlib import Path
from typing import List, Tuple

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np


def _add_utils_to_path() -> None:
    """Ensure the local utils module is importable regardless of CWD."""
    script_dir = Path(__file__).resolve().parent
    utils_dir = script_dir.parent / 'source' / 'utils'
    if str(utils_dir) not in sys.path:
        sys.path.insert(0, str(utils_dir))


_add_utils_to_path()
import plot_utils as putils  # noqa: E402

def plot_loss_hist(ax, loss_hist_all: np.ndarray, color) -> None:
    """Plot a concatenated loss history for a single run.

    Expects a 2D array where the first dimension indexes segments; segments are
    plotted consecutively to create a single timeline.
    """
    T = loss_hist_all.shape[0]
    tmp = 0
    for i in range(T):
        xs = np.arange(loss_hist_all[i].shape[0])
        xs += tmp
        ax.plot(xs, loss_hist_all[i], lw=2, color=color, alpha=0.7)
        tmp += loss_hist_all[i].shape[0]


def collect_data(
    root_dir: str,
    base: str,
    seeds: int = 50,
) -> Tuple[List[List[np.ndarray]], List[List[np.ndarray]], List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
    """Load metric and loss data for all seeds.

    Returns:
    - real_vals[i]: list of arrays for metric i across seeds
    - test_vals[i]: list of arrays for metric i across seeds
    - total_losses: list of total loss arrays across seeds
    - dist_losses: list of distance loss arrays across seeds
    - vendi_losses: list of vendi loss arrays across seeds
    """
    real_vals: List[List[np.ndarray]] = [[], [], []]
    test_vals: List[List[np.ndarray]] = [[], [], []]
    total_losses: List[np.ndarray] = []
    dist_losses: List[np.ndarray] = []
    vendi_losses: List[np.ndarray] = []

    for seed in range(seeds):
        real_file = os.path.join(root_dir, f'res/{base}_seed_{seed}_DIST.npy')
        test_file = os.path.join(root_dir, f'res/{base}_seed_{seed}_test_DIST.npy')
        loss_file = os.path.join(root_dir, 'res', f'{base}_seed_{seed}_LOSS.npy')
        dloss_file = os.path.join(root_dir, 'res', f'{base}_seed_{seed}_DLOSS.npy')
        vloss_file = os.path.join(root_dir, 'res', f'{base}_seed_{seed}_VLOSS.npy')

        if os.path.isfile(real_file) and os.path.isfile(test_file):
            real_all = np.load(real_file)
            test_all = np.load(test_file)
            # Expect shape (3, steps) -> metrics: [MMD, Wass, Vendi]
            for i in range(3):
                real_vals[i].append(real_all[i])
                test_vals[i].append(test_all[i])

        if os.path.exists(loss_file):
            total_losses.append(np.load(loss_file))
        if os.path.exists(dloss_file):
            dist_losses.append(np.load(dloss_file))
        if os.path.exists(vloss_file):
            vendi_losses.append(np.load(vloss_file))

    return real_vals, test_vals, total_losses, dist_losses, vendi_losses


def plot_metric_with_ci(ax, values: List[np.ndarray], color, label: str) -> Tuple[np.ndarray, np.ndarray]:
    """Plot mean and std fill for a collection of 1D arrays."""
    data = np.array(values)
    avg = data.mean(axis=0)
    std = data.std(axis=0)
    xs = np.arange(avg.shape[0])
    ax.plot(avg, 'o-', mfc='white', markersize=9, c=color, lw=3, label=label)
    ax.fill_between(xs, avg - std, avg + std, color=color, alpha=0.2)
    return avg, std


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument('--dir1', type=str, default='../results/demo_train_multi_cluster/multi_cluster_qubits_4_2_dat_3000_1000')
    parser.add_argument('--base1', type=str, default='multi_cluster_ryzcz_qubits_4_2_steps_STEPS_DIST_lays_LAYERS_in_product_dat_1000_3000_epoch_1001_100_lr_0.001_init_1.0_vd_VD')
    parser.add_argument('--dist', type=str, default='wass')
    parser.add_argument('--vd', type=float, default=0.0)
    parser.add_argument('--steps', type=int, default=10)
    parser.add_argument('--layers', type=int, default=10)

    # Retained (unused) for backward-compatibility of CLI
    parser.add_argument('--ymin', type=float, default=1e-1)
    parser.add_argument('--ymax', type=float, default=1e-1)
    args = parser.parse_args()

    dir1, base1 = args.dir1, args.base1
    dist, vd = args.dist, args.vd

    base1 = base1.replace('_DIST_', f'_{dist}_').replace('_vd_VD', f'_vd_{vd}').replace('_steps_STEPS_', f'_steps_{args.steps}_').replace('_lays_LAYERS_', f'_lays_{args.layers}_')
    # Style
    putils.setPlot(fontsize=26, labelsize=24, lw=3)

    # Layout: 3 metrics on top row, then 3 rows for different losses
    fig = plt.figure(figsize=(24, 32))
    gs = gridspec.GridSpec(4, 3)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])
    ax4 = fig.add_subplot(gs[1, :])
    ax5 = fig.add_subplot(gs[2, :])
    ax6 = fig.add_subplot(gs[3, :])
    axs = [ax1, ax2, ax3, ax4, ax5, ax6]

    metrics = ['MMD', 'Wass', 'Vendi']
    colors = [putils.BLUE_m, putils.RED_m, putils.GREEN_m]

    # Load all data once
    real_vals, test_vals, total_losses, dist_losses, vendi_losses = collect_data(dir1, base1, seeds=50)

    # Plot metrics
    for i, ax in enumerate(axs[:3]):
        if i == 0:
            ax.set_ylabel(r'MMD Dist.')
        elif i == 1:
            ax.set_ylabel(r'Wass Dist.')
        elif i == 2:
            ax.set_ylabel('Vendi Score')
            ax.set_yscale('log')

        if len(real_vals[i]) > 0:
            # Test-backward
            avg_test, _ = plot_metric_with_ci(ax, test_vals[i], colors[i], r'Test-incremental')
            # True
            data_real = np.array(real_vals[i])
            avg_real = data_real.mean(axis=0)
            std_real = data_real.std(axis=0)
            xs = np.arange(avg_real.shape[0])
            ax.plot(avg_real, '--', mfc='white', markersize=9, c=colors[i], lw=3, label=r'True')
            ax.fill_between(xs, avg_real - std_real, avg_real + std_real, color='gray', alpha=0.2)

            if i == 2:
                # Preserve diagnostic print from original script
                print(f'Vendi Score: real={avg_real[0]}, test={avg_test[0]}')
                print(np.array(test_vals[i]).shape, avg_real.shape)

    # Plot losses for each available seed
    for arr in total_losses:
        plot_loss_hist(ax4, arr, putils.BLUE_m)
    for arr in dist_losses:
        plot_loss_hist(ax5, arr, putils.VERMILLION_m)
    for arr in vendi_losses:
        plot_loss_hist(ax6, arr, putils.GREEN_m)

    # Axes cosmetics
    putils.set_axes_tick1(axs[:3], xlabel='Step $k$', legend=True, tick_minor=True, top_right_spine=True, w=3, tick_length_unit=5)
    putils.set_axes_tick1(axs[3:], xlabel=r'$\rm Epochs$', legend=False, tick_minor=True, top_right_spine=True, w=3, tick_length_unit=5)
    ax4.set_ylabel('Total loss', fontsize=30)
    ax5.set_ylabel('Distance loss', fontsize=30)
    ax6.set_ylabel('Vendi loss', fontsize=30)
    ax6.set_yscale('log')
    for it in axs:
        it.legend(loc='best', fontsize=20, framealpha=0.8)

    plt.tight_layout()

    # Save
    if os.path.exists(dir1) and len(test_vals) > 0:
        fig_folder = os.path.join(dir1, 'figs_diff')
        os.makedirs(fig_folder, exist_ok=True)
        fig_file = os.path.join(fig_folder, f'eval_loss_{base1}')
        for ftype in ['png', 'svg', 'pdf']:
            plt.savefig(f'{fig_file}.{ftype}', bbox_inches='tight', dpi=300)

    plt.show()
    plt.clf()
    plt.close()


if __name__ == "__main__":
    main()