import numpy as np
import torch
from torch import Tensor, Size
from torch.distributions import Distribution, Normal


class WienerProcess(Distribution):
    def __init__(
        self,
        times: Tensor,
        sigma=1.0,
        batch_shape=torch.Size(),
        event_shape=torch.Size(),
        validate_args=None,
    ):
        super().__init__(
            batch_shape=batch_shape,
            event_shape=times.shape + event_shape,
            validate_args=validate_args,
        )
        self.dts = times.clone()
        self.dts[..., 1:] -= times[..., :-1].clone()
        self.dts = self.dts.reshape(self.dts.shape + Size([1] * len(event_shape)))
        self.feature_dims = len(event_shape)
        self.sigma = sigma

    def rsample(self, sample_shape=torch.Size()):
        dW = (
            self.dts.new_empty(self._extended_shape(sample_shape))
            .normal_(0, 1)
            .mul(self.sigma)
        )
        w = (self.dts.sqrt() * dW)
        # from time import time
        # start = time()
        w = w.cumsum(dim=-(self.feature_dims + 1))
        # print("cumsum:", f"{time() - start:.2e}")
        return w

    def log_prob(self, times: Tensor, values: Tensor):
        """We ignore values which have a dt of zero..."""
        assert len(times.shape) <= len(values.shape) - self.feature_dims
        dts: Tensor = times.clone()
        dts[..., 1:] -= times[..., :-1].clone()
        dts = dts.reshape(dts.shape + Size([1] * self.feature_dims))

        deltas: Tensor = values.clone().transpose(0, -(self.feature_dims + 1))
        deltas[1:] -= deltas[:-1].clone()
        deltas = deltas.transpose(0, -(self.feature_dims + 1))

        return (
            Normal(torch.zeros((), device=values.device), self.sigma * dts.sqrt())
            .log_prob(deltas)
            .reshape(*deltas.shape[: -(self.feature_dims + 1)], -1)
            .nansum(-1)
        )


class BrownianBridge(Distribution):
    """
    times and obs_times should be one dimensional tensors
    and obs should have shape [batches, len(obs_times), features]
    """
    def __init__(
        self,
        times: Tensor,
        obs_times: Tensor,
        obs: Tensor,
        sigma=1.0,
        wiener: WienerProcess = None,
        batch_shape=Size(),
        validate_args=None,
    ):
        feature_shape = Size([obs.size(-1)])
        batch_shape = batch_shape + obs.shape[:-2]
        super().__init__(
            batch_shape=batch_shape,
            event_shape=times.shape + feature_shape,
            validate_args=validate_args,
        )
        self.wiener = wiener or WienerProcess(
            times,
            sigma,
            batch_shape=batch_shape,
            event_shape=feature_shape,
        )

        self.times = times.reshape(times.shape + Size([1]))
        self.obs_times = obs_times
        self.obs = obs

        time_idx = torch.bucketize(obs_times, times, right=True).clamp(
            min=1,
            max=times.size(0) - 1,
        )
        indices = torch.bucketize(times, obs_times, right=True).clamp(
            min=1,
            max=obs_times.size(0) - 1,
        )
        idx0 = indices - 1

        self.t0 = self.obs_times[idx0].reshape(self.times.shape)
        self.deltaT = self.obs_times[indices].reshape(self.times.shape) - self.t0
        self.y0 = self.obs.index_select(-2, idx0)
        self.deltaY = self.obs.index_select(-2, indices) - self.y0
        self.w0_indices = time_idx[idx0]
        self.w1_indices = time_idx[indices]

    def rsample(self, sample_shape=torch.Size()):
        w: Tensor = self.wiener.rsample(sample_shape)

        w0 = w.index_select(-len(self.event_shape), self.w0_indices)
        wT = w.index_select(-len(self.event_shape), self.w1_indices)

        y = (
            self.y0
            + (w - w0)
            - (self.times - self.t0) / self.deltaT * (wT - w0 - self.deltaY)
        )
        return y


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    times = torch.linspace(0, 1, 1000)
    obs_times = torch.linspace(0, 1, 4)
    obs = WienerProcess(obs_times).sample()

    b = BrownianBridge(times, obs_times, obs.unsqueeze(-1))

    plt.plot(times.numpy(), b.sample(Size([1000])).squeeze(-1).t().numpy())
    plt.plot(obs_times.numpy(), obs.numpy(), "ro")
    plt.show()
