from os.path import join, exists, dirname
from os import makedirs

import numpy as np
import torch
from torch import nn, optim, Size
from gpytorch import means, kernels, mlls
from gpytorch.models import ApproximateGP
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.variational import (
    CholeskyVariationalDistribution,
    VariationalStrategy,
    IndependentMultitaskVariationalStrategy,
    GridInterpolationVariationalStrategy,
)

from .sde_systems import make_van_der_pol_paths, VanDerPol

device = "cuda" if torch.cuda.is_available() else "cpu"


class DriftField(ApproximateGP):
    def __init__(self, bounds, grid_size=32):
        n_dims = len(bounds)
        super().__init__(
            IndependentMultitaskVariationalStrategy(
                GridInterpolationVariationalStrategy(
                    self,
                    grid_size=grid_size,
                    grid_bounds=bounds,
                    variational_distribution=CholeskyVariationalDistribution(
                        int(pow(grid_size, n_dims)),
                        batch_shape=Size([n_dims]),
                    ),
                ),
                num_tasks=n_dims,
            )
        )
        self.bounds = bounds

        # self.mean = means.LinearMean(n_dims, Size([n_dims]))
        # self.mean = means.ConstantMean(batch_shape=Size([n_dims]))
        self.covar = kernels.GridInterpolationKernel(
            kernels.ScaleKernel(
                kernels.RBFKernel(batch_shape=Size([n_dims])),
                batch_shape=Size([n_dims]),
            ),
            grid_size=grid_size,
            grid_bounds=bounds,
        )
        self.log_sigma = nn.Parameter(torch.zeros(()))

    def __call__(self, x):
        bounds = x.new_tensor(self.bounds)
        bmin = bounds[:, 0]
        bmax = bounds[:, 1]
        return super().__call__(torch.max(torch.min(x, bmax), bmin))

    @property
    def sigma(self):
        return self.log_sigma.exp()
        # return 0.15

    def forward(self, x):
        # mean = self.mean(x)
        mean = x.new_zeros(x.size(-1), x.size(-2))
        covar = self.covar(x)
        mvn = MultivariateNormal(mean, covar)
        return mvn


class Likelihood:
    def __init__(self, times, obs_times, obs, drift_func: DriftField):
        from .piis import PathInt
        from .brownianbridge import BrownianBridge

        self.path_int = PathInt(
            BrownianBridge(
                times,
                obs_times,
                obs,
                batch_shape=Size([1]),
            ),
            self.sample_drift,
        )
        self.drift_func = drift_func

    def sample_drift(self, t, x):
        mvn: MultitaskMultivariateNormal = self.drift_func(x.reshape(-1, 2))

        f = mvn.rsample()

        return f.reshape(x.shape)

    def log_prob(self):
        self.path_int.bridge.wiener.sigma = self.drift_func.sigma
        return self.path_int.log_prob().sum()


def train(
    drift: DriftField,
    obs_times,
    obs,
    *,
    callback=None,
    steps=2000,
    **kwargs,
):
    from .training import Timer

    def loss_fn():
        return (
            -like.log_prob() + drift.variational_strategy.kl_divergence()
        ) / obs.numel()

    drift.train()
    like = Likelihood(
        torch.linspace(obs_times.min(), obs_times.max(), 1025, device=device),
        obs_times,
        obs,
        drift,
    )
    opt = optim.Adam(drift.parameters(), lr=1e-1)
    sched = optim.lr_scheduler.ExponentialLR(opt, gamma=0.999)

    times, losses = [], []
    timer = Timer()
    timer.start()
    for step in range(1, steps + 1):
        opt.zero_grad()

        loss = loss_fn()
        loss.backward()

        opt.step()
        sched.step()
        times.append(timer.current)
        losses.append(loss.item())

        if callback is not None and step % 10 == 0:
            with torch.no_grad():
                timer.stop()
                callback(step, times, losses)
                timer.start()

    drift.eval()


def make_data():
    fname = ".data/van_der_pol.pt"
    if exists(fname):
        data = torch.load(fname)
        return data["t"], data["x"]

    times = torch.linspace(0, 10, 100)
    # torchsde wants (time, batches, features)
    # but we need (batches, time, features)
    x = make_van_der_pol_paths(times, 64).transpose(0, 1)

    torch.save(
        {
            "t": times,
            "x": x,
        },
        fname,
    )

    return times, x


def main():
    np.random.seed(1)
    torch.manual_seed(1)
    t, x = make_data()
    t, x = t.to(device), x.to(device)
    bounds = [
        [x[..., 0].min().item() - 1, x[..., 0].max().item() + 1],
        [x[..., 1].min().item() - 1, x[..., 1].max().item() + 1],
    ]
    drift = DriftField(bounds).to(device)

    fname = ".data/van_der_pol/1.pt"
    norms = []
    vdp = VanDerPol().f(None, x.reshape(-1, 2))
    makedirs(dirname(fname), exist_ok=True)
    best = None

    def callback(step, times, losses):
        for n, p in drift.named_parameters():
            if p.grad is None or p.grad.isnan().any():
                print(n, "had no/nan grad")
        diff = vdp - drift(x.reshape(-1, 2)).mean
        norms.append(diff.square().mean() / vdp.square().mean())

        print(
            f"step [{step}], time [{int(times[-1])}], loss [{losses[-1]:.2f}], l2diff [{norms[-1]}]"
        )
        torch.save(
            {
                "times": times,
                "losses": losses,
                "l2diff": norms,
            },
            fname,
        )

        nonlocal best
        if best is None or norms[-1] <= best:
            best = norms[-1]

            torch.save(
                {
                    "norm": best,
                    "checkpoint": drift.state_dict(),
                },
                join(dirname(fname), "checkpoint.pt"),
            )

    train(
        drift,
        t,
        x,
        callback=callback,
        steps=100_000,
    )


def plot():
    t, x = make_data()
    bounds = [
        [x[..., 0].min().item() - 1, x[..., 0].max().item() + 1],
        [x[..., 1].min().item() - 1, x[..., 1].max().item() + 1],
    ]
    print(bounds)
    drift = DriftField(bounds).cpu()

    fname = ".data/van_der_pol/checkpoint.pt"
    obj = torch.load(fname, map_location="cpu")
    drift.load_state_dict(obj["checkpoint"])
    grid = drift.variational_strategy.base_variational_strategy.grid
    print(grid.min(0)[0], grid.max(0)[0])

    import matplotlib.pyplot as plt
    from matplotlib import cm
    plt.rcParams.update({'font.size': 15})

    drift.cpu()
    x = x.transpose(0, 1).cpu().numpy()
    # plt.plot(x[..., 0], x[..., 1], zorder=0, color="black", lw=0.2)

    # y, x = np.mgrid[
    #     bounds[1][0] + 0.5 : bounds[1][1] - 0.5 : 50j,
    #     bounds[0][0] + 0.5 : bounds[0][1] - 0.5 : 50j,
    # ]
    y, x = np.mgrid[
        x[..., 1].min().item() : x[..., 1].max().item() : 50j,
        x[..., 0].min().item() : x[..., 0].max().item() : 50j,
    ]
    with torch.no_grad():
        XY = torch.FloatTensor(np.stack((x, y), axis=-1)).reshape(-1, 2)
        print(XY.min(0)[0], XY.max(0)[0])
        mvn: MultitaskMultivariateNormal = drift(XY)
        print(mvn)
        print(mvn.loc.shape, mvn.covariance_matrix.shape)
        mean = mvn.mean
        std = mvn.stddev.sum(-1)
        print(std.min(), std.max())

    vdp = VanDerPol().f(None, torch.FloatTensor(np.stack((x, y), axis=-1)))  # .numpy()
    # print(vdp.shape, mean.shape, mvn.stddev.shape)
    errs = ((vdp.reshape(mean.shape) - mean) / mvn.stddev).reshape(-1).numpy()
    plt.xlim(
        np.quantile(errs, 0.05),
        np.quantile(errs, 0.95),
    )
    plt.plot([np.quantile(errs, 0.1)] * 2, [0, 0.25], color="red")
    plt.plot([np.quantile(errs, 0.25)] * 2, [0, 0.25], color="red")
    plt.plot([np.quantile(errs, 0.75)] * 2, [0, 0.25], color="red")
    plt.plot([np.quantile(errs, 0.90)] * 2, [0, 0.25], color="red")
    plt.hist(
        errs,
        bins=500,
        density=True,
    )
    plt.title("Deviation of predictive mean from ground truth")
    plt.ylabel("rel. freq.")
    plt.xlabel("standard deviations from mean")
    plt.savefig(".data/figures/gp-err.png", bbox_inches="tight")
    # plt.show()
    # exit()
    # plt.streamplot(
    #     x,
    #     y,
    #     vdp[..., 0],
    #     vdp[..., 1],
    #     color="grey",
    #     zorder=-1,
    # )
    plt.contourf(
        x,
        y,
        -std.log().numpy().reshape(x.shape),
        cmap=cm.Blues,
        alpha=0.8,
        zorder=-1,
    )
    plt.streamplot(
        x,
        y,
        mean[:, 0].reshape(x.shape),
        mean[:, 1].reshape(y.shape),
        color=-torch.log(1 + 0.8 * (std - std.min()) / std.max())
        .reshape(x.shape)
        .numpy(),
        linewidth=2 * (1 - std / std.max()).reshape(x.shape).numpy(),
        cmap=cm.Blues,
        zorder=1,
    )
    plt.xlim(x.min(), x.max())
    plt.ylim(y.min(), y.max())
    plt.savefig(".data/figures/vdp-gp.png")
    # plt.show()


if __name__ == "__main__":
    main()
    plot()
