from ast import literal_eval

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
import volimg.utils
from matplotlib.colors import Normalize
from pdisvae import evaluate, inference, utils
from pdisvae.models.linear import LinearDecoder, LinearEncoder
from sklearn.decomposition import PCA, FastICA
from tqdm import tqdm

session_list = volimg.utils.session_list


def prepare_data(
    session: str,
    outcome: str,
    trial_side: str,
):
    df_session = volimg.utils.load_session(session)
    outcome = "success"
    trial_side = "left"
    images_masked = torch.from_numpy(
        volimg.utils.apply_mask(
            df_session["images"][
                (df_session["outcome"] == outcome)
                & (df_session["trial_side"] == trial_side)
            ].mean()
        )[400:550]
    )
    x = (
        images_masked.T
    )  # (n_pixels, n_time_bins) = (n_total_samples, obs_dim), spatial ICA
    n_time_bins, n_pixels = images_masked.shape
    obs_dim, n_total_samples = n_time_bins, n_pixels
    return x


def make_video(
    tag: str,
    method: str,
    x: torch.Tensor,
):
    # https://stackoverflow.com/questions/17853680/animation-using-matplotlib-with-subplots-and-artistanimation
    results_file = f"results_{tag}"
    n_components, n_groups = literal_eval(method)
    group_rank = int(n_components / n_groups)

    n_total_samples, obs_dim = x.shape

    decoder = LinearDecoder(obs_dim=obs_dim, n_components=n_components)
    decoder.load_state_dict(torch.load(f"{results_file}/{method}_decoder.pth"))

    with torch.no_grad():
        source = torch.load(f"{results_file}/{method}_z_pred_mean.pth")
        mixing = decoder.mixing_and_bias.weight.data.detach()
        bias = decoder.mixing_and_bias.bias.data.detach()

    #     fig, ax = plt.subplots(1, 1, figsize=(n_groups, 4), layout="constrained")
    #     ax.axis("off")
    #     ax.set(
    #         xticks=np.linspace(25, n_groups * 50 + 25, n_groups+1),
    #         xticklabels=[f"group {group + 1}" for group in range(n_groups)] + ["bias"],
    #         yticks=[25, 75],
    #         yticklabels=["raw", "group reconstruction"],
    #     )
    fig, axs = plt.subplots(
        2,
        max(n_groups + 1, 4),
        figsize=(1.5 * max(n_groups + 1, 4), 3),
        layout="constrained",
    )
    title = f"{n_components} components, {n_groups} groups, each group is rank-{group_rank}\n"
    metric_dict = {
        "partial correlation": "PC",
        "reconstruction $R^2$": "recon $R^2$",
        "reconstruction RMSE": "recon RMSE",
    }
    for metric in metric_dict:
        value = pd.read_csv(f"{results_file}/{method}.csv").at[0, f"{metric}"]
        title += f"{metric_dict[metric]}: {value:.3f}, "
    title = title[:-2]
    fig.suptitle(title)

    for ax in axs.flatten():
        sns.despine(ax=ax, top=True, bottom=True, right=True, left=True)
        ax.set(xticks=[], yticks=[])
    for group in range(n_groups):
        axs[0, group].set_title(f"group {group + 1}")
    axs[0, n_groups].set_title("bias")
    axs[0, 0].set_ylabel("raw")
    axs[1, 0].set_ylabel("group\nrecon")

    ims = []
    x_group = np.zeros((n_groups, n_total_samples, obs_dim))
    for group in range(n_groups):
        x_group[group] = (
            source[:, group * group_rank : (group + 1) * group_rank]
            @ mixing[:, group * group_rank : (group + 1) * group_rank].T
        )

    #     big_map = np.zeros((100, (n_groups + 1)* 50))
    #     for t in tqdm(range(n_time_bins)):
    #         for group in range(n_groups):
    #             big_map[0:50, group*50:(group+1)*50] = volimg.utils.remove_mask(x[:, t])
    #             big_map[50:100, group*50:(group+1)*50] = volimg.utils.remove_mask(x_group[group, :, t])
    #         big_map[50:100, n_groups*50:(n_groups+1) * 50] = volimg.utils.remove_mask(bias[t] * torch.ones(n_total_samples))

    #         im = ax.imshow(
    #             big_map,
    #             vmin=-3,
    #             vmax=3,
    #             cmap="seismic",
    #             interpolation="bilinear",
    #             animated=True,
    #         )
    #         ims.append([im])

    for t in range(obs_dim):
        im = []
        for group in range(n_groups):
            ax = axs[0, group]
            im0 = ax.imshow(
                volimg.utils.remove_mask(x[:, t]),
                vmin=-3,
                vmax=3,
                cmap="seismic",
                interpolation="bilinear",
                animated=True,
            )
            im.append(im0)

            ax = axs[1, group]
            im1 = ax.imshow(
                volimg.utils.remove_mask(x_group[group, :, t]),
                vmin=-3,
                vmax=3,
                cmap="seismic",
                interpolation="bilinear",
                animated=True,
            )
            im.append(im1)

        axs[0, n_groups].axis("off")
        ax = axs[1, n_groups]
        im2 = ax.imshow(
            volimg.utils.remove_mask(bias[t] * torch.ones(n_total_samples)),
            vmin=-3,
            vmax=3,
            cmap="seismic",
            interpolation="bilinear",
            animated=True,
        )
        im.append(im2)

        ims.append(im)

    ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
    writer = animation.PillowWriter(fps=200, bitrate=1800)
    ani.save(f"{tag}_{method}.gif", writer=writer)


def make_sequence(tag: str, method: str, n_frames: int):
    results_file = f"results_{tag}"
    n_components, n_groups = literal_eval(method)
    group_rank = int(n_components / n_groups)

    n_total_samples, obs_dim = x.shape

    decoder = LinearDecoder(obs_dim=obs_dim, n_components=n_components)
    decoder.load_state_dict(torch.load(f"{results_file}/{method}_decoder.pth"))

    with torch.no_grad():
        source = torch.load(f"{results_file}/{method}_z_pred_mean.pth")
        mixing = decoder.mixing_and_bias.weight.data.detach()
        bias = decoder.mixing_and_bias.bias.data.detach()

    #     fig, axs = plt.subplots(
    #         n_groups, n_frames, figsize=(12, n_groups / n_frames * 12), layout="constrained"
    #     )
    #     title = f"{n_components} components, {n_groups} groups, each group is rank-{group_rank}\n"
    #     metric_dict = {
    #         "partial correlation": "PC",
    #         "reconstruction $R^2$": "recon $R^2$",
    #         "reconstruction RMSE": "recon RMSE",
    #     }
    #     for metric in metric_dict:
    #         value = pd.read_csv(f"{results_file}/{method}.csv").at[0, f"{metric}"]
    #         title += f"{metric_dict[metric]}: {value:.3f}, "
    #     title = title[:-2]
    #     fig.suptitle(title)

    for group in range(n_groups):
        fig = plt.figure(figsize=(12, 12))
        ax = fig.add_subplot(projection="3d")
        for t_idx, t in enumerate(np.linspace(0, obs_dim - 1, n_frames).astype(int)):
            imshow3d(
                ax,
                np.flip(
                    volimg.utils.remove_mask(
                        source[:, group * group_rank : (group + 1) * group_rank]
                        @ mixing[t, group * group_rank : (group + 1) * group_rank]
                    ),
                    0,
                ),
                "x",
                pos=50 * t_idx,
                cmap="seismic",
                norm=mpl.colors.Normalize(vmin=-1, vmax=1, clip=True),
            )

        ax.set_aspect("equal")
        ax.axis("off")


#         for t_idx, t in enumerate(np.linspace(0, obs_dim-1, n_frames).astype(int)):
#             ax = axs[group, t_idx]
#             ax.axis("off")
#             ax.imshow(
#                 volimg.utils.remove_mask(
#                     source[:, group * group_rank : (group + 1) * group_rank]
#                     @ mixing[t, group * group_rank : (group + 1) * group_rank]
#                 ),
#                 vmin=-3,
#                 vmax=3,
#                 cmap="seismic",
#                 interpolation="bilinear",
#             )
#     return fig, axs


def imshow3d(ax, array, value_direction="z", pos=0, norm=None, cmap=None):
    """
    Display a 2D array as a  color-coded 2D image embedded in 3d.

    The image will be in a plane perpendicular to the coordinate axis *value_direction*.

    Parameters
    ----------
    ax : Axes3D
        The 3D Axes to plot into.
    array : 2D numpy array
        The image values.
    value_direction : {'x', 'y', 'z'}
        The axis normal to the image plane.
    pos : float
        The numeric value on the *value_direction* axis at which the image plane is
        located.
    norm : `~matplotlib.colors.Normalize`, default: Normalize
        The normalization method used to scale scalar data. See `imshow()`.
    cmap : str or `~matplotlib.colors.Colormap`, default: :rc:`image.cmap`
        The Colormap instance or registered colormap name used to map scalar data
        to colors.
    """
    if norm is None:
        norm = Normalize()
    colors = plt.get_cmap(cmap)(norm(array))

    if value_direction == "x":
        nz, ny = array.shape
        zi, yi = np.mgrid[0 : nz + 1, 0 : ny + 1]
        xi = np.full_like(yi, pos)
    elif value_direction == "y":
        nx, nz = array.shape
        xi, zi = np.mgrid[0 : nx + 1, 0 : nz + 1]
        yi = np.full_like(zi, pos)
    elif value_direction == "z":
        ny, nx = array.shape
        yi, xi = np.mgrid[0 : ny + 1, 0 : nx + 1]
        zi = np.full_like(xi, pos)
    else:
        raise ValueError(f"Invalid value_direction: {value_direction!r}")
    ax.plot_surface(xi, yi, zi, rstride=1, cstride=1, facecolors=colors, shade=False)
