import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint as odeint

from enflows.CNF.neural_odes.wrappers.cnf_regularization import RegularizedODEfunc

__all__ = ["CNF"]


class CNF(nn.Module):
    def __init__(self, odefunc, T=1.0, train_T=False, regularization_fns=None, solver='dopri5', atol=1e-5, rtol=1e-5):
        super(CNF, self).__init__()
        if train_T:
            self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
        else:
            self.register_buffer("sqrt_end_time", torch.sqrt(torch.tensor(T)))

        nreg = 0
        if regularization_fns is not None:
            odefunc = RegularizedODEfunc(odefunc, regularization_fns)
            nreg = len(regularization_fns)
        self.odefunc = odefunc
        self.nreg = nreg
        self.regularization_states = None
        self.solver = solver
        self.atol = atol
        self.rtol = rtol
        self.test_solver = solver
        self.test_atol = atol
        self.test_rtol = rtol
        self.solver_options = {}

    def forward(self, z, logpz=None, integration_times=None, reverse=False):

        if logpz is None:
            _logpz = torch.zeros(z.shape[0], 1).to(z)
        else:
            _logpz = logpz

        if integration_times is None:
            integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
        if reverse:
            integration_times = _flip(integration_times, 0)

        # Refresh the odefunc statistics.
        self.odefunc.before_odeint()

        # Add regularization states.
        reg_states = tuple(torch.tensor(0).to(z) for _ in range(self.nreg))

        if self.training:
            state_t = odeint(
                self.odefunc,
                (z, _logpz) + reg_states,
                integration_times.to(z),
                atol=self.atol,
                rtol=self.rtol,
                method=self.solver,
                options=self.solver_options,
                adjoint_options={"norm": "seminorm"}
                # step_size = self.solver_options["step_size"]
            )
        else:
            state_t = odeint(
                self.odefunc,
                (z, _logpz),
                integration_times.to(z),
                atol=self.test_atol,
                rtol=self.test_rtol,
                method=self.test_solver,
                adjoint_options={"norm": "seminorm"}
                # step_size=self.solver_options["step_size"]
            )

        if len(integration_times) == 2:
            state_t = tuple(s[1] for s in state_t)

        z_t, logpz_t = state_t[:2]
        self.regularization_states = state_t[2:]

        if logpz is not None:
            return z_t, logpz_t
        else:
            return z_t

    def get_regularization_states(self):
        reg_states = self.regularization_states
        self.regularization_states = None
        return reg_states

    def num_evals(self):
        return self.odefunc._num_evals.item()


class CompactCNF(nn.Module):
    def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5,
                 divergence_fn="approximate"):
        super(CompactCNF, self).__init__()
        assert divergence_fn in ("brute_force", "approximate")

        nreg = 0

        self.diffeq = dynamics_network
        self.nreg = nreg
        self.solver = solver
        self.atol = atol
        self.rtol = rtol
        self.test_solver = solver
        self.test_atol = atol
        self.test_rtol = rtol
        self.solver_options = {}
        self.rademacher = True

        if divergence_fn == "brute_force":
            self.divergence_fn = divergence_bf
        elif divergence_fn == "approximate":
            self.divergence_fn = divergence_approx

        self.register_buffer("_num_evals", torch.tensor(0.))
        self.before_odeint()

    def before_odeint(self, e=None):
        self._e = e
        self._num_evals.fill_(0)

    def num_evals(self):
        return self._num_evals.item()

    def forward(self, t, states):
        assert len(states) >= 2
        y = states[0]

        # increment num evals
        self._num_evals += 1

        # convert to tensor
        if not torch.is_tensor(t):
            t = torch.tensor(t).type_as(y)
        else:
            t = t.type_as(y)
        batchsize = y.shape[0]

        # Sample and fix the noise.
        if self._e is None:
            if self.rademacher:
                self._e = sample_rademacher_like(y)
            else:
                self._e = sample_gaussian_like(y)

        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            t.requires_grad_(True)

            dy = self.diffeq(t, y)
            # Hack for 2D data to use brute force divergence computation.
            if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2:
                divergence = divergence_bf(dy, y).view(batchsize, 1)
            else:
                if self.training:
                    divergence = self.divergence_fn(dy, y, e=self._e).view(batchsize, 1)
                else:
                    divergence = divergence_bf(dy, y, e=self._e).view(batchsize, 1)

        return tuple([dy, -divergence])

    def integrate(self, z, logpz=None, integration_times=None, reverse=False):
        if logpz is None:
            _logpz = torch.zeros(z.shape[0], 1).to(z)
        else:
            _logpz = logpz

        if integration_times is None:
            integration_times = torch.tensor([0.0, self.sqrt_end_time * self.sqrt_end_time]).to(z)
        if reverse:
            integration_times = _flip(integration_times, 0)

        # Refresh the odefunc statistics.
        self.before_odeint()

        if self.training:
            state_t = odeint(
                self,
                (z, _logpz),
                integration_times.to(z),
                atol=self.atol,
                rtol=self.rtol,
                method=self.solver,
                options=self.solver_options,
                adjoint_options={"norm": "seminorm"}
                # step_size = self.solver_options["step_size"]
            )
        else:
            state_t = odeint(
                self,
                (z, _logpz),
                integration_times.to(z),
                atol=self.test_atol,
                rtol=self.test_rtol,
                method=self.test_solver,
                adjoint_options={"norm": "seminorm"}
                # step_size=self.solver_options["step_size"]
            )

        z_t, logpz_t = tuple(s[1] for s in state_t)

        return z_t, logpz_t


class CompactTimeVariableCNF(nn.Module):

    start_time = 0.0
    end_time = 1.0

    def __init__(self, dynamics_network, solver='dopri5', atol=1e-5, rtol=1e-5,
                 divergence_fn="approximate"):
        super(CompactTimeVariableCNF, self).__init__()
        assert divergence_fn in ("brute_force", "approximate")

        nreg = 0

        self.diffeq = dynamics_network
        self.nreg = nreg
        self.solver = solver
        self.atol = atol
        self.rtol = rtol
        self.test_solver = solver
        self.test_atol = atol
        self.test_rtol = rtol
        self.solver_options = {}
        self.rademacher = True

        if divergence_fn == "brute_force":
            self.divergence_fn = divergence_bf
        elif divergence_fn == "approximate":
            self.divergence_fn = divergence_approx

        self.register_buffer("_num_evals", torch.tensor(0.))
        self.before_odeint()

        self.odeint_kwargs = dict(
            train=dict(
                atol=self.atol,
                rtol=self.rtol,
                method=self.solver,
                options=self.solver_options,
                adjoint_options={"norm": "seminorm"}
            ),
            test=dict(
                atol=self.test_atol,
                rtol=self.test_rtol,
                method=self.test_solver,
                adjoint_options={"norm": "seminorm"}
            )
        )

    def integrate(self, t0, t1, z, logpz=None):

        _logpz = torch.zeros(z.shape[0], 1).to(z) if logpz is None else logpz
        initial_state = (t0, t1, z, _logpz)

        integration_times = torch.tensor([self.start_time, self.end_time]).to(t0)

        # Refresh the odefunc statistics.
        self.before_odeint(e=self.sample_e_like(z))

        self.get_odeint_kwargs()
        state_t = odeint(
            func=self,
            y0=initial_state,
            t=integration_times,
            **self.get_odeint_kwargs()
        )
        _, _,  z_t, logpz_t = tuple(s[-1] for s in state_t)

        return z_t, logpz_t

    def forward(self, s, states):
        assert len(states) >= 2
        t0, t1, y, _ = states
        ratio = (t1 - t0) / (self.end_time - self.start_time)

        # increment num evals
        self._num_evals += 1

        # Sample and fix the noise.
        if self._e is None:
            self._e = self.sample_e_like(y)

        with torch.set_grad_enabled(True):
            y.requires_grad_(True)
            t = (s - self.start_time) * ratio + t0
            dy = self.diffeq(t, y)
            dy = dy * ratio.reshape(-1, *([1] * (y.ndim - 1)))

            divergence = self.calculate_divergence(y, dy)

        return tuple([torch.zeros_like(t0), torch.zeros_like(t1), dy, -divergence])

    def sample_e_like(self, y):
        if self.rademacher:
            return sample_rademacher_like(y)
        else:
            return sample_gaussian_like(y)

    def calculate_divergence(self, y, dy):
        # Hack for 2D data to use brute force divergence computation.
        if not self.training and dy.view(dy.shape[0], -1).shape[1] == 2:
            divergence = divergence_bf(dy, y).view(-1, 1)
        else:
            if self.training:
                divergence = self.divergence_fn(dy, y, e=self._e).view(-1, 1)
            else:
                divergence = divergence_bf(dy, y, e=self._e).view(-1, 1)
        return divergence

    def get_odeint_kwargs(self):
        if self.training:
            return self.odeint_kwargs["train"]
        else:
            return self.odeint_kwargs["test"]

    def before_odeint(self, e=None):
        self._e = e
        self._num_evals.fill_(0)

    def num_evals(self):
        return self._num_evals.item()


def _flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
    return x[tuple(indices)]


def sample_rademacher_like(y):
    return torch.randint(low=0, high=2, size=y.shape).to(y) * 2 - 1


def sample_gaussian_like(y):
    return torch.randn_like(y)


def divergence_bf(dx, y, **unused_kwargs):
    sum_diag = 0.
    for i in range(y.shape[1]):
        sum_diag += torch.autograd.grad(dx[:, i].sum(), y, create_graph=True)[0].contiguous()[:, i].contiguous()
    return sum_diag.contiguous()


def divergence_approx(f, y, e=None):
    e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
    e_dzdx_e = e_dzdx * e
    approx_tr_dzdx = e_dzdx_e.view(y.shape[0], -1).sum(dim=1)
    return approx_tr_dzdx
