from typing import Optional, Literal
from models import GCN

from utils.loader import SubgraphLoader

def get_model(model_type: str, in_channels: int, hidden_channels: int,
              out_channels: int, num_layers: int, loader: SubgraphLoader,
              dropout: float, device:str,
              use_cache: bool, layer_wise_cache: bool, checkpointing_strategy: Literal['scattered', 'cpu', 'storage'],
              storage_offload: bool, storage_path: Optional[str] = None,
              optimize_dataloader: Optional[bool] = False,
              drop_input: Optional[bool] = False,
              batch_norm: Optional[bool] = False, residual: Optional[bool] = False,
              linear: Optional[bool] = False
              ):
    model = None
    if model_type.lower() == 'gcn':
        model = GCN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            out_channels=out_channels,
            num_layers=num_layers,
            loader=loader,
            dropout=dropout,
            drop_input=False,
            optimize_dataloader=optimize_dataloader,
            device=device,
            use_cache=use_cache,
            layer_wise_cache=layer_wise_cache,
            checkpointing_strategy=checkpointing_strategy,
            storage_offload=storage_offload,
            storage_path=storage_path
        ).to(device)
    else:
        raise NotImplementedError

    return model