import math

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch as th
from einops import rearrange
from ANONYMOUS.logging import wandb_img
from ANONYMOUStorch.data import num_to_groups
from ANONYMOUStorch.utils import as_numpy, no_grad_func
from ANONYMOUS_ABCviz.plt import jcolors


# pylint: disable=too-many-function-args, too-many-locals
def traj_plot(samples, xlabel, ylabel, title, fsave):
    samples = rearrange(samples, "t b d -> b t d").cpu()
    fig, axs = plt.subplots(1, 1, figsize=(7, 7))
    for i, sample in enumerate(samples):
        axs.quiver(
            sample[:-1, 0],
            sample[:-1, 1],
            sample[1:, 0] - sample[:-1, 0],
            sample[1:, 1] - sample[:-1, 1],
            scale_units="xy",
            angles="xy",
            scale=1,
            color=jcolors[i],
        )
    axs.set_title(title)
    axs.set_xlabel(xlabel)
    axs.set_ylabel(ylabel)
    fig.savefig(fsave)
    plt.close(fig)
    wandb_img(title, fsave, fsave)


def viz_sample(sample, title, fsave, sample_num=50000):
    points = as_numpy(sample)
    fig, axs = plt.subplots(1, 1, figsize=(7, 7))
    axs.set_title(title)
    axs.plot(
        points[:sample_num, 0],
        points[:sample_num, 1],
        linewidth=0,
        marker=".",
        markersize=1,
    )
    axs.set_xlim(-5, 5)
    axs.set_ylim(-5, 5)
    fig.savefig(fsave)
    plt.close(fig)
    wandb_img(title, fsave, fsave)


def viz_kde(points, fname, lim=9.0):
    points = as_numpy(points)
    # assert points.ndim == 2 and points.shape[0] == 2
    fig, ax = plt.subplots(1, 1, figsize=(10, 10), dpi=200)
    sns.kdeplot(
        x=points[:2000, 0], y=points[:2000, 1], cmap="coolwarm", shade=True, ax=ax
    )
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    ax.axis("off")
    fig.savefig(fname)
    plt.close(fig)
    wandb_img("kde", fname, fname)


@no_grad_func
def generate_samples_loss(model, dataset, dt=0.01, t_end=1.0, num_sample=2000):
    dim = dataset.ndim
    x = th.zeros((num_sample, dim + 1)).float().cuda()
    normal_const = dim / 2 * np.log(2 * np.pi) + 0.5 * dim * np.log(t_end)
    uw_term = 0
    for cur_t in th.arange(0, t_end, dt).cuda():
        f_value = model.f(cur_t, x)
        g_value = model.g(cur_t, x)
        noise = th.randn_like(g_value) * math.sqrt(dt)
        x += f_value * dt + g_value * noise
        uw_term += (f_value[:, :-1] * noise[:, :-1]).sum(dim=1)
    state = x[:, :-1]
    disc_loss, cur_idx = [], 0
    for cur_len_batch in num_to_groups(num_sample, 256):
        disc_loss.append(dataset.get_disc(state[cur_idx : cur_idx + cur_len_batch]))
        cur_idx = cur_idx + cur_len_batch
    disc_loss = th.cat(disc_loss)
    quad_loss = -0.5 * state.pow(2).sum(dim=-1) / t_end - normal_const
    total_loss = x[:, -1] + uw_term + disc_loss + quad_loss
    info = {
        "sample/loss": total_loss.mean().item(),
        "sample/disc_loss": disc_loss.mean().item(),
        "sample/quad_loss": quad_loss.mean().item(),
        "sample/uw_loss": uw_term.mean().item(),
        "sample/reg_loss": x[:, -1].mean().item(),
    }
    return state, total_loss, info


@no_grad_func
def viz_2dfield_t(model, t, title, fsave, xlim=3):
    x = th.linspace(-xlim, xlim, 100).cuda()
    xx, yy = th.meshgrid(x, x)
    points = th.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
    f_vector = model.f(t, points)
    mag = th.norm(f_vector, dim=1)
    xx, yy, f_vector, mag = as_numpy(xx, yy, f_vector, mag)
    fig, axs = plt.subplots(1, 1, figsize=(7, 7))
    axs.streamplot(xx, yy, f_vector[:, 0], f_vector[:, 1], color=mag, cmp="autumn")
    axs.set_title(title)
    axs.set_xlim(-xlim, xlim)
    axs.set_ylim(-xlim, xlim)
    fig.savefig(fsave)
    plt.close(fig)
    wandb_img(title, fsave, fsave)


@no_grad_func
def viz_field(model, title, fsave, xlim=3.0):  # pylint: disable=too-many-locals
    n = 20
    x = th.linspace(-xlim, xlim, n).cuda()
    xx, yy = th.meshgrid(x, x)
    ts = th.linspace(0.0, 1.0, 10).cuda()
    points = th.cat([xx.reshape(-1, 1), yy.reshape(-1, 1)], axis=1)
    xx_np, yy_np = as_numpy([xx, yy])
    len_t = len(ts)
    with plt.style.context("img"):
        fig, axs = plt.subplots(1, len_t, figsize=(len_t * 7, 1 * 7))
        for i_th, cur_t in enumerate(ts):
            f_vector = model.f_func(cur_t, points)
            # mag = th.norm(f_vector, dim=1)
            f_vector = as_numpy(f_vector)
            mag = np.hypot(f_vector[:, 0], f_vector[:, 1])
            axs[i_th].quiver(
                xx_np,
                yy_np,
                f_vector[:, 0],
                f_vector[:, 1],
                mag,
                units="x",
                pivot="tip",
                width=0.022,
            )
            # axs[i_th].streamplot(
            #     x.cpu().numpy(),
            #     x.cpu().numpy(),
            #     f_vector[:, 0].reshape(n, n),
            #     f_vector[:, 1].reshape(n, n),
            #     color=mag.reshape(n, n),
            #     cmap="autumn",
            # )
            # axs[i_th].streamplot(
            #     xx_np,
            #     yy_np,
            #     f_vector[:, 0].reshape(n, n),
            #     f_vector[:, 1].reshape(n, n),
            #     color=mag.reshape(n,n), cmap="autumn"
            # )
            axs[i_th].set_title(f"{cur_t}")
            axs[i_th].set_xlim(-xlim, xlim)
            axs[i_th].set_ylim(-xlim, xlim)

    fig.savefig(fsave)
    plt.close(fig)
    wandb_img(title, fsave, fsave)
