from .ar_diffusion import *
from .attention import *
from .avit import *
from .fno import *
from .hypernetwork import *
from .ic_neuralPDE import *
from .tokenizer import *
from .unet import *


def build_model(params):
    if params.model == "avit":
        finetune = False
        if 'finetune' in params:
            finetune = params.finetune
        return AViT(
            tokenizer_type="CNN",
            padding_mode=params.padding_mode,
            in_channels=params.in_channels,
            spatial_ndims=params.spatial_ndims,
            patch_size=params.patch_size,
            num_heads=params.num_heads,
            embed_dim=params.embed_dim,
            processor_blocks=params.processor_blocks,
            bias_type=params.bias_type,
            drop_path=params.drop_path,
            mpp_norm=params.pretrained_MPP,
            finetune=finetune
        )
    elif params.model == "icnpde":
        finetune = False
        if 'finetune' in params:
            finetune = params.finetune
        return ICneuralPDE(
            in_channels=params.in_channels,
            spatial_ndims=params.spatial_ndims,
            embed_dim=params.embed_dim,
            num_heads=params.num_heads,
            processor_blocks=params.processor_blocks,
            bias_type=params.bias_type,
            drop_path=params.drop_path,
            scheme_type=params.scheme_type,
            timesteps_integrator=params.timesteps_integrator,
            padding_mode=params.padding_mode,
            hidden_dim=params.hidden_dim,
            ksize=params.ksize,
            finetune=finetune
        )
    elif params.model == "ardiff":
        finetune = False
        if 'finetune' in params:
            finetune = params.finetune
        if params.in_channels is None:
            cond_channels = data_channels = None
        else:
            cond_channels = 2 * params.in_channels  # 2 past steps x n_fields
            data_channels = params.in_channels  # n_fields
        return DiffusionModel(
            params.hidden_dim,
            params.diffusion_steps, 
            cond_channels, 
            data_channels,
            finetune
        )
    elif params.model == "fno":
        modes = (params.modes,)*params.spatial_ndims
        return FNO(
            n_modes=modes,
            in_channels=params.in_channels,
            out_channels=params.in_channels,
            hidden_channels=params.hidden_channels,
        )
    elif params.model == "unet": 
        return UNet(
            dim_in=params.in_channels,
            dim_out=params.in_channels,
            n_spatial_dims=params.spatial_ndims,
            init_features=params.init_features,
        )
    else:
        raise ValueError(f"Unknown model {params.model}")