from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
from IPython.core.display_functions import clear_output
import time

from theory.theory import IGNORED_VALUE
from theory.simplified_linear_mamba import SimplifiedLinearMamba, set_simplified_linear_model_ideal_weights
from mqar.train import evaluate_split
from utils.common import set_seed


def evaluate_ideal_model_accuracy(V, D, N, dataloader, run_config, model_variant='mamba_tiny'):

    device = run_config['runtime']['device']

    # prepare model
    if model_variant == 'mamba_linear':
        model = SimplifiedLinearMamba(V=V, D=D, N=N, device=device)
        set_simplified_linear_model_ideal_weights(model=model)

    elif model_variant == 'mamba_tiny':
        raise NotImplementedError

    else:
        raise ValueError(f'Invalid {model_variant=}')

    # evaluate
    results, _ = evaluate_split(
        model, dataloader, run_config,
        # desc_text=f"D={int(D)}, N={int(N)} | ",
    )
    accuracy = results['accuracy']
    return accuracy


def plot_accuracy_grid(D_axis, N_axis, accuracy_grid, threshold_accuracy: float, metadata_title ='', cmap ='inferno'):

    x_label, y_label = 'D', 'N'

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))

    masked_accuracy_grid = np.ma.masked_equal(accuracy_grid, IGNORED_VALUE)

    m0 = ax[0].pcolormesh(D_axis, N_axis, masked_accuracy_grid.T, vmin=0, vmax=1, cmap=cmap)
    # ax[0].set_aspect(4)
    # ax[0].set_aspect(aspect, adjustable='box')
    fig.colorbar(m0, ax=ax[0])
    ax[0].set_title(f"Accuracy\n{metadata_title}")
    ax[0].set_xlabel(x_label)
    ax[0].set_ylabel(y_label)

    m1 = ax[1].pcolormesh(D_axis, N_axis, masked_accuracy_grid.T >= threshold_accuracy, vmin=0, vmax=1, cmap=cmap)
    # ax[1].set_aspect(4)
    fig.colorbar(m1, ax=ax[1])
    ax[1].set_title(f"Accuracy >= {threshold_accuracy}\n{metadata_title}")
    ax[1].set_xlabel(x_label)
    ax[1].set_ylabel(y_label)

    plt.show()


def evaluate_and_plot_ideal_model_accuracy_grid(
        V, D_axis, N_axis,
        threshold_accuracy,
        dataloader,
        run_config,
        n_seeds=1,
        metadata_title='',
        cmap='inferno',
):

    accuracy_grid = np.zeros((len(D_axis), len(N_axis)))

    # iterate
    for i, D in enumerate(D_axis):
        for j, N in enumerate(N_axis):

            if N > D:
                accuracy_grid[i, j] = IGNORED_VALUE
                continue

            # best of n seeds
            accuracy_per_seed = np.zeros(n_seeds)
            seeds = range(n_seeds)
            for k, seed in enumerate(seeds):
                set_seed(seed)
                accuracy = evaluate_ideal_model_accuracy(V, D, N, dataloader, run_config)
                accuracy_per_seed[k] = accuracy
            best_accuracy = max(accuracy_per_seed)

            # write
            accuracy_grid[i, j] = best_accuracy

            # plot
            plot_accuracy_grid(D_axis, N_axis, accuracy_grid, threshold_accuracy, metadata_title, cmap)
            # time.sleep(1)
            clear_output(wait=False)

    return accuracy_grid