import torch
import torch.nn as nn


def init_identity(model):
    for name, param in model.named_parameters():
        if 'weight' in name:
            nn.init.eye_(param)
        elif 'bias' in name:
            nn.init.zeros_(param)
        else:
            raise ValueError(f"unexpected parameter name: {name}")
        
    linear_layers = [
        module for module in model.modules() 
        if isinstance(module, nn.Linear)
    ]
    if linear_layers:
        nn.init.constant_(linear_layers[0].bias, 2.0)
        nn.init.constant_(linear_layers[-1].bias, -2.0)


class MLP(nn.Module):
    def __init__(
        self, input_dim, hidden_dim=64, output_dim=None, 
        num_layers=3, time_varying=False, context_dim=None
        ) -> None:
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.time_dim = 1 if time_varying else 0
        
        if output_dim is None:
            output_dim = input_dim
            
        if context_dim is not None:
            self.context_embedding = nn.Sequential([
                nn.Linear(context_dim, hidden_dim),
                nn.ReLU(),
                *[nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU()
                ) for _ in range(num_layers - 1)],
                nn.Linear(hidden_dim, hidden_dim)
            ])
            
        self.input_layer = nn.Sequential(
            nn.Linear(input_dim + self.time_dim, hidden_dim),
            nn.ReLU(),
        )
            
        self.hidden_layers = nn.Sequential(
            *[nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU()
            ) for _ in range(num_layers - 1)],
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x, context=None):
        emb = self.input_layer(x)
        if context is not None:
            emb += self.context_embedding(context)
        return self.hidden_layers(emb)
    

class FlowModelWrapper(nn.Module):
    def __init__(
        self, input_dim, hidden_dim=64, output_dim=None, 
        num_layers=3, time_varying=False, context_dim=None,
        ) -> None:
        super(FlowModelWrapper, self).__init__()
        self.model = MLP(
            input_dim, hidden_dim, output_dim, num_layers, time_varying, context_dim
        )
    
    def forward(self, t, x, context):
        y = torch.cat([x, t[:, None]], dim=-1)
        return self.model(y)