import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torchsde
from einops import rearrange
from ANONYMOUStorch.utils import as_numpy, no_grad_func


# pylint: disable=invalid-name
def traj_plot(ts, samples, xlabel, ylabel, title="", fsave="img.png"):
    ts = ts.cpu()
    # samples = samples.squeeze().t().cpu()
    # (T,B,D)->(B,T,D)
    samples = rearrange(samples, "t b d -> b t d").cpu()
    plt.figure()
    for i, sample in enumerate(samples):
        plt.plot(ts, sample.flatten(), marker="x", label=f"sample {i}")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.savefig(fsave)
    plt.close()


@no_grad_func
def drfit_surface(model):
    model.cuda()
    xs = th.linspace(-3.0, 3.0, 120).view(-1, 1).cuda()
    ts = th.linspace(0.0, 0.99, 100).cuda()

    values = []
    for cur_t in ts:
        values.append(model.f_func(cur_t, xs))

    values = th.cat(values, dim=1)

    x, t, zz = as_numpy([xs, ts, values])
    tt, xx = np.meshgrid(t, x)
    return tt, xx, zz


@no_grad_func
def generate_samples(model):
    x0 = th.zeros((2000, 2)).float().cuda()
    ts = th.tensor([0.0, 1.0]).float().cuda()
    ys = torchsde.sdeint(model, x0, ts, dt=0.01)
    y1 = ys[-1]
    return as_numpy(y1[:, :1])
