from .modules import *
from .node import *
from .hnn import *
from .chnn import *


class Normalizer(nn.Module):
    def __init__(self, mean, std, reverse=False):
        super(Normalizer, self).__init__()
        self.shift = nn.parameter.Parameter(torch.Tensor(mean), requires_grad=False)
        self.scale = nn.parameter.Parameter(torch.Tensor(std), requires_grad=False)
        self.reverse = reverse

    def forward(self, x1, x2=None):
        if self.reverse:
            if x2 is None:
                return x1 * self.scale + self.shift
            return x1 * self.scale + self.shift, x2 * self.scale + self.shift

        if x2 is None:
            return (x1 - self.shift) / self.scale
        return (x1 - self.shift) / self.scale, (x2 - self.shift) / self.scale


def get_MLP(input_dim, hidden_dim, output_dim, act='tanh', bias=True, data_mean=None, data_std=None):
    Act = get_activation_from_name(act)
    sequence = [
        Linear(input_dim, hidden_dim),
        Act(),
        Linear(hidden_dim, hidden_dim),
        Act(),
        Linear(hidden_dim, output_dim, bias),
    ]
    if data_mean is not None:
        assert data_std is not None
        sequence = [Normalizer(data_mean, data_std), ] + sequence
    model = Sequential(*sequence)
    return model


def get_nn(input_dim, hidden_dim, act, model, data_mean=None, data_std=None):
    if data_mean is not None:
        assert data_std is not None
        data_mean_q = data_mean[:input_dim // 2]
        data_mean_p = data_mean[input_dim // 2:]
        data_std_q = data_std[:input_dim // 2]
        data_std_p = data_std[input_dim // 2:]
    else:
        data_mean_q, data_std_q = None, None
        data_mean_p, data_std_p = None, None

    if model == 'node':
        return NODE(
            net=get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=input_dim, act=act, bias=True, data_mean=data_mean, data_std=data_std),
        )
    if model == 'sonode':
        return SONODE(
            net=get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=input_dim // 2, act=act, bias=True, data_mean=data_mean, data_std=data_std),
        )
    if model == 'sepsonode':
        return SepSONODE(
            net=get_MLP(input_dim=input_dim // 2, hidden_dim=hidden_dim, output_dim=input_dim // 2, act=act, bias=True, data_mean=data_mean_q, data_std=data_std_q),
        )
    if model == 'hnn':
        return HNN(
            net=get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean, data_std=data_std),
        )
    if model == 'sephnn':
        return SepHNN(
            netV=get_MLP(input_dim=input_dim // 2, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean_q, data_std=data_std_q),
            netT=get_MLP(input_dim=input_dim // 2, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean_p, data_std=data_std_p),
        )
    if model == 'kinhnn':
        return SepHNN(
            netV=get_MLP(input_dim=input_dim // 2, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean_q, data_std=data_std_q),
            netT=Sequential(KineticEnergy(input_dim // 2)),
        )
    # SPECIAL FOR PAPER
    if model == 'massspring':
        return HamiltonianMassSpring()
    if model == 'chnn2pend':
        return CHNN2Pend(
            net=get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean, data_std=data_std),
        )
    if model == 'chnn2body':
        return CHNN2Body(
            net=get_MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=1, act=act, bias=False, data_mean=data_mean, data_std=data_std),
        )
    raise NotImplementedError(model)
