"""
CODE ADAPTED FROM: https://github.com/RameenAbdal/StyleFlow

"""
import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint as odeint_normal
from torch.autograd import Variable
#from TorchDiffEqPack import odesolve_adjoint
#from TorchDiffEqPack import odesolve_adjoint_sym12
#from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController

__all__ = ["CNF", "SequentialFlow"]


class SequentialFlow(nn.Module):
    """A generalized nn.Sequential container for normalizing flows."""

    def __init__(self, layer_list):
        super(SequentialFlow, self).__init__()
        self.chain = nn.ModuleList(layer_list)

    def forward(self, x, context, logpx=None, reverse=False, inds=None, integration_times=None, deBug=False, parent_forcing=False, addNoise=True):
        if inds is None:
            if reverse:
                inds = range(len(self.chain) - 1, -1, -1)
            else:
                inds = range(len(self.chain))

        if logpx is None:
            for i in inds:
                x = self.chain[i](
                    x=x, 
                    context=context, 
                    logpx=logpx, 
                    integration_times=integration_times, 
                    reverse=reverse,
                    deBug = deBug,
                    parent_forcing=parent_forcing,
                    addNoise = addNoise
                    )
            return x
        else:
            for i in inds:

                x, logpx, regTerm1, regTerm2 = self.chain[i](
                    x=x, 
                    context=context, 
                    logpx=logpx, 
                    integration_times=integration_times, 
                    reverse=reverse,
                    deBug = deBug,
                    parent_forcing=parent_forcing,
                    addNoise = addNoise
                    )

            return x, logpx, regTerm1, regTerm2


class CNF(nn.Module):
    def __init__(self, odefunc, obs_dim, conditional=True, T=1.0, train_T=False, regularization_fns=None,
                 solver='dopri8', atol=1e-4, rtol=1e-4, use_adjoint=True): 
        super(CNF, self).__init__()
        self.train_T = train_T
        self.T = T
        if train_T:
            self.register_parameter("sqrt_end_time", nn.Parameter(torch.sqrt(torch.tensor(T))))
            print("Training T :", self.T)

        if regularization_fns is not None and len(regularization_fns) > 0:
            raise NotImplementedError("Regularization not supported")
        self.use_adjoint = use_adjoint
        self.odefunc = odefunc
        self.solver = solver
        self.atol = atol
        self.rtol = rtol
        self.test_solver = solver
        self.test_atol = atol
        self.test_rtol = rtol
        self.solver_options = {}
        self.conditional = conditional
        self.obs_dim = obs_dim
    
    def make_norm(self, state):
        state_size = state.numel()
        print(state_size)
        def norm(aug_state):
            y = aug_state[1:1 + state_size]
            adj_y = aug_state[1 + state_size:1 + 2 * state_size]
            return max(rms_norm(y), rms_norm(adj_y))
        return norm

    def forward(self, x, context=None, logpx=None, integration_times=None, reverse=False, deBug=False, parent_forcing=False, addNoise=True):
        #if logpx is None:
            
        _logpx = torch.zeros(*x.shape[:-1], 1).to(x)
        #else:
        #_logpx = logpx

        if integration_times is None:
            print('OBS!!! integration times None')
            if self.train_T:
                integration_times = torch.stack(
                    [torch.tensor(0.0).to(x), self.sqrt_end_time * self.sqrt_end_time]
                ).to(x)
            else:
                integration_times = torch.tensor([0., self.T], requires_grad=False).to(x)

        self.odefunc.before_odeint(context, integration_times, parent_forcing=parent_forcing, deBug=deBug, logpx=logpx, addNoise=addNoise)
        if reverse:
            integration_times = _flip(integration_times.clone(), 0)

        #self.odefunc.before_odeint(deBug)
        odeint = odeint_adjoint if self.use_adjoint else odeint_normal

        if deBug:
            print(integration_times)

        method= self.solver
        self.solver_options["jump_t"] = integration_times
        #self.solver_options["step_t"] = integration_times

        self.adjoint_options = {}
        #self.adjoint_options["norm"] = "seminorm" #hey that's not an ode makes it really slow, why?

        #self.solver_options["adjoint_params"] = tuple(self.odefunc.parameters())
    

        

        if logpx == None: #self.training:
            #s = odesolve_adjoint_sym12(self.odefunc, (x, context), options = options)
            atol = self.atol * 2
            rtol = self.rtol * 2
            
            s = odeint(
                self.odefunc,
                (x,),
                integration_times,
                atol=atol,
                rtol=rtol,
                method=method,
                options=self.solver_options,
                #adjoint_options=dict(norm=make_norm((x, context)))
                #adjoint_options = self.adjoint_options 
            )
            
        else: 
            regTerm1 = torch.zeros(*x.shape[:-1], 1).to(x)
            regTerm2 = torch.zeros(*x.shape[:-1], 1).to(x)
            atol = self.atol * 5
            rtol = self.rtol * 5
            
            s = odeint(
               self.odefunc,
                (x, _logpx, regTerm1, regTerm2), # regTerm1, regTerm2),
                integration_times,
                atol= atol, #[self.atol, self.atol, 1e20, 1e20, 1e20],# + [1e20] * len(context) + [1e20] * len(regTerm1) + [1e20] * len(regTerm2) if self.solver in ['dopri5', 'bosh3'] else self.atol,
                rtol= rtol, #[self.rtol, self.rtol, 1e20, 1e20, 1e20], # + [1e20] * len(context) + [1e20] * len(regTerm1) + [1e20] * len(regTerm2) if self.solver in ['dopri5', 'bosh3'] else self.rtol,
                method=method,
                options=self.solver_options,
                #adjoint_options=dict(norm=make_norm((x, context)))
                #adjoint_options = self.adjoint_options
            )

            regTerm1 = s[-2][-1,:,:]
            regTerm2 = s[-1][-1,:,:]
            _logpx = s[-3]

        #x = s[0][-1,:,:]
        #_logpx = s[1][-1,:,:]
        if logpx is not None:

            return s[0].transpose(0, 1), _logpx.transpose(0, 1), regTerm1, regTerm2 
        else:
            return s[0].transpose(0, 1) 


    def num_evals(self):
        return self.odefunc._num_evals()


def _flip(x, dim):
    indices = [slice(None)] * x.dim()
    indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device)
    return x[tuple(indices)]
