import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, ListedColormap, hsv_to_rgb


def plot_tensors(
        tensors_dict,
        inner_dims, d_state, seq_len,
        share_x_axis=False):
    # collect and preprocess arrays
    items = []
    for idx, (name, t) in enumerate(tensors_dict.items()):
        try:
            arr = t.squeeze().numpy()
        except:
            continue
        ndim = arr.ndim
        # peel off singleton-batch
        if ndim == 3 and arr.shape[0] == 1:
            arr = arr[0]
            ndim = 2
        # for time-major 3D
        if ndim == 3 and arr.shape[0] == seq_len:
            t_idx = seq_len - 1
            arr = arr[t_idx]
            ndim = 2
            name = f"{name}[t={t_idx}]"
        # shift any seq_len dim to last
        if (ndim >= 2) and (arr.shape[0] != arr.shape[1]):
            for ax in range(arr.ndim - 1):
                if arr.shape[ax] == seq_len:
                    arr = np.moveaxis(arr, ax, -1)
                    break
        ndim = arr.ndim
        # keep only 1D or 2D
        if ndim in (1, 2):
            items.append((idx, name, arr, ndim))

    if share_x_axis and items:
        # filter out arrays that don't match seq_len on x-axis
        filtered = []
        for idx, name, arr, ndim in items:
            if ndim == 2 and arr.shape[1] == seq_len:
                filtered.append((idx, name, arr, ndim))
            elif ndim == 1 and arr.shape[0] == seq_len:
                filtered.append((idx, name, arr, ndim))
        if filtered:
            fig, axes = plt.subplots(nrows=len(filtered), sharex=True, figsize=(6, 3 * len(filtered)))
            if len(filtered) == 1:
                axes = [axes]
            for ax, (idx, name, arr, ndim) in zip(axes, filtered):
                if ndim == 2:
                    im = ax.imshow(arr, aspect='auto', interpolation='none')
                    fig.colorbar(im, ax=ax)
                else:
                    im = ax.imshow(arr[np.newaxis, :], aspect='auto', interpolation='none')
                    fig.colorbar(im, ax=ax)
                ax.set_title(f"{idx}: {name}")
                ax.set_ylabel('Index')
            # label only bottom x-axis
            axes[-1].set_xlabel('Sequence index')
            plt.show()
            return

    # fallback to individual plots
    for idx, (name, t) in enumerate(tensors_dict.items()):
        # print(f"{name}: {t.shape = }")
        try:
            arr = t.numpy()
        except:
            print("exception!")
            continue
        ndim = arr.ndim
        if ndim == 3 and arr.shape[0] == 1:
            arr = arr[0]
            ndim = 2
        if ndim == 3 and arr.shape[0] == seq_len:
            t_idx = seq_len - 1
            arr = arr[t_idx]
            ndim = 2
            name = f"{name}[t={t_idx}]"
        if ndim >= 2:
            for ax in range(arr.ndim - 1):
                if arr.shape[ax] == seq_len:
                    arr = np.moveaxis(arr, ax, -1)
                    break
        ndim = arr.ndim
        if ndim == 2:
            h, w = arr.shape
            if h == seq_len and w != seq_len:
                arr = arr.T
                h, w = arr.shape
            if h in inner_dims and w == d_state:
                ylabel, xlabel = 'Inner dimension', 'State dimension'
            elif h == d_state and w in inner_dims:
                ylabel, xlabel = 'State dimension', 'Inner dimension'
            elif h in inner_dims and w == seq_len:
                ylabel, xlabel = 'Inner dimension', 'Sequence index'
            elif h == d_state and w == seq_len:
                ylabel, xlabel = 'State dimension', 'Sequence index'
            else:
                ylabel, xlabel = 'Row index', 'Column index'
            plt.figure()
            plt.imshow(arr, aspect='auto', interpolation='none')
            plt.title(f'{idx}: {name} (2D)')
            plt.xlabel(xlabel)
            plt.ylabel(ylabel)
            plt.colorbar()
            if share_x_axis and w == seq_len:
                plt.xlim(0, seq_len - 1)
            plt.show()
        elif ndim == 1:
            arr2 = arr[np.newaxis, :]
            plt.figure()
            plt.imshow(arr2, aspect='auto', interpolation='none')
            plt.title(f'{idx}: {name} (1D)')
            plt.xlabel('Inner dimension index')
            plt.ylabel('Singleton row')
            plt.colorbar()
            if share_x_axis and arr.shape[0] == seq_len:
                plt.xlim(0, seq_len - 1)
            plt.show()
        else:
            continue


def plot_and_print_AR_matrix(
        ax, mat, title, L, V, fig,
        debug_plot=True, debug_print=False,
        axis_to_collapse=None, axis_collapse_fn=None):
    """
    Helper to collapse and plot a matrix on given axis.
    - mat: 2D or 3D numpy array
    - title: string title
    - L: sequence length
    - V: vocabulary size
    - fig: parent figure for colorbar
    - debug_print: whether to print matrix info
    """

    assert L != V, "L and V must be different dimensions"

    shape_str = "×".join(str(s) for s in mat.shape)
    print(f"\n{title} (shape {shape_str})")

    if debug_print:
        if mat.ndim > 2:
            print("(plot skipped for ndim > 2)\n")
        else:
            print(mat, "\n")

    if not debug_plot:
        return

    # Validate dimensions
    if mat.ndim not in (2, 3):
        raise ValueError(f"Matrix must be 2D or 3D, got {mat.ndim}D")

    # Collapse along specified axis if requested
    c_size = None
    if axis_to_collapse is not None:
        assert axis_collapse_fn in (np.sum, np.argmax), \
            "axis_collapse_fn must be np.sum or np.argmax"
        c_size = mat.shape[axis_to_collapse]
        mat = axis_collapse_fn(mat, axis=axis_to_collapse,
                               keepdims=(mat.ndim != 3))
        print(f"Collapsed axis {axis_to_collapse} → new shape: {mat.shape}")
    else:
        print("No axis collapsed")

    # Ensure the result is 2D
    if mat.ndim != 2:
        raise ValueError(f"Collapsed array must be 2D, got {mat.ndim}D")

    # Transpose if necessary so that X axis is length L
    if mat.shape[1] != L:
        mat = mat.T
    assert mat.shape[1] == L, (
        f"After orienting, second dim must be L={L}, got {mat.shape[1]}"
    )

    # Determine labels

    sequence_label = "Sequence (L)"
    tokens_label = "Tokens (V)"
    count_label = "Count"

    x_label = sequence_label

    y_size = mat.shape[0]
    y_label = ""
    if y_size == V:
        y_label = tokens_label
    elif y_size == L:
        y_label = sequence_label
    elif y_size == 1:
        y_label = "(Single Dim.)"

    c_label = "Value"
    cmap = "viridis"
    norm = None

    if (axis_to_collapse is not None) and (axis_collapse_fn == np.sum):
        c_label = count_label

    elif c_size == V:
        c_label = tokens_label

        if V <= 20:
            base = plt.get_cmap('tab20')  # still only 20 base entries, but we sample continuously
            colors = base(np.linspace(0, 1, V))

        else:
            rng = np.random.default_rng(42)
            hsv = np.column_stack([rng.random(V), rng.uniform(0.6, 0.9, V), rng.uniform(0.7, 0.9, V)])
            colors = np.c_[hsv_to_rgb(hsv), np.ones(V)]

        colors[0] = (0.0, 0.0, 0.5, 0.75)  # deep blue
        cmap = ListedColormap(colors)
        norm = BoundaryNorm(np.arange(V + 1), V)

    # Plot heatmap
    aspect = 'auto'
    # aspect = 'equal'
    pcm = ax.imshow(mat, cmap=cmap, norm=norm, aspect=aspect, interpolation='none')
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title(f"{title}: ({shape_str})")

    # Colorbar
    cbar = fig.colorbar(pcm, ax=ax)
    cbar.set_label(c_label)
