import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation as anim
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Line3D
from tqdm import tqdm
from tabulate import tabulate
from .error import mse, l2_err
from .colors import ADAM_HGN, SWIM_RF_HGN, ELM_RF_HGN

def get_canvas():
    fig = plt.figure(figsize=(15, 20))
    gs = GridSpec(nrows=3, ncols=1, height_ratios=[3, 1, 1], hspace=0.2)
    ax_anim = fig.add_subplot(gs[0])
    ax_anim.set_aspect("equal")
    gs_props = GridSpecFromSubplotSpec(nrows=1, ncols=4, subplot_spec=gs[1], wspace=0.3)

    ax_ke = fig.add_subplot(gs_props[0])
    ax_pe = fig.add_subplot(gs_props[1])
    ax_rel2 = fig.add_subplot(gs_props[2])
    ax_h = fig.add_subplot(gs_props[3])

    return fig, ax_anim, (ax_ke, ax_pe, ax_rel2, ax_h)

def is_3d_axis(ax) -> bool:
    return hasattr(ax, "zaxis")

def plot_nodes(ax, positions, scatter_obj=None, **kwargs):
    """
    Args:
        ax              Axis of the plot
        positions       of shape (N,2) or (N,3)
    """
    positions = np.asarray(positions)

    # Initialize scatter if not given
    if scatter_obj is None:
        if is_3d_axis(ax):
            return ax.scatter(positions[:, 0], positions[:, 1], positions[:, 2], **kwargs)
        else:
            return ax.scatter(positions[:, 0], positions[:, 1], **kwargs)

    # Update scatter if given
    if is_3d_axis(ax):
        scatter_obj._offsets3d = (positions[:, 0], positions[:, 1], positions[:, 2])
    else:
        scatter_obj.set_offsets(positions[:, :2])

    return scatter_obj

def plot_edges(ax, positions, edge_index, line_collection=None, **kwargs):
    """
    Args:
        ax              Axis of the plot
        positions       of shape (N,2) or (N,3)
        edge_index      of shape (2,E) array of edges
    """
    positions = np.asarray(positions)
    src, dst = edge_index

    if is_3d_axis(ax):
        if line_collection is None:
            lines = []
            for idx_src, idx_dst in zip(src, dst):
                node_src, node_dst = positions[idx_src], positions[idx_dst]
                lx = Line3D([node_src[0], node_dst[0]],
                            [node_src[1], node_dst[1]],
                            [node_src[2], node_dst[2]], **kwargs)
                ax.add_line(lx)
                lines.append(lx)
            return lines

        # update existing lines
        for line, idx_src, idx_dst in zip(line_collection, src, dst):
            node_src, node_dst = positions[idx_src], positions[idx_dst]
            line.set_data([node_src[0], node_dst[0]], [node_src[1], node_dst[1]])
            line.set_3d_properties([node_src[2], node_dst[2]])
        return line_collection
    else:
        # 2D LineCollection
        segments = np.stack([positions[src], positions[dst]], axis=1)
        if line_collection is None:
            lc = LineCollection(segments, **kwargs)
            ax.add_collection(lc)
            return lc

        line_collection.set_segments(segments)
        return line_collection

def animate_2D(q, edge_index=None, framing_length=1, filename="sim2d.mp4",
               q_preds=None, pred_labels=None):
    """
    Args:
        q               array of shape (n_steps, n_obj, 2)
        edge_index      (2,E) array of edges
        framing_length  Frames to animate
        q_preds and pred_labels should be given as arrays of predictions (e.g. from multiple models)
    """
    q = q[::framing_length]
    n_steps, n_obj, n_dim = q.shape
    assert n_dim == 2, "2D animation needs (n_steps, n_obj, 2) input"

    num_animations = 1 if q_preds is None else len(q_preds)
    fig, axes = plt.subplots(1, num_animations, figsize=(5, 3))

    # Auto limits
    margin = 0.1 * np.max(np.abs(q))

    # Objects that will be updated per frame
    node_scatter_list = []
    edge_lc_list = []

    for ax in axes:
        ax.set_aspect("equal")
        ax.set_xlim(np.min(q[...,0]) - margin, np.max(q[...,0]) + margin)
        ax.set_ylim(np.min(q[...,1]) - margin, np.max(q[...,1]) + margin)
        node_scatter_list.append(None)
        edge_lc_list.append(None)

    if q_preds is not None:
        assert len(q_preds) == len(pred_labels)
        q_pred_dict = {}
        node_scatter_pred_dict = {}
        edge_lc_pred_dict = {}
        for q_pred, label in zip(q_preds, pred_labels):
            q_pred_dict[label] = q_pred[::framing_length]
            assert (
                q.shape == q_pred_dict[label].shape
            ), f"q and q_pred are mismatching: {q.shape} != {q_pred_dict[label].shape}"
            node_scatter_pred_dict[label] = None
            edge_lc_pred_dict[label] = None

    def get_color(label):
        match label.lower():
            case "(adam) hgn":
                return ADAM_HGN
            case "(elm) rf-hgn":
                return ELM_RF_HGN
            case "(swim) rf-hgn":
                return SWIM_RF_HGN
            case _:
                raise ValueError(f"Unknown model label: {label}")

    def init():
        nonlocal node_scatter_list
        for idx, ax in enumerate(axes):
            node_scatter_list[idx] = plot_nodes(
                ax, np.zeros((n_obj, 2)), scatter_obj=None, s=35, facecolors="none", edgecolors="black", zorder=3, label=r"Using true $\mathcal{H}$"
            )
        if q_preds is not None:
            nonlocal node_scatter_pred_dict
            for idx, label in enumerate(pred_labels):
                node_scatter_pred_dict[label] = plot_nodes(
                    axes[idx], np.zeros((n_obj, 2)), scatter_obj=None, s=30, c=get_color(label), label=label, zorder=4,
                )

        if edge_index is not None:
            nonlocal edge_lc_list
            for idx, ax in enumerate(axes):
                edge_lc_list[idx] = plot_edges(
                    ax, np.zeros((n_obj, 2)), edge_index, line_collection=None, color="gray", linewidth=1.0, zorder=1
                )
            if q_preds is not None:
                nonlocal edge_lc_pred_dict
                for idx, label in enumerate(pred_labels):
                    edge_lc_pred_dict[label] = plot_edges(
                        axes[idx], np.zeros((n_obj, 2)), edge_index, line_collection=None, color="gray", linewidth=1.0, zorder=2,
                    )
        return []

    def update(frame):
        nonlocal node_scatter_list
        pos = q[frame]
        for idx, ax in enumerate(axes):
            node_scatter_list[idx] = plot_nodes(ax, pos, scatter_obj=node_scatter_list[idx])
        if q_preds is not None:
            nonlocal node_scatter_pred_dict
            for idx, label in enumerate(pred_labels):
                pos_pred = q_pred_dict[label][frame]
                node_scatter_pred_dict[label] = plot_nodes(
                    axes[idx], pos_pred, scatter_obj=node_scatter_pred_dict[label]
                )

        if edge_index is not None:
            nonlocal edge_lc_list
            for idx, ax in enumerate(axes):
                edge_lc_list[idx] = plot_edges(ax, pos, edge_index, line_collection=edge_lc_list[idx])
            if q_preds is not None:
                nonlocal edge_lc_pred_dict
                for idx, label in enumerate(pred_labels):
                    pos_pred = q_pred_dict[label][frame]
                    edge_lc_pred_dict[label] = plot_edges(
                        axes[idx], pos_pred, edge_index, line_collection=edge_lc_pred_dict[label]
                    )

        # place legends
        if frame == 0:
            for ax in axes:
                ax.legend(loc="lower right", fontsize=6)

        return []

    print("Animating...")
    a = anim.FuncAnimation(fig, update, frames=n_steps, init_func=init, interval=20, blit=False)

    pbar = tqdm(total=len(q), desc="Rendering") # total=num_frames
    def progress_callback(i, _):
        pbar.update(i - pbar.n)

    a.save(filename, writer="ffmpeg", fps=30, progress_callback=progress_callback)
    print("Saved:", filename)

def animate_3D(q, edge_index=None, framing_length=1, filename="sim3d.mp4"):
    """
    Args:
        q               array of shape (n_steps, n_obj, 3)
        edge_index      (2,E) array of edges
        framing_length  Frames to animate
    """
    q = q[::framing_length]
    n_steps, n_obj, n_dim = q.shape
    assert n_dim == 3, "3D animation requires (n_steps, n_obj, 3) data"

    fig = plt.figure(figsize=(6,6))
    ax = fig.add_subplot(111, projection="3d")

    # Compute limits
    mins = np.min(q, axis=(0,1))
    maxs = np.max(q, axis=(0,1))
    center = 0.5 * (mins + maxs)
    r = np.max(maxs - mins) * 0.6

    ax.set_xlim(center[0]-r, center[0]+r)
    ax.set_ylim(center[1]-r, center[1]+r)
    ax.set_zlim(center[2]-r, center[2]+r)

    node_scatter = None
    edge_lines = None

    def init():
        nonlocal node_scatter, edge_lines
        node_scatter = plot_nodes(ax, np.zeros((n_obj, 3)), scatter_obj=None, s=40, c="blue")
        if edge_index is not None:
            edge_lines = plot_edges(ax, np.zeros((n_obj, 3)), edge_index,
                                    line_collection=None, color="gray", linewidth=1.0)
        return []

    def update(frame):
        nonlocal node_scatter, edge_lines
        pos = q[frame]
        node_scatter = plot_nodes(ax, pos, scatter_obj=node_scatter)
        if edge_index is not None:
            edge_lines = plot_edges(ax, pos, edge_index, line_collection=edge_lines)
        return []

    a = anim.FuncAnimation(fig, update, frames=n_steps, init_func=init, interval=20, blit=False)

    print("Animating...")
    a.save(filename, writer="ffmpeg", fps=30)
    print("Saved:", filename)

def generate_results_table(q_true, q_pred, e_true, e_pred, start_idx=0, num_evals=5):
    """
    Builds table with MSE, relative error on positions
    and true and predicted Hamiltonian values

    Args:
        q_true, q_pred      Both of shape (n_steps, n_nodes, dof)
        e_true, e_pred      Both of shape (n_steps, 1)
        start_idx           Start idx of the results table
        num_evals           Number of indices to evaluate
    Returns:
        table_title, table  Of types string and table
    """
    assert len(q_true.shape) == 3
    assert q_true.shape == q_pred.shape
    assert e_true.shape == e_pred.shape
    # Evaluate at 5 points in the trajectory
    test_indices = np.linspace(start_idx, len(q_true) - 1, num=num_evals, dtype=np.int64)
    q_mse, q_rel2 = np.zeros(num_evals, dtype=q_true.dtype), np.zeros(num_evals, dtype=q_true.dtype)
    e_true_table = np.zeros(num_evals, dtype=e_true.dtype)
    e_pred_table = np.zeros(num_evals, dtype=e_true.dtype)

    # Set integration constant using the initial true Hamiltonian value
    e_pred = e_pred + (e_true[0] - e_pred[0])

    for idx, test_index in enumerate(test_indices):
        q_mse[idx] = mse(q_true[test_index], q_pred[test_index], verbose=False)
        q_rel2[idx] = l2_err(q_true[test_index], q_pred[test_index], verbose=False)
        e_true_table[idx] = e_true[test_index].item()
        e_pred_table[idx] = e_pred[test_index].item()

    table_title = "\nTable: Predicted trajectory evaluation, error values on positions (q) are displayed with true and predicted conserved (energy) values."
    arr_columns = [ f"T={step_idx}" for step_idx in test_indices ]
    results_table = tabulate(
        headers=[""] + arr_columns,
        tabular_data=[
            ["q MSE"] + list(q_mse),
            ["q L2 rel."] + list(q_rel2),
            ["True H"] + list(e_true_table),
            ["Pred H"] + list(e_pred_table),
        ],
        floatfmt=".3e"
    )
    return table_title, results_table

def plot_xy_traj(ax_x, ax_y, qx, qy, obj_idx, color, linestyle, linewidth=2.0):
    """
    Plots obj_idx x and y into the given axes.
    For plotting
    [x position of particle 1][y positions of particle 1][MSE][Hamiltonian]

    Args:
        ax_x, ax_y          Axis to plot
        qx, qy              y-axis data
        obj_idx             Object to plot the positions
        color, linestyle/width  matplotlib params
    """
    assert qx.shape == qy.shape
    plotx = np.arange(1, len(qx) + 1)
    ax_x.plot(plotx, qx[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)
    ax_y.plot(plotx, qy[:, obj_idx], c=color, linestyle=linestyle, linewidth=linewidth)

def plot_mse(ax_mse, qx_true, qy_true, qx_pred, qy_pred, label, color, linestyle, linewidth=2.0):
    """
    Plots MSE positions against integration which uses the true Hamiltonian of the system.

    Args:
        ax_mse              Axis to plot
        qx_true, qy_true    True positions
        qx_pred, qy_pred    Predicted positions
        label, color, linestyle/width   matplotlib params
    """
    mse = np.mean((qx_true - qx_pred)**2 + (qy_true - qy_pred)**2, axis=1)
    plotx = np.arange(1, len(mse) + 1)
    ax_mse.semilogy(plotx, mse, c=color, label=label, linestyle=linestyle, linewidth=linewidth)

def plot_hamiltonian(ax_energy, energy, label, color, linestyle, linewidth=2.0):
    """
    Args:
        ax_energy           Axis to plot
        energy              y-axis data to plot
        label, color, linestyle/width   matplotlib params
    """
    plotx = np.arange(1, len(energy) + 1)
    ax_energy.semilogy(plotx, energy, label=label, c=color, linestyle=linestyle, linewidth=linewidth)

def plot_traj_energy_mse(q_true, e_true, q_preds, e_preds,
                         obj_idx,
                         model_labels, model_colors, model_linestyles,
                         filepath,
                         label_fontsize=12, legend_fontsize=14):
    assert len(model_labels) == len(model_colors) == len(model_linestyles)
    assert len(model_labels) == len(q_preds) == len(e_preds)

    qx_true, qy_true = np.array_split(q_true, 2, axis=-1)
    qx_preds = []
    qy_preds = []
    for q_pred in q_preds:
        qx_pred, qy_pred = np.array_split(q_pred, 2, axis=-1)
        qx_preds.append(qx_pred)
        qy_preds.append(qy_pred)

    fig, (ax_x, ax_y, ax_energy, ax_mse) = plt.subplots(1, 4, figsize=(14, 3), dpi=100)
    plot_xy_traj(ax_x, ax_y, qx_true, qy_true, obj_idx, "k", "solid", linewidth=4)
    plot_hamiltonian(ax_energy, e_true, r"using true $\mathcal{H}$", "k", "solid", linewidth=4)
    for qx_pred, qy_pred, e_pred, model_label, model_color, model_linestyle in zip(
        qx_preds, qy_preds, e_preds, model_labels, model_colors, model_linestyles
    ):
        plot_xy_traj(ax_x, ax_y, qx_pred, qy_pred, obj_idx, model_color, model_linestyle, linewidth=3)
        plot_mse(ax_mse, qx_true, qy_true, qx_pred, qy_pred, model_label, model_color, model_linestyle, linewidth=3)
        # Set integration constant using the initial true Hamiltonian value
        e_pred = e_pred + (e_true[0] - e_pred[0])
        plot_hamiltonian(ax_energy, e_pred, model_label, model_color, model_linestyle, linewidth=3)

    lines = []
    labels = []
    Line, Label = ax_energy.get_legend_handles_labels()
    lines.extend(Line)
    labels.extend(Label)
    fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.88), shadow=False)
    [ ax.set_xlabel("Time step", fontsize=label_fontsize) for ax in [ax_x, ax_y, ax_mse, ax_energy] ]
    ax_x.set_ylabel(fr"$q_1$ of node {obj_idx}", fontsize=label_fontsize)
    ax_y.set_ylabel(fr"$q_2$ of node {obj_idx}", fontsize=label_fontsize)
    ax_mse.set_ylabel("MSE", fontsize=label_fontsize)
    ax_mse.grid()
    ax_energy.set_ylabel(r"True $\mathcal{H}$", fontsize=label_fontsize)
    fig.tight_layout()

    fig.savefig(filepath)
    print(f"-> Plot saved at '{filepath}'")

def plot_snaps(q_true, q_preds, row_labels, colors, filepath,
               edge_index=None, num_snaps=5, legend_fontsize=12):
    """
    Args:
        row_axes        Axes to plot the prediction snapshots in a row
        q_true          True positions
        q_preds         Predicted positions as a lits [elm, adam, swim]
        row_labels      Label of the row (usually, model name)
        colors          Color of the predictions
        filepath        Filepath to save pdf
        edge_index      If given plot edges too
    """
    # 5,3 => 8, 9
    # fig, (ax_elm, ax_adam, ax_swim)  = plt.subplots(3, 5, figsize=(8, 9), dpi=100, sharex=True, sharey=True)
    fig, (ax_elm, ax_adam, ax_swim)  = plt.subplots(3, 5, figsize=(16, 5), dpi=100, sharex=True, sharey=True)

    # Axis limits
    margin = 0.2 * np.max(np.abs(q_true))
    anim_xlim_min = np.min(q_true[..., 0] - margin).item()
    anim_xlim_max = np.max(q_true[..., 0] + margin).item()
    anim_ylim_min = np.min(q_true[..., 1] - margin).item()
    anim_ylim_max = np.max(q_true[..., 1] + margin).item()

    # Matplotlib params
    true_args = { "facecolors": "none", "edgecolors": "red" }
    # To differentiate between pred and true
    vars = { 'true_label_set': False }
    lines = []
    labels = []

    test_indices = np.linspace(0, len(q_true) - 2, num=num_snaps, dtype=np.int64)

    def plot_snap(row_axes, q_traj, q_traj_pred, row_label, color):
        for row_axis, test_index in zip(row_axes, test_indices):
            row_axis.set_xlim(anim_xlim_min, anim_xlim_max)
            row_axis.set_ylim(anim_ylim_min, anim_ylim_max)

            if vars["true_label_set"]:
                row_axis.scatter(q_traj[test_index][..., 0], q_traj[test_index][..., 1], s=80, **true_args, zorder=20)
            else:
                row_axis.scatter(q_traj[test_index][..., 0], q_traj[test_index][..., 1], s=80, **true_args, zorder=20, label=r"using true $H$")
                vars["true_label_set"] = True
            row_axis.scatter(q_traj_pred[test_index][..., 0], q_traj_pred[test_index][..., 1], s=80, edgecolors="black", zorder=19, label=row_label, c=color)
            if edge_index is not None:
                # plot_edges(row_axis, q_traj[test_index], edge_index)
                plot_edges(row_axis, q_traj_pred[test_index], edge_index, zorder=1)
        Line, Label = row_axes[0].get_legend_handles_labels()
        lines.extend(Line)
        labels.extend(Label)

    # plot_snap(ax_elm, q_true, q_preds[0], "(ELM) RF-HGN", color=ELM_RF_HGN)
    # plot_snap(ax_adam, q_true, q_preds[1], "(Adam) HGN", color="tab:orange")
    # plot_snap(ax_swim, q_true, q_preds[2], "(SWIM) RF-HGN", color="tab:blue")
    plot_snap(ax_elm, q_true, q_preds[0], row_labels[0], color=colors[0])
    plot_snap(ax_adam, q_true, q_preds[1], row_labels[1], color=colors[1])
    plot_snap(ax_swim, q_true, q_preds[2], row_labels[2], color=colors[2])
    for ax, test_index in zip(ax_swim, test_indices):
        ax.set_xlabel(f"Time step {test_index}")
    # Legends
    fig.legend(lines, labels, loc='lower center', ncol=len(labels), fontsize=legend_fontsize, bbox_to_anchor=(0.5, 0.95))
    fig.tight_layout()
    # fig.legend(loc="upper center")
    fig.savefig(filepath)
    print(f"-> Snaps are saved at '{filepath}'")
