import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from torchsde import SDEIto
import torch
from torch import nn

from .mlpsde import MlpSde
from .training import train


device = "cuda" if torch.cuda.is_available() else "cpu"


def prepare_data():
    df: pd.DataFrame = pd.read_csv(".data/chicken/hungary_chickenpox.csv")

    dti = pd.to_datetime(df["Date"], dayfirst=True)
    dti: pd.DataFrame = (dti - dti[0]) / pd.Timedelta("1w")

    del df["Date"]
    t, x = dti.to_numpy(), df.to_numpy().astype(np.float)

    t, x = torch.FloatTensor(t), torch.FloatTensor(x)
    return t, x / x.std(dim=1, keepdim=True)


class Model(SDEIto):
    def __init__(self, features: int):
        super().__init__(noise_type="diagonal")
        self.lin = nn.Linear(features, features, bias=False)
        self.log_sigma = nn.Parameter(torch.zeros((features,)))

    @property
    def sigma(self):
        return self.log_sigma.exp()

    def f(self, t, x):
        return self.lin(x)

    def g(self, t, x):
        return torch.ones_like(x) * self.sigma


def main():
    t, x = prepare_data()
    x = x.unsqueeze(1)
    # sde = Model(x.size(-1))
    sde = MlpSde([20, 128, 128, 20])

    def callback(step, res):
        print(f"[{step}] loss=[{res.loss[-1]}]")

    res = train(
        sde.to(device),
        t.to(device),
        x.to(device),
        dt=float(t.max() * 1e-3),
        callback=callback,
        # use_adjoint=True,
        use_sdeint=False,
        steps=5_000,
    )
    torch.save(sde.cpu(), ".data/chicken/checkpoint.pt")
    torch.save(res, ".data/chicken/resdata.pt")


def sample_grads(
    sde: SDEIto,
    deltatime=1e-1,
    use_sdeint=False,
    dt=1e-2,
    n_paths=1,
    n_samples=100,
    **kwargs,
):
    ts, xs = prepare_data()
    ts, xs = ts[:10], xs[:10]
    xs = xs.unsqueeze(1)
    ts = ts / ts[1] * deltatime

    if use_sdeint:
        from .sdeintloss import SdeIntLoss

        loss_fn = SdeIntLoss(nn.MSELoss(), n_samples=n_paths, dt=dt, use_adjoint=False)
    else:
        from .piis import PathIntLoss

        loss_fn = PathIntLoss(sde, ts, xs, dt, n_paths)

    grads = []
    for i in range(n_samples):
        loss = loss_fn(sde, ts, xs)
        sde.zero_grad(True)
        loss.backward()
        g = torch.cat(
            tuple(
                p.grad.clone().reshape(-1)
                for p in sde.parameters()
                if p.grad is not None
            )
        )
        grads.append(g)
        if i % 10 == 0:
            print(i, end=", ", flush=True)
    print()
    grads = torch.stack(grads)
    return grads


def plotvars(n):
    g1vars = [
        (lambda g: g.var(0).clamp(1e-20))(
            torch.load(f".data/chicken/var-adj/{i}.pt")["grads"]
        ).numpy()
        for i in range(n)
    ]
    g2vars = [
        (lambda g: g.var(0).clamp(1e-20))(
            torch.load(f".data/chicken/var-path/{i}.pt")["grads"]
        ).numpy()
        for i in range(n)
    ]

    fig, ax1 = plt.subplots(1, 1)
    fig.tight_layout()

    ax1: plt.Axes
    ax1.violinplot(g1vars, points=10000)
    ax1.set_title("SDE integration")
    ax1.set_yscale("log")
    ax1.set_xlabel("time between observations")
    ax1.set_ylabel("gradient variance")
    ax1.set_ylim(1e-17, 1e20)
    ax1.set_xticks(list(range(1, n + 1)))
    ax1.set_xticklabels(["$10^{" + f"{-2 + i}" + "}$" for i in range(n)])
    plt.savefig(
        ".data/figures/gradient-var-adj.png",
        bbox_inches="tight",
    )
    plt.close(fig)

    fig, ax1 = plt.subplots(1, 1)
    fig.tight_layout()

    ax1: plt.Axes
    ax1.violinplot(g2vars, points=10000)
    ax1.set_title("Importance sampling")
    ax1.set_yscale("log")
    ax1.set_xlabel("time between observations")
    # ax1.set_ylabel("gradient variance")
    ax1.set_ylim(1e-17, 1e20)
    ax1.set_xticks(list(range(1, n + 1)))
    ax1.set_xticklabels(["$10^{" + f"{-2 + i}" + "}$" for i in range(n)])
    plt.savefig(
        ".data/figures/gradient-var-pathint.png",
        bbox_inches="tight",
    )


def plot_fit():
    with torch.no_grad():
        from torchsde import sdeint

        t, x = prepare_data()
        t = t[:60]
        x = x[:60]
        res = torch.load(".data/chicken/resdata.pt")
        sde = torch.load(".data/chicken/checkpoint.pt")

        y = (
            sdeint(
                sde,
                x[::5].reshape(-1, 1, 20).repeat(1, 100, 1).reshape(-1, 20),
                # x.repeat(100, 1),
                t[:5],
                dt=t[1] * 1e-1,
            )
            .reshape(5, -1, 100, 20)
            .transpose(0, 1)
            .reshape(-1, 100, 20)
        )
        # y = y.clamp(min=0)

        fig, axs = plt.subplots(4, 5)
        for i, ax in zip(range(x.size(-1)), (ax for row in axs for ax in row)):
            yi = y[:, :, i]
            mean = yi.mean(dim=1).numpy()
            std = yi.std(dim=1).numpy()

            ax: plt.Axes
            ax.plot(t, x[:, i], "r")
            ax.plot(t, mean, "b")
            ax.fill_between(t, mean + 1 * std, mean - 1 * std, color="blue", alpha=0.3)
            # ax.set_ylim(0, x.max())
            # ax.set_xlim(0, 60)
        plt.show()


if __name__ == "__main__":
    # main()
    torch.manual_seed(1)
    sde = MlpSde([20, 128, 128, 20])
    for i, deltatime in enumerate((1e-2, 1e-1, 1e0, 1e1, 1e2)):
        from os.path import exists

        fname = f".data/chicken/var-path/{i}.pt"
        if not exists(fname):
            print(f"computing grads -- path; delta = {deltatime}")
            grads = sample_grads(sde, deltatime=deltatime)
            torch.save(
                {
                    "grads": grads,
                    "deltatime": deltatime,
                    "method": "pathint",
                },
                fname,
            )
            print("done")

        fname = f".data/chicken/var-adj/{i}.pt"
        if not exists(fname):
            print(f"computing grads -- adj; delta = {deltatime}")
            grads = sample_grads(sde, deltatime=deltatime, use_sdeint=True)
            torch.save(
                {
                    "grads": grads,
                    "deltatime": deltatime,
                    "method": "sdeint",
                },
                fname,
            )
            print("done")
    plotvars(4)
