from .models import *
from .finde import *
from .ode import PhysicsModel

def periodic_pad1d(x1, x2=None, padding=1):
    # padding for periofic boundary condition.
    x1 = torch.cat([x1[..., -padding:], x1, x1[..., :padding]], dim=-1)
    if x2 is None:
        return x1
    x2 = torch.cat([x2[..., -padding:], x2, x2[..., :padding]], dim=-1)
    return x1, x2


class PeriodicPad1d(nn.Module):
    def __init__(self, padding=1):
        super(PeriodicPad1d, self).__init__()
        self.padding = padding

    def forward(self, x1, x2=None):
        return periodic_pad1d(x1, x2, padding=self.padding)


class GlobalSummation1d(nn.Module):
    # global summation to get a system energy
    def __init__(self, c=1.):
        super(GlobalSummation1d, self).__init__()
        self.c = c

    def forward(self, x1, x2=None):
        x1 = x1.sum(-1, keepdim=True) * self.c
        if x2 is None:
            return x1
        x2 = x2.sum(-1, keepdim=True) * self.c
        return x1, x2


def get_PDENODE(input_dim, hidden_dim, output_dim, act='tanh', bias=True, global_sum=False, data_mean=None, data_std=None):
    Act = get_activation_from_name(act)
    sequence = [
        PeriodicPad1d(1),
        Conv1d(input_dim, hidden_dim, kernel_size=3),
        Act(),
        Conv1d(hidden_dim, hidden_dim, kernel_size=1),
        Act(),
    ]
    if global_sum:
        sequence += [
            Conv1d(hidden_dim, output_dim, kernel_size=1, bias=bias),
            GlobalSummation1d(),
        ]
    else:
        sequence += [
            Conv1d(hidden_dim, output_dim, kernel_size=1, bias=bias),
        ]

    if data_mean is not None:
        assert data_std is not None
        sequence = [Normalizer(data_mean, data_std), ] + sequence
    model = Sequential(*sequence)
    return model

def get_CNN(input_dim, hidden_dim, output_dim, act='tanh', bias=True, global_sum=False, data_mean=None, data_std=None):
    Act = get_activation_from_name(act)
    sequence = [
        PeriodicPad1d(1),
        Conv1d(input_dim, hidden_dim, kernel_size=3),
        Act(),
        PeriodicPad1d(1),
        Conv1d(hidden_dim, hidden_dim, kernel_size=3),
        Act(),
    ]
    if global_sum:
        sequence += [
            PeriodicPad1d(1),
            Conv1d(hidden_dim, output_dim, kernel_size=3, bias=bias),
            GlobalSummation1d(),
        ]
    else:
        sequence += [
            PeriodicPad1d(1),
            Conv1d(hidden_dim, output_dim, kernel_size=3, bias=bias),
        ]

    if data_mean is not None:
        assert data_std is not None
        sequence = [Normalizer(data_mean, data_std), ] + sequence
    model = Sequential(*sequence)
    return model

def get_nn(input_dim, hidden_dim, act, model, data_mean=None, data_std=None):
    if model == 'node':
        return NODE(
            net=get_PDENODE(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=input_dim, act=act, bias=True, data_mean=data_mean, data_std=data_std),
        )
    if model == 'cnn':
        return NODE(
            net=get_CNN(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=input_dim, act=act, bias=True, data_mean=data_mean, data_std=data_std),
        )
    raise NotImplementedError(model)


class PhysicsModelPDE1d(PhysicsModel):
    def __init__(self, input_dim, hidden_dim, act='tanh', model='hnn', solver='dopri5', data_mean=None, data_std=None, finde=None):
        # skip default initializer
        torch.nn.Module.__init__(self)
        self.odeint = odeint
        self.model = model
        if ',' in solver:
            self.solver, self.solver_eval = solver.split(',')
        else:
            self.solver = self.solver_eval = solver
        self.check_combination()
        self.net = get_nn(input_dim, hidden_dim, act, model=model, data_mean=data_mean, data_std=data_std)
        if finde:
            if model=='cnn':
                quantities = get_CNN(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=finde.num, act=act, bias=False, global_sum=True, data_mean=data_mean, data_std=data_std)
            else:
                quantities = get_PDENODE(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=finde.num, act=act, bias=False, global_sum=True, data_mean=data_mean, data_std=data_std)
            self.finde = get_finde(finde, input_dim, hidden_dim, act, model=self.net, data_mean=data_mean, data_std=data_std, quantities=quantities)
        else:
            self.finde = None
        self.reset_parameters()

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
                nn.init.orthogonal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
