import abc
import torch
import torchcde
import torchsde

from . import common


_names = dict(drift='drift', diffusion='diffusion')


###############################################
#  Vector fields for generator+discriminator  #
#  These each define a drift and diffusion.   #
###############################################

class _SDEFunc(torch.nn.Module):
    def __init__(self, spec, drift, diffusion):
        super(_SDEFunc, self).__init__()
        self.noise_type = 'general'
        self.sde_type = 'stratonovich'
        self.base = spec  # for propagating .parameters() etc.
        self.drift = drift
        self.diffusion = diffusion


class _SDESpec(torch.nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def drift(self, *args, **kwargs):
        raise NotImplementedError

    @abc.abstractmethod
    def diffusion(self, *args, **kwargs):
        raise NotImplementedError

    # This is used for setting/resetting any inputs that aren't from the evolving state, for example class labels when
    # doing a conditional GAN.
    # Note that when using sdeint_adjoint that gradients will _not_ be computed wrt any arguments here. Which is fine
    # for us, it's just a limitation of how we've got this set up at the moment.
    def contextualise(self, *args, **kwargs):
        return _SDEFunc(self, self.drift(*args, **kwargs), self.diffusion(*args, **kwargs))


# Generator is a Neural SDE, and has both drift and diffusion
class _GeneratorFunc(_SDESpec):
    def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_layers, noise_channels,
                 label_channels, **kwargs):
        super(_GeneratorFunc, self).__init__(**kwargs)

        self._input_channels = input_channels
        self._hidden_channels = hidden_channels
        self._noise_channels = noise_channels

        # Here we take the drift and diffusion functions to both be MLPs of the same size
        self._drift = common.MLP(1 + hidden_channels + label_channels, hidden_channels, hidden_hidden_channels,
                                 num_layers, tanh=True)
        self._diffusion = common.MLP(1 + hidden_channels + label_channels,
                                     (input_channels + hidden_channels) * noise_channels,
                                     hidden_hidden_channels, num_layers, tanh=True)
        # Plus we have some linear readout from the evolving hidden state; formally speaking this is actually part of
        # the drift.
        self._readout = torch.nn.Linear(hidden_channels, input_channels)

        # Increase learning rate for final linear map
        self._readout.weight.register_hook(lambda grad: 100 * grad)
        self._readout.bias.register_hook(lambda grad: 100 * grad)

    def drift(self, y):
        def drift_(t, h):
            # t has shape ()
            # h has shape (batch_size, input_channels + hidden_channels)
            #   Its first input_channels channels correspond to the generated sample; the next hidden_channels channels
            #   correspond to its evolving hidden state (used so that it can be non-Markovian).

            batch_size = h.size(0)
            h = h[:, self._input_channels:]  # take just the hidden state; ignore the generated sample
            t = t.expand(batch_size, 1)
            thy = torch.cat([t, h, y], dim=1)
            drift = self._drift(thy)
            return torch.cat([self._readout(drift), drift], dim=1)
        return drift_

    def diffusion(self, y):
        def diffusion_(t, h):
            # t has shape ()
            # h has shape (batch_size, input_channels + hidden_channels)
            #   It's first input_channels channels correspond to the generated sample; the next hidden_channels channels
            #   correspond to its evolving hidden state (used so that it can be non-Markovian).

            batch_size = h.size(0)
            h = h[:, self._input_channels:]  # take just the hidden state; ignore the generated sample
            t = t.expand(batch_size, 1)
            thy = torch.cat([t, h, y], dim=1)
            return self._diffusion(thy).view(batch_size, self._input_channels + self._hidden_channels,
                                             self._noise_channels)
        return diffusion_


# Discriminator is actually a Neural CDE.
class _DiscriminatorFunc(_SDESpec):
    def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_layers, noise_channels,
                 label_channels, **kwargs):
        super(_DiscriminatorFunc, self).__init__(**kwargs)

        self._noise_channels = noise_channels

        # Inputs are time, control, hidden state, label
        self._drift = common.MLP(1 + hidden_channels + input_channels + label_channels, hidden_channels,
                                 hidden_hidden_channels, num_layers, tanh=True)
        # No diffusion because this is a Neural CDE

        # Plus we have some linear readout from the evolving hidden state; formally speaking this is actually part of
        # the drift. The terminal value of this is then our real/fake score
        self._readout = torch.nn.Linear(hidden_channels, 1)

        # Increase learning rate for final linear map
        self._readout.weight.register_hook(lambda grad: 100 * grad)
        self._readout.bias.register_hook(lambda grad: 100 * grad)

    def drift(self, X, y):
        def drift_(t, h, x=None):
            # t has shape ()
            # h has shape (batch_size, 1 + hidden_channels)
            #   The first channel is the discriminator's real/fake score for the sample. (So the terminal value of that
            #   channel is the overall score for the whole sample). The other hidden_channels channels are the hidden
            #   state, just as in a usual Neural CDE or RNN.
            h = h[:, 1:]

            if x is None:
                x = X(t)
            # X is the driving control.
            # If we're running the discriminator on real data then this can be computed as a function of time.
            # If we're running the discriminator on fake data (as we go along) then this is passed as input.

            # (*)
            # Note that we're inputting x into this NCDE as:
            # dz = f(..., x(t)) dt
            # rather than
            # dz = f(...) dt + g(...) dx(t)
            # as in the paper.
            # (The former is a special case of the latter, see Appendix C of the neural CDE paper.)
            # There's no real reason for this; just making the observation here in case anyone is trying to follow this
            # code and is wondering.
            thyx = torch.cat([t.expand(h.size(0), 1), h, x, y], dim=1)

            drift = self._drift(thyx)
            return torch.cat([self._readout(drift), drift], dim=1)
        return drift_

    def diffusion(self, X, y):
        # Yep, zero diffusion. This doesn't contradict what's in the paper, see the above note (*).
        def diffusion_(t, h):
            return torch.zeros_like(h).unsqueeze(-1).repeat(1, 1, self._noise_channels)
        return diffusion_


# To be able to evaluate the function composition discriminator(generator(noise)) in a memory efficient manner, we
# combine the generator and discriminator into a joint system of differential equations.
class _GeneratorDiscriminatorFunc(_SDESpec):
    def __init__(self, input_channels, generator_hidden_channels, generator_func, discriminator_func, **kwargs):
        super(_GeneratorDiscriminatorFunc, self).__init__(**kwargs)

        self._input_channels = input_channels
        self._generator_hidden_channels = generator_hidden_channels

        self._generator_func = generator_func
        self._discriminator_func = discriminator_func

    # Splits the overall state into the generator and discriminator's individual states
    def split(self, h):
        # h contains the generator's output, the generator's overall state, and the discriminator's hidden channels
        split = self._input_channels + self._generator_hidden_channels
        x = h[:, :self._input_channels]
        generator_h = h[:, :split]
        discriminator_h = h[:, split:]
        return x, generator_h, discriminator_h

    def drift(self, y):
        generator_drift = self._generator_func.drift(y)
        discriminator_drift = self._discriminator_func.drift(None, y)

        def drift_(t, h):
            x, generator_h, discriminator_h = self.split(h)
            g = generator_drift(t, generator_h)
            d = discriminator_drift(t, discriminator_h, x)
            return torch.cat([g, d], dim=1)
        return drift_

    def diffusion(self, y):
        generator_diffusion = self._generator_func.diffusion(y)
        discriminator_diffusion = self._discriminator_func.diffusion(None, y)

        def diffusion_(t, h):
            _, generator_h, discriminator_h = self.split(h)
            g = generator_diffusion(t, generator_h)
            d = discriminator_diffusion(t, discriminator_h)
            return torch.cat([g, d], dim=1)
        return diffusion_


###############################################
# Wrap up the drift/diffusions into something #
#       that evaluates torchsde.sdeint        #
###############################################


class _SDE(torch.nn.Module, metaclass=abc.ABCMeta):
    # Creates the initial state of the system
    @abc.abstractmethod
    def initial(self, *a, **w):
        raise NotImplementedError

    # Solves the SDE
    @abc.abstractmethod
    def forward(self, *a, **kw):
        raise NotImplementedError


class _Generator(_SDE):
    def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_layers, noise_channels,
                 initial_noise_channels, label_channels, adjoint, method, adaptive, **kwargs):
        super(_Generator, self).__init__(**kwargs)

        self._hidden_channels = hidden_channels
        self._noise_channels = noise_channels
        self._initial_noise_channels = initial_noise_channels
        self._adjoint = adjoint
        self._method = method
        self._adaptive = adaptive

        self._initial = common.MLP(initial_noise_channels, input_channels, hidden_hidden_channels, num_layers)
        self._func = _GeneratorFunc(input_channels, hidden_channels, hidden_hidden_channels, num_layers, noise_channels,
                                    label_channels)

    def initial(self, y, seed):
        batch_size = y.size(0)
        generator = torch.Generator(y.device).manual_seed(seed)
        z = torch.randn(batch_size, self._initial_noise_channels, dtype=y.dtype, device=y.device, generator=generator)
        # Initialise the output at something we generate as a function of random noise.
        initial_out = self._initial(z)
        # Initialise the hidden state at zero.
        initial_hidden = torch.zeros(batch_size, self._hidden_channels, dtype=y.dtype, device=y.device)
        return torch.cat([initial_out, initial_hidden], dim=1)

    def forward(self, t, y, seed):
        # t has shape (seq_len,)
        # y has shape (batch_size,)
        batch_size = y.size(0)

        y0 = self.initial(y, seed)
        # Note that BrownianInterval is only deterministic wrt its seed _and_ the choice and order of its query points.
        # (i.e. querying points in a different order will produce a different Brownian motion sample.)
        # As we're using a non-adaptive method then the query points come in a single fixed order, so that's not a
        # problem.
        # But if you, dear reader, happen to be using an adaptive method, and want determinism wrt your choice of
        # Brownian motion, then use BrownianTree instead. (Which is much slower, which is why we don't use it here when
        # we know we can get away without it.)
        levy_area_approximation = 'foster' if self._method == 'log_ode' else 'none'
        bm = torchsde.BrownianInterval(t0=t[0], t1=t[-1], shape=(batch_size, self._noise_channels), dtype=y.dtype,
                                       device=y.device, entropy=seed,  levy_area_approximation=levy_area_approximation)
        func = self._func.contextualise(y)
        sdeint = torchsde.sdeint_adjoint if self._adjoint else torchsde.sdeint
        if self._adaptive:
            kwargs = dict(adaptive=True)
        else:
            kwargs = dict(dt=1.0)
        out = sdeint(func, y0, t, bm=bm, names=_names, method=self._method, **kwargs)

        return out.transpose(0, 1)


class _Discriminator(_SDE):
    def __init__(self, input_channels, hidden_channels, hidden_hidden_channels, num_layers, noise_channels,
                 label_channels, adjoint, adaptive, **kwargs):
        super(_Discriminator, self).__init__(**kwargs)

        self._noise_channels = noise_channels
        self._adjoint = adjoint
        self._adaptive = adaptive

        self._initial = common.MLP(input_channels, hidden_channels, hidden_hidden_channels, num_layers)
        self._func = _DiscriminatorFunc(input_channels, hidden_channels, hidden_hidden_channels, num_layers,
                                        noise_channels, label_channels)

    def initial(self, X0):
        batch_size = X0.size(0)
        initial_loss = torch.zeros(batch_size, 1, dtype=X0.dtype, device=X0.device)
        initial_hidden = self._initial(X0)
        return torch.cat([initial_loss, initial_hidden], dim=1)

    def forward(self, t, x, y):
        X = torchcde.LinearInterpolation(x, t).evaluate  # set the data process X to be linear interpolation of true data

        y0 = self.initial(X(t[0]))
        # In some sense this isn't really solving an SDE, as we have zero diffusion. We could equally well be using the
        # torchcde library. We stick to using torchsde instead, just to be 100% certain that this is doing the same
        # thing as the combined generator-discriminator.
        func = self._func.contextualise(X, y)
        # midpoint = log-ODE for zero diffusion
        sdeint = torchsde.sdeint_adjoint if self._adjoint else torchsde.sdeint
        if self._adaptive:
            kwargs = dict(adaptive=True)
        else:
            kwargs = dict(dt=1.0)
        out = sdeint(func, y0, t[[0, -1]], names=_names, method='midpoint', **kwargs)

        out = out[1]  # take just the final time
        out = out[:, 0]  # take just the loss (and ignore the hidden state)

        return out


class _GeneratorDiscriminator(_SDE):
    def __init__(self, input_channels, generator_hidden_channels, noise_channels, generator, discriminator, adjoint,
                 method, adaptive, **kwargs):
        super(_GeneratorDiscriminator, self).__init__(**kwargs)

        self._input_channels = input_channels
        self._noise_channels = noise_channels
        self._adjoint = adjoint
        self._method = method
        self._adaptive = adaptive

        self._generator = generator
        self._discriminator = discriminator
        self._func = _GeneratorDiscriminatorFunc(input_channels, generator_hidden_channels,
                                                 generator._func, discriminator._func)

    def __repr__(self):
        return "{}()".format(type(self).__name__)

    def initial(self, y, seed):
        g = self._generator.initial(y, seed)  # consists of initial output, and initial hidden state
        X0 = g[:, :self._input_channels]  # get just the initial output to pass to discriminator.initial.
        d = self._discriminator.initial(X0)
        return torch.cat([g, d], dim=1)

    def forward(self, t, y, seed):
        batch_size = y.size(0)

        y0 = self.initial(y, seed)
        levy_area_approximation = 'foster' if self._method == 'log_ode' else 'none'
        bm = torchsde.BrownianInterval(t0=t[0], t1=t[-1], shape=(batch_size, self._noise_channels), dtype=y.dtype,
                                       device=y.device, entropy=seed, levy_area_approximation=levy_area_approximation)
        func = self._func.contextualise(y)
        sdeint = torchsde.sdeint_adjoint if self._adjoint else torchsde.sdeint
        if self._adaptive:
            kwargs = dict(adaptive=True)
        else:
            kwargs = dict(dt=1.0)
        out = sdeint(func, y0, t[[0, -1]], bm=bm, names=_names, method=self._method, **kwargs)

        out = out[1]  # take just the final time
        out = self._func.split(out)[2]  # [2] to take just the discriminator output (and not the generator output)
        out = out[:, 0]  # take just the loss (and not the discriminator's hidden state)

        return out


###############################################
#   Wrap up the generator and discriminator   #
#          into an overall interface          #
###############################################

class NeuralSDE(common.GAN):
    def __init__(self, input_channels,
                 generator_hidden_channels, generator_hidden_hidden_channels, generator_num_layers,
                 discriminator_hidden_channels, discriminator_hidden_hidden_channels, discriminator_num_layers,
                 noise_channels, initial_noise_channels, label_channels, lipschitz, adjoint, method, adaptive):
        super(NeuralSDE, self).__init__()
        assert method in ('midpoint', 'log_ode')

        self._input_channels = input_channels

        self._generator = _Generator(input_channels, generator_hidden_channels, generator_hidden_hidden_channels,
                                     generator_num_layers, noise_channels, initial_noise_channels, label_channels,
                                     adjoint, method, adaptive)
        self._discriminator = _Discriminator(input_channels, discriminator_hidden_channels,
                                             discriminator_hidden_hidden_channels, discriminator_num_layers,
                                             noise_channels, label_channels, adjoint, adaptive)
        self._combined = _GeneratorDiscriminator(input_channels, generator_hidden_channels, noise_channels,
                                                 self._generator, self._discriminator, adjoint, method, adaptive)

        if 'spectral' in lipschitz:
            common.spectral_norm(self._discriminator, lipschitz['spectral'])

        self._lipschitz = lipschitz
        self._adjoint = adjoint
        self._method = method
        self._adaptive = adaptive
        self._noise_channels = noise_channels

    def extra_repr(self):
        return ("lipschitz={}, adjoint={}, method={}, adaptive={}, noise_channels={}"
                .format(repr(self._lipschitz), repr(self._adjoint), repr(self._method), repr(self._adaptive),
                        self._noise_channels))

    def train_generator(self, t, y, seed=None):
        if seed is None:
            seed = self.generate_seed()

        score_fake = self._combined(t, y, seed)
        # out has shape (batch_size,) and is the discriminator's score for each sample
        return score_fake.mean()

    def train_discriminator(self, t, x, y, penalty, seed=None):
        if seed is None:
            seed = self.generate_seed()

        score_real = self._discriminator(t, x, y)
        score_fake = self._combined(t, y, seed)
        loss = score_fake.mean() - score_real.mean()

        penalty_ = 0
        if penalty and 'gp' in self._lipschitz:
            mult = self._lipschitz['gp']
            fake = self.generate_sample(t, y)
            gp_penalty = common.gp_penalty((fake,), (x,), lambda x: self._discriminator(t, x, y))
            penalty_ = mult * gp_penalty

        return loss, penalty_

    def generate_sample(self, t, y, seed=None):
        if seed is None:
            seed = self.generate_seed()

        return self._generator(t, y, seed)[..., :self._input_channels]
