import matplotlib
import numpy as np
import torch
from IPython.core.display_functions import display
from matplotlib import pyplot as plt

from config import get_model_from_config
from experiments.grid_runs_plot import results_dir
from mamba_tiny.scans import compute_attention, compute_recurrence
from mqar_zoology.associative_recall import IGNORED_TOKEN
from scripts.scaling_laws_utils import set_quarter_ticks_from_arrays
from theory.simplified_linear_mamba import extract_model_weights, compute_effective_embeddings, compute_projections, \
    compute_Gram_matrices
from utils import recorder
from utils.common import print_model_parameter_count
from utils.recorder import collect_model_recordings

import matplotlib.ticker as mticker


# --- Global, paper-friendly sizes (tweak as you like) ---

# DPI = 100
DPI = 300

# matplotlib.rcParams['text.usetex'] = False  # comment out if you really want TeX

matplotlib.rcParams.update({
    "figure.dpi": DPI,
    "savefig.dpi": DPI,
    "font.size": 14,          # base font size
    "axes.labelsize": 25, #16,     # x/y labels
    "axes.titlesize": 18,     # title
    "xtick.labelsize": 25,
    "ytick.labelsize": 25,
    "legend.fontsize": 14,
})


def load_model_from_grid_run(
        run_name, V, D, N,
        model_class, model_variant,
        verbose=False,
        load_best_seed=False,
        device='cpu',
):

    # find model (ignores seed)
    best_models_path = results_dir / run_name / 'best_models'
    prefix_for_loaded_model = f"D_{str(D).zfill(4)}_N_{str(N).zfill(4)}"
    # print(f"Loading model from {prefix_for_loaded_model}*")
    # print(list(best_models_path.glob("*.pt")))
    matches = [x.name for x in best_models_path.glob("*.pt") if x.name.startswith(prefix_for_loaded_model)]
    if len(matches) == 0:
        raise FileNotFoundError

    if load_best_seed:
        raise NotImplementedError
    else:
        model_name_to_load = matches[0]

    # load - from grid results or a one trained from here  # TODO
    # saved_model_path = results_dir / run_name_to_load / 'best_model.pt'
    saved_model_path = best_models_path / model_name_to_load

    # ----------------------------------------------------------------------------------------

    # model

    # set_seed(0)

    model = get_model_from_config(V=V, D=D, N=N, model_class=model_class, model_variant=model_variant)

    model.to(device)

    if verbose:
        display(model)
        print_model_parameter_count(model)

    checkpoint = torch.load(saved_model_path, map_location=device)
    model.load_state_dict(checkpoint)

    return model


def run_model_on_single_sequence(model, dataloaders, ssm_mode='attention', split='val', device='cpu', print_ids=False):

    for batch in dataloaders[split]:
        break

    x_ids = batch.x_ids[0]

    # y_pred_ids = batch.x_ids

    x_ids = batch.x_ids[0].unsqueeze(0).to(device)
    y_true_ids = batch.y_true_ids[0].unsqueeze(0).to(device)

    model.eval()
    model.to(device)

    assert ssm_mode in ['recurrent', 'attention', 'selective_scan']

    model.layers[0].mixer.config['ssm_mode'] = ssm_mode
    print(f"setting SSM mode to '{ssm_mode}'\n")

    recorder.empty()

    # recorder.disable_recording()
    recorder.enable_recording()

    # predict
    y_pred_logits = model(x_ids).logits  # (B, L, V)

    # collect predictions
    y_pred_ids = y_pred_logits.argmax(dim=2)
    y_correct = y_pred_ids.eq(y_true_ids)
    y_correct = y_correct.masked_select(y_true_ids.ne(IGNORED_TOKEN)).detach().cpu()

    y_response_ids = y_pred_ids.squeeze().detach().cpu().numpy()

    total_correct = int(y_correct.sum())

    accuracy = float(y_correct.to(float).mean())

    if print_ids:
        print(f"\n{x_ids = }")
        print("\ny_response_ids = \n", y_response_ids, "\n")
        print("\ny_response_gt = \n", y_true_ids, "\n")

    print(f"\n{accuracy = }")
    print(f"{total_correct = } / {len(y_correct)}")

    # collect records
    records, records_order = collect_model_recordings(model, verbose=False)

    return x_ids, y_true_ids, y_pred_ids, records


def normalize(mat: np.ndarray) -> np.ndarray:
    """
    Scale `mat` so that its largest entry becomes +1, and its smallest entry ≥ -1.
    If |min(mat)| > |max(mat)|, flip the sign of the matrix first.
    """
    max_val = mat.max()
    min_val = mat.min()

    # If the negative side dominates, flip everything
    if abs(min_val) > abs(max_val):
        mat = -mat
        max_val = -min_val  # after flip, this is the new maximum

    # Avoid division by zero
    if max_val == 0:
        return mat  # all zeros

    return mat / max_val


def compute_ideal_x_B_C(x_ids, V):

    # ideal weights

    I = torch.eye(V)
    Z = torch.zeros((V, V))

    # V_k = V // 2
    # M_k = torch.eye(V)
    # M_k[V_k:, V_k:] = 0

    # write/read/output (k/q/v) projectors
    S_B_ideal = torch.hstack([I, Z])  # (N, 2D)
    S_C_ideal = torch.hstack([Z, I])  # (N, 2D)
    # S_B_ideal = M_k @ torch.hstack([I, Z])  # (N, 2D)
    # S_C_ideal = M_k @ torch.hstack([Z, I])  # (N, 2D)

    # input dependent matrices

    E_oh = torch.eye(V)

    x_curr = E_oh[:, x_ids[0]]

    x_prev = torch.zeros_like(x_curr)
    x_prev[:, 1:] = x_curr[:, :-1]

    x_hat = torch.vstack([x_prev, x_curr])

    x_ssm_ideal = x_hat.T.unsqueeze(0)

    B_ideal = (S_B_ideal @ x_hat).T.unsqueeze(0)
    C_ideal = (S_C_ideal @ x_hat).T.unsqueeze(0)

    return x_ssm_ideal, B_ideal, C_ideal


def compute_ideal_attention_map(x_ids, V):

    x_ssm_ideal, B_ideal, C_ideal = compute_ideal_x_B_C(x_ids, V)

    alpha_ideal, y_ssm_ideal = (
        # _compute_attention(x=x_hat, A=None, B=B_ideal, C=C_ideal))
        compute_attention(x=x_ssm_ideal, A=None, B=B_ideal, C=C_ideal))

    alpha_ideal = alpha_ideal[0]

    return alpha_ideal


def compute_attention_maps(x_ids, records, V):

    alpha_ideal = compute_ideal_attention_map(x_ids, V)
    alpha_noisy = records['model.selective_scan__attention_alpha'][0].cpu().numpy()

    return alpha_ideal, alpha_noisy

def compute_ideal_hidden_state(x_ids, V):

    x_ssm_ideal, B_ideal, C_ideal = compute_ideal_x_B_C(x_ids, V)

    h_ideal, y_ssm_ideal = (
        # _compute_attention(x=x_hat, A=None, B=B_ideal, C=C_ideal))
        compute_recurrence(x=x_ssm_ideal, A=None, B=B_ideal, C=C_ideal))

    h_ideal = h_ideal[0]

    return h_ideal


def compute_hidden_states(x_ids, model, model_class, records, V, D, N_facts):

    max_t = int(N_facts * 2)  # length of context part

    Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out = compute_model_projection_matrices(model, model_class)

    h_ideal = compute_ideal_hidden_state(x_ids, V)[max_t]
    h_noisy = records["model.selective_scan__h_ssm"][0].cpu().numpy()[max_t]
    h_noisy_projected = Pi_v_out @ h_noisy.T @ Pi_q_in.numpy()

    H_ideal = h_ideal[:, V:].T
    H_noisy = h_noisy[:, D:].T
    H_noisy_projected = h_noisy_projected[:, V:]

    return H_ideal, H_noisy, H_noisy_projected


def pow2_ticks(vmax):
    nmax = int(np.floor(np.log2(vmax)))
    return [2 ** n for n in range(nmax + 1)]


def plot_hidden_states(H_ideal, H_noisy, H_noisy_projected, V, cmap='inferno', vmin=0, vmax=1, no_title=True):

    add_dim_label = False

    x_labels = [r"$k$", r"$k'$", r"$k''$"]
    y_labels = [r"$v$", r"$v'$", r"$v''$"]
    titles = [r"$H$", r"$H'$", r"$H''$"]

    H_matrices = [
        H_ideal,
        H_noisy,
        H_noisy_projected,
    ]

    for M, x_label, y_label, title in zip(H_matrices, x_labels, y_labels, titles):
        dims = np.array(M.shape)

        fig, ax = plt.subplots(figsize=(6, 6))  # bigger figure

        im = ax.imshow(
            M, cmap=cmap, interpolation="none",
            vmin=vmin,
            vmax=vmax,
            origin="lower",
            extent=(0, dims[1], 0, dims[0]),
            # aspect="auto",
            aspect=1,
        )

        if M.shape == (V, V):
            V_k = V // 2
            ax.set_xlim(0, V_k)  # key
            ax.set_ylim(V_k, V)  # value
            dims //= 2
        else:
            ax.set_xlim(0, dims[1])  # key
            ax.set_ylim(0, dims[0])  # value

        if add_dim_label:
            x_label = f"{x_label}\n$[V_k]$"
            y_label = f"{y_label}\n$[V_v]$"

        ax.set_xlabel(x_label, labelpad=8)
        ax.set_ylabel(y_label, rotation=0, labelpad=16)

        if not no_title:
            ax.set_title(title, pad=10)

        # Sharper ticks and frame for print
        ax.tick_params(axis="both", which="major", length=4, width=1.2)

        ax.xaxis.set_major_locator(mticker.MultipleLocator(dims[1] / 4))
        ax.yaxis.set_major_locator(mticker.MultipleLocator(dims[0] / 4))

        for spine in ax.spines.values():
            spine.set_linewidth(1.2)

        fig.tight_layout()
        plt.show()


def plot_attention_map(alpha, no_title=True, no_cbar=True, cmap='inferno', vmin=0, vmax=1):

    # data
    t_axis = np.arange(0, alpha.shape[0])
    alpha = normalize(alpha)

    # styling parameters

    # create figure & axis
    fig, ax = plt.subplots(figsize=(6, 5))

    # plot
    im = ax.pcolormesh(
        t_axis, t_axis,
        alpha.T,
        cmap=cmap,
        # shading='auto',
        vmin=vmin, vmax=vmax,
    )

    if not no_cbar:
        plt.colorbar(im, ax=ax)

    # title and labels
    if not no_title:
        ax.set_title(r"SSM Attention $\alpha$")

    # ax.set_xlabel('t [L]', fontsize=label_size)
    # ax.set_ylabel(r'$\tau$ [L]', fontsize=label_size)
    ax.set_xlabel(r'$t$')
    ax.set_ylabel(r'$\tau$')



    ax.set_aspect('equal')

    # # (optional) colorbar
    # cbar = fig.colorbar(im, ax=ax)
    # cbar.set_label(r'$\alpha$', fontsize=label_size)

    fig.tight_layout()
    plt.show()


def compute_model_projection_matrices(model, model_class):

    # extract trained model weights
    E_in, P_in, W, S_B, S_C, P_out, E_out = extract_model_weights(
        model=model, model_class=model_class,
        print_output_shapes=False,
    )

    # compute effective transformation matrices
    E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in = compute_effective_embeddings(
        E_in, P_in, W, S_B, S_C, P_out, E_out)
    Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out = compute_projections(
        E_hat_in, E_hat_out, E_tilde_q_in, E_tilde_k_in, P_out, E_out)

    return Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out


def compute_model_Gram_matrices(model, model_class):

    # compute effective transformation matrices
    Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out = compute_model_projection_matrices(model=model, model_class=model_class)
    G_kq, G_vv = compute_Gram_matrices(
        Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out)

    # slice
    V = G_kq.shape[1] // 2
    G_E_tilde = G_kq[:V, V:]  # prev→curr block
    G_E = G_vv[:V, :]  # curr→curr block

    return G_kq, G_vv, G_E_tilde, G_E

def plot_embedding_Gram_matrices(model, model_class, V):

    # compute
    G_kq, G_vv, G_E_tilde, G_E = compute_model_Gram_matrices(model, model_class)


    # plot
    plot_Gram_matrices(G_kq, G_vv, V)
    plot_Gram_stats(G_kq, G_vv, V)

    return G_kq, G_vv


def plot_Gram_matrices(G_kq, G_vv, V, cmap='inferno', cp_line_color='w', kv_line_color='gray', vmin=0, vmax=None, figsize=(12, 12), no_title=False):

    # G_kq is (2V, 2V)
    # G_vv is (V, 2V)

    G_names = ['G_kq', 'G_vv']

    # estimate signs
    G_E_tilde = G_kq[:V, V:]  # prev→curr block
    G_E = G_vv[:V, :]  # curr→curr block
    half_V = V // 2

    sign_G_kq = np.sign(float(torch.diag(G_E_tilde)[:half_V].mean()))
    sign_G_vv = np.sign(float(torch.diag(G_E)[half_V:].mean()))
    G_signs = [sign_G_kq, sign_G_vv]

    # G_signs = [1, 1]

    ncols = len(G_names)
    # fig, axes = plt.subplots(ncols=ncols, nrows=1, figsize=figsize, sharey=True)
    # if ncols == 1:
    #     axes = np.array([axes])

    # for name, sign, ax in zip(G_names, G_signs, axes):

    for name, sign in zip(G_names, G_signs):

        fig = plt.figure(figsize=figsize)

        G = eval(name)

        if isinstance(G, torch.Tensor):
            G = G.cpu().numpy()

        G *= sign

        ax = plt.gca()
        ax.imshow(
            G.T,
            interpolation='none',
            cmap=cmap,
            origin='upper',  # (0,0) at top-left
            # origin='lower',
            # extent=(0, 2*V, 2*V, 0)  # x: 0→2V left→right, y: 0→2V top→bottom,
            extent=(0, (2*V if name == 'G_kq' else V), 2*V, 0),
            vmin=vmin,
            vmax=vmax,
            aspect=1,
        )
        # ax.axis('image')  # equal aspect, tight to image
        if not no_title:
            ax.set_title(name)
        # ax.colorbar()

        if kv_line_color is not None:
            ax.axvline(0.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)
            ax.axhline(0.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)
            ax.axvline(1.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)
            ax.axhline(1.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)

        ax.axvline(V, color=cp_line_color, linestyle="-", alpha=0.75)
        ax.axhline(V, color=cp_line_color, linestyle="-", alpha=0.75)

        # # Axis labels
        # ax.set_xlabel("outputs")
        # ax.set_ylabel("inputs")

        # Sub-labels for halves
        half_ax_labels = ['prev', 'curr']

        if name == 'G_vv':
            ax.set_xlim([0, V])
            ax.set_aspect(1)
            ax.set_xticks([V / 2])
            ax.set_yticks([V / 2, 3 * V / 2])
            # ax.set_xticklabels([r'$v_t$'])
            # ax.set_yticklabels([r'$x_{\tau-1}$', r'$x_\tau$'])
            ax.set_xticklabels([r'$t$'])
            ax.set_yticklabels([r'${\tau-1}$', r'$\tau$'])


        elif name == 'G_kq':
            ax.set_xlim([0, 2 * V])
            ax.set_xticks([V/2, 3*V/2])
            ax.set_yticks([V / 2, 3 * V / 2])
            # ax.set_xticklabels([r'$x_{\tau-1}$', r'$x_\tau$'])
            # ax.set_yticklabels([r'$x_{t-1}$', r'$x_t$'])
            ax.set_xticklabels([r'${\tau-1}$', r'$\tau$'])
            ax.set_yticklabels([r'${t-1}$', r'$t$'])

        ax.tick_params(axis='y', labelrotation=90)

        # Optional: adjust tick label positions
        ax.tick_params(axis='x', bottom=True, top=False, labelbottom=True)
        ax.tick_params(axis='y', left=True, right=False, labelleft=True)


        plt.tight_layout()
        plt.show()

    # plt.tight_layout()
    # plt.show()


def plot_Gram_stats(G_kq, G_vv, V, cmap='inferno', kv_line_color='gray', figsize=(12, 5.5), share_y=True):

    # the diagonals now show the desired separation
    G_E_tilde = G_kq[:V, V:]  # prev-curr block
    G_E = G_vv[:, :V]  # curr-curr block

    half_V = V // 2

    M_matrices = [G_E, G_E_tilde]
    M_names = [r'$G_E$', r'$G_{\tilde{E}}$']
    # G_names = ['G_kq[V:, :V]', 'G_vv[:, :V]']

    ncols = len(M_names)  # always 2
    fig_0, axes_0 = plt.subplots(ncols=ncols, nrows=1, figsize=figsize, sharey=share_y)
    fig_1, axes_1 = plt.subplots(ncols=ncols, nrows=1, figsize=figsize, sharey=share_y)

    # for i, (M_name, G_source_name) in enumerate(zip(M_names, G_names)):
    for i, (M, M_name) in enumerate(zip(M_matrices, M_names)):

        # M = eval(M_name)

        # M_sign = np.sign(float(torch.diag(M).mean()))
        # M *= M_sign

        print(f"{M_name} diag mean k tokens = ", float(torch.diag(M)[:half_V].mean()))
        print(f"{M_name} diag mean v tokens = ", float(torch.diag(M)[half_V:].mean()))

        if isinstance(M, torch.Tensor):
            M = M.cpu().numpy()

        # title = f"{M_name} = {G_source_name}"

        title = M_name

        axes_0[i].imshow(M, interpolation='none', cmap=cmap, vmin=0)
        axes_0[i].axvline(0.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)
        axes_0[i].axhline(0.5 * V, color=kv_line_color, linestyle="-", alpha=0.5)
        axes_0[i].set_title(title)
        axes_0[i].set_aspect('auto')

        x_axis = np.arange(V+1)
        y_axis = np.arange(V+1)
        set_quarter_ticks_from_arrays(axes_0[i], x_axis, y_axis)  # temp

        axes_1[i].bar(range(V), np.diag(M))
        axes_1[i].axvline(0.5 * V, color='r', linestyle="-", alpha=1)
        axes_1[i].set_title(title + "\nDiagonal Elements")
        axes_1[i].set_aspect('auto')

        set_quarter_ticks_from_arrays(axes_1[i], x_axis, y_axis)  # temp

    plt.show()


def plot_conv_identity_shift(x_ids, model, model_class, V, N_facts):


    # extract trained model weights
    E_in, P_in, W, S_B, S_C, P_out, E_out = extract_model_weights(
        model=model, model_class=model_class,
        print_output_shapes=False,
    )
    # compute effective transformation matrices
    Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out = compute_model_projection_matrices(model=model, model_class=model_class)
    G_kq, G_vv = compute_Gram_matrices(Pi_q_in, Pi_k_in, Pi_v_in, Pi_v_out)

    E_oh = torch.eye(V)

    x_curr = E_oh[:, x_ids[0]]

    x_prev = torch.zeros_like(x_curr)
    x_prev[:, 1:] = x_curr[:, :-1]

    x_hat = torch.vstack([x_prev, x_curr])

    xi_tau = x_hat
    xi_t = x_hat
    x_t = x_curr

    max_t = int(N_facts * 2) - 1  # length of context part

    M_arrays = [
        x_t.T @ G_vv @ xi_tau,
        # x_t.T @ G_kq[:, V:].T @ xi_tau,
        xi_t.T @ G_kq.T @ xi_tau,
    ]

    lims = np.array([0 - 0.5, max_t + 0.5])

    for M in M_arrays:

        plt.imshow(M.T, interpolation='none', origin='upper', cmap='inferno', vmin=0)
        plt.xlim(lims)
        plt.ylim(lims[::-1])
        plt.show()