import math

import torch
from torch import nn, Tensor
from torchsde import SDEIto

from .brownianbridge import BrownianBridge, WienerProcess


class PathInt:
    def __init__(self, bridge: BrownianBridge, f: nn.Module):
        self.bridge = bridge
        self.f = f

        # ..... um ..... yeah, bad code :( but I'm lazy
        self.dts = bridge.wiener.dts[1:]

    def log_prob(self):
        _, s = self.generate_samples()

        return (
            self.bridge.wiener.log_prob(
                self.bridge.obs_times[1:] - self.bridge.obs_times[0],
                self.bridge.obs[..., 1:, :] - self.bridge.obs[..., 0:1, :],
            )
            + s.logsumexp(dim=0)
            - math.log(s.size(0))
        )

    def generate_samples(self):
        sigma_inv_sqr = 1 / (self.bridge.wiener.sigma ** 2)

        ys = self.bridge.rsample().flatten(
            end_dim=max(len(self.bridge.batch_shape) - 1, 0)
        )
        dy = ys[:, 1:] - ys[:, :-1]

        fs = self.f(self.bridge.times[:-1], ys[:, :-1])
        alpha = (fs * sigma_inv_sqr * dy).flatten(start_dim=1).sum(dim=-1)
        beta = (fs * sigma_inv_sqr * fs * self.dts).flatten(start_dim=1).sum(dim=-1)

        s = (alpha - 0.5 * beta).reshape(self.bridge.batch_shape)

        return ys.reshape(self.bridge._extended_shape()), s


class PathIntLoss(nn.Module):
    def __init__(self, sde, obs_times, obs, dt, n_samples=100):
        super().__init__()
        self.piis = PathInt(
            BrownianBridge(
                times=torch.arange(obs_times.min(), obs_times.max(), dt).to(obs_times),
                obs_times=obs_times,
                obs=obs.transpose(0, 1),
                batch_shape=torch.Size([n_samples]),
            ),
            sde.f,
        )

    def forward(self, sde, obs_times, obs):
        # I should really make this code better...
        self.piis.bridge.wiener.sigma = sde.sigma
        return -self.piis.log_prob().sum() / obs.numel()


if __name__ == "__main__":
    import numpy as np
    from torch.distributions import Normal
    from torchsde import sdeint, SDEIto

    class OrnUhl(nn.Module):
        def forward(self, t, x):
            return -x

    class OUSde(SDEIto):
        def __init__(self):
            super().__init__("diagonal")
            self.f = OrnUhl()

        def g(self, t, x):
            return torch.ones_like(x)

    times = torch.linspace(0, 10, 1000)
    obs_times = torch.linspace(0, 10, 40)

    n, prob, means, stddevs = [], [], [], []

    for dims in [1] + list(range(10, 101, 10)):
        obs = torch.sin(math.pi * obs_times / 10)
        obs = torch.stack([obs for _ in range(dims)], dim=-1)

        log_prob = sum(
            Normal(
                x0 * torch.exp(-(t1 - t0)),
                torch.sqrt(1 / 2 * (1 - torch.exp(-2 * (t1 - t0)))),
            )
            .log_prob(x1)
            .sum()
            .item()
            for t0, t1, x0, x1 in zip(obs_times[:-1], obs_times[1:], obs[:-1], obs[1:])
        )
        # sqr_dev = sum(
        #     dims / 2 * (1 - torch.exp(-2 * (t))) + (x * x).sum()
        #     for t, x in zip(obs_times[1:], obs[1:])
        # )

        # def err(samples, batches):
        #     ys = sdeint(OUSde(), obs[:1].repeat(samples * batches, 1), obs_times)
        #     delta = obs[1:].unsqueeze(1) - ys[1:]
        #     return (
        #         (delta * delta)
        #         .reshape(-1, samples, batches, dims)
        #         .sum(-1)
        #         .sum(0)
        #         .mean(-1)
        #     ).numpy()

        # sqr_error = err(100, 100)

        bridge = BrownianBridge(
            times,
            obs_times,
            obs,
            batch_shape=torch.Size([100]),
        )

        pi = PathInt(bridge, OrnUhl())
        p = [pi.log_prob().item() for _ in range(100)]

        n.append(dims)
        prob.append(log_prob)
        means.append(np.mean(p))
        stddevs.append(np.std(p))
        # prob.append(sqr_dev)
        # means.append(np.mean(sqr_error))
        # stddevs.append(np.std(sqr_error))
        print(f"{dims} & {prob[-1]:.2f} & {means[-1]:.2f} & {stddevs[-1]:.2f} \\\\")

    torch.save(
        dict(
            n=n,
            prob=prob,
            means=means,
            stddevs=stddevs,
        ),
        f".data/ou-{len(obs_times)}.pt",
    )


def plotall():
    import numpy as np
    import matplotlib.pyplot as plt

    def plot(fname, color="blue", label=None, nstds=3):
        obj = torch.load(fname)
        n = np.array(obj["n"])
        p = np.array(obj["prob"])
        m = np.array(obj["means"])
        s = np.array(obj["stddevs"])

        plt.plot(n, p, color=color, linestyle="dashed")
        plt.plot(n, m, color=color, linestyle="solid", label=label)
        plt.fill_between(
            n,
            m + nstds * s,
            m - nstds * s,
            color=color,
            alpha=0.2,
        )

    for f, c in ((4, "blue"), (10, "green"), (20, "orange"), (30, "cyan"), (40, "red")):
        plot(f".data/ou-{f}.pt", c, label=f"N = {f}")

    # plt.xscale("log")
    # plt.yscale("log")
    plt.xlabel("dimension")
    plt.ylabel("log-prob")
    plt.title("Log probability estimates vs ground truth")
    plt.legend()
    plt.show()
    # plt.savefig(".data/figures/ou-logprobs.png", bbox_inches="tight")
