from .modules import *
from .models import *
from .finde import *


SOLVER_LIST = [
    # available solvers in torchdiffeq
    'dopri5',
    'rk4',
    'midpoint',
    'euler',
    'test',
]
SOLVER_LIST_ADDITIONAL_EXPLICIT = [
    'leapfrog',
    'leapfrog2',
]

SOLVER_LIST_ADDITIONAL_IMPLICIT = [
    'dg',
    'implicit_midpoint',
]


class PhysicsModel(torch.nn.Module):

    def __init__(self, input_dim, hidden_dim, act='tanh', model='hnn', solver='dopri5', data_mean=None, data_std=None, finde=None):
        super(PhysicsModel, self).__init__()
        self.odeint = odeint
        self.model = model
        if ',' in solver:
            self.solver, self.solver_eval = solver.split(',')
        else:
            self.solver = self.solver_eval = solver
        self.check_combination()
        self.net = get_nn(input_dim, hidden_dim, act, model=model, data_mean=data_mean, data_std=data_std)
        self.finde = get_finde(finde, input_dim, hidden_dim, act, model=self.net, data_mean=data_mean, data_std=data_std) if finde else None
        self.reset_parameters(model, hidden_dim)

    def check_combination(self):
        available_models = ['node', 'sonode', 'sepsonode', 'lnn', 'lnn_origin', 'hnn', 'sephnn', 'kinhnn', 'nsf', 'massspring', 'chnn2pend']
        # assert self.model in available_models, 'model must be one of {} but \'{}\''.format(available_models, self.model)
        assert self.solver in SOLVER_LIST + SOLVER_LIST_ADDITIONAL_EXPLICIT + SOLVER_LIST_ADDITIONAL_IMPLICIT, 'no valid solver {}'.format(self.solver)
        assert self.solver_eval in SOLVER_LIST + SOLVER_LIST_ADDITIONAL_EXPLICIT + SOLVER_LIST_ADDITIONAL_IMPLICIT, 'no valid solver {}'.format(self.solver_eval)

    def reset_parameters(self,model,hidden_dim):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        if model=='lnn':
            Linears=[m for m in self.net.lagrangian if isinstance(m,nn.Linear)]
            sqrt_n=np.sqrt(hidden_dim)
            for i,m in enumerate(Linears):
                if i==0:
                    nn.init.normal_(m.weight, std=2.2/sqrt_n)
                elif i==len(Linears)-1:
                    nn.init.normal_(m.weight, std=sqrt_n)
                else:
                    nn.init.normal_(m.weight, std=0.58*i/sqrt_n)

    def forward(self, x1, *, dt=None, x2=None, func=None):
        if func is None:
            return self.net(x1, x2)
        if func == 'grad':
            return self.grad(x1)
        if func == 'discrete_grad':
            return self.discrete_grad(x1, x2)
        if func == 'time_derivative':
            return self.time_derivative(x1)
        if func == 'discrete_time_derivative':
            return self.discrete_time_derivative(x1, dt=dt, x2=x2)

    def hamiltonian(self, x1, x2=None):
        return self.net(x1, x2)

    def grad(self, x1):
        return self.net.grad(x1)

    def discrete_grad(self, x1, x2):
        return self.net.discrete_grad(x1, x2)

    def time_derivative(self, x1):
        if self.finde is not None and self.finde.is_continuous:
            return self.finde.project_to_TuM(x1, self.net.time_derivative(x1))
        else:
            return self.net.time_derivative(x1)

    def time_derivative_q(self, q, p):
        return self.time_derivative(torch.cat([q, p], dim=-1)).chunk(2, dim=-1)[0]

    def time_derivative_p(self, q, p):
        return self.time_derivative(torch.cat([q, p], dim=-1)).chunk(2, dim=-1)[1]

    def discrete_time_derivative(self, x1, dt, *, x2=None, xtol=1e-7, eval_mode=False):
        # get x2; given x2 is used for implicit solver.
        solver = self.solver_eval if eval_mode else self.solver
        psi = None
        # for implicit solvers
        # when x2 is not given, x2 is obtained implicitly
        if solver in SOLVER_LIST_ADDITIONAL_IMPLICIT and x2 is None:
            assert not torch.is_grad_enabled()
            x2_ = self.odeint(OdeintWrapper(self), x1, torch.tensor([0, dt]).to(x1), method='midpoint')[-1]
            psi = fsolve_gpu(lambda xp: self.discrete_time_derivative(x1, x2=xp, dt=dt, eval_mode=eval_mode) - (xp - x1) / dt, x2_, xtol=xtol)
        # when x2 is given, x2 is used.
        if solver == 'implicit_midpoint':
            psi = self.time_derivative((x1 + x2) / 2)
        if solver == 'dg':
            psi = self.net.discrete_time_derivative(x1, x2=x2)
        # for explicit solvers
        if solver in SOLVER_LIST:
            func = OdeintWrapper(self)
            x2_ = self.odeint(func, x1, torch.tensor([0, dt]).to(x1), method=solver)[-1]
            psi = (x2_ - x1) / dt
        if solver == 'leapfrog':
            dt2 = dt / 2
            q, p = x1.chunk(2, dim=-1)
            p = p + dt2 * self.time_derivative_p(q, p)
            q = q + dt * self.time_derivative_q(q, p)
            p = p + dt2 * self.time_derivative_p(q, p)
            x2_ = torch.cat([q, p], dim=-1)
            psi = (x2_ - x1) / dt
        if solver == 'leapfrog2':
            dt2 = dt / 2
            q, p = x1.chunk(2, dim=-1)
            q = q + dt2 * self.time_derivative_q(q, p)
            p = p + dt * self.time_derivative_p(q, p)
            q = q + dt2 * self.time_derivative_q(q, p)
            x2_ = torch.cat([q, p], dim=-1)
            psi = (x2_ - x1) / dt
        if psi is None:
            raise NotImplementedError(solver)
        if self.finde is not None and self.finde.is_discrete:
            return self.finde.project_to_TvuM(x1=x1, psi_hat=psi, dt=dt, x2=x2)   # type: ignore
        return psi

    def get_orbit(self, x0, t_eval, xtol=1e-7):
        if self.solver in SOLVER_LIST and (self.finde is None or not self.finde.is_discrete):
            orbit = self.odeint(OdeintWrapper(self), x0, t_eval.to(x0), method=self.solver)
        else:
            orbit_list = [x0, ]
            x1 = x0
            dts = t_eval[1:] - t_eval[:-1]
            for itr, dt in enumerate(dts):
                print(itr, '/', len(dts), end='\r')
                dudt = self.discrete_time_derivative(x1=x1, dt=dt, xtol=xtol, eval_mode=True)
                x2 = dt * dudt + x1
                orbit_list.append(x2)
                x1 = x2
            orbit = torch.stack(orbit_list, dim=0)
        return orbit
