from model.ivp_solvers import CouplingFlow, ODEModel, ResNetFlow, GRUFlow, GNNFlow, InverGNNFlow,GraphFlow
from model.components import Embedding_MLP, Embedding_MLP_GNN, Reconst_Mapper_MLP
from model.ivp_vae_biclass import IVPVAE_BiClass
from model.ivp_vae_extrap import IVPVAE_Extrap
from model.ivp_vae_synthetic import IVPVAE_Synthetic
from utils import SolverWrapper, SolverWrapper_GNN

class ModelFactory:
    def __init__(self, args):
        self.args = args

    def build_ivp_solver(self, states_dim):
        ivp_solver = None
        hidden_dims = [self.args.hidden_dim] * self.args.hidden_layers
        if self.args.ivp_solver == 'ode':
            ivp_solver = SolverWrapper(ODEModel(states_dim, self.args.odenet, hidden_dims, self.args.activation,
                                                self.args.final_activation, self.args.ode_solver, self.args.solver_step, self.args.atol, self.args.rtol))
        else:
            if self.args.ivp_solver == 'couplingflow':
                flow = CouplingFlow
            elif self.args.ivp_solver == 'resnetflow':
                flow = ResNetFlow
            elif self.args.ivp_solver == 'gruflow':
                flow = GRUFlow
            elif self.args.ivp_solver == 'gnn':
                flow = GNNFlow
            elif self.args.ivp_solver == 'invergnn':
                flow = InverGNNFlow
            elif self.args.ivp_solver == 'graph':
                flow = GraphFlow
            else:
                raise NotImplementedError

            if self.args.ivp_solver == 'gnn' or self.args.ivp_solver == 'invergnn'or self.args.ivp_solver == 'graph':
                
                ivp_solver = SolverWrapper_GNN(flow(
                    1,states_dim, self.args.flow_layers, hidden_dims, self.args.time_net, self.args.time_hidden_dim))
            else:
                ivp_solver = SolverWrapper(flow(
                    states_dim, self.args.flow_layers, hidden_dims, self.args.time_net, self.args.time_hidden_dim))
        return ivp_solver
    
    def init_components(self):

        embedding_nn = Embedding_MLP(
            self.args.variable_num, 
            self.args.latent_dim
        )

        embedding_nn_gnn = Embedding_MLP_GNN(
            1, 
            1,
        )

        ivp_solver1 = self.build_ivp_solver(self.args.latent_dim)

        reconst_mapper = Reconst_Mapper_MLP(
            self.args.latent_dim, self.args.variable_num)

        return embedding_nn, embedding_nn_gnn, ivp_solver1,reconst_mapper


    def initialize_biclass_model(self):

        embedding_nn,embedding_nn_gnn, diffeq_solver, reconst_mapper = self.init_components()

        return IVPVAE_BiClass(
            args=self.args,
            embedding_nn=embedding_nn,
            embedding_nn_gnn=embedding_nn_gnn,
            reconst_mapper=reconst_mapper,
            diffeq_solver=diffeq_solver)
    
    def initialize_extrap_model(self):

        embedding_nn,embedding_nn_gnn,diffeq_solver1,reconst_mapper = self.init_components()

        return IVPVAE_Extrap(
            args=self.args,
            embedding_nn=embedding_nn,
            embedding_nn_gnn=embedding_nn_gnn,
            reconst_mapper=reconst_mapper,
            diffeq_solver1=diffeq_solver1)
    def initialize_synthetic_model(self):

        embedding_nn,embedding_nn_gnn,diffeq_solver1,reconst_mapper = self.init_components()
        return IVPVAE_Synthetic(
            args=self.args,
            embedding_nn=embedding_nn,
            embedding_nn_gnn=embedding_nn_gnn,
            reconst_mapper=reconst_mapper,
            diffeq_solver1=diffeq_solver1)

