from config.structured import ProjectConfig
from .model import ConditionalPointCloudDiffusionModel
from .model_utils import set_requires_grad


def get_model(cfg: ProjectConfig):
    model = ConditionalPointCloudDiffusionModel(
        **cfg.model,
        category=cfg.dataset.category,
        dataset=cfg.dataset,
        # prior_point_radius=cfg.model.prior_point_radius,
        # prior_frame_num=cfg.model.prior_frame_num,
        # prior_use_depth=cfg.model.prior_use_depth,
    )
    if cfg.run.freeze_feature_model:
        set_requires_grad(model.feature_model, False)
    return model


def get_coloring_model(cfg: ProjectConfig):
    from .model_coloring import PointCloudColoringModel
    model = PointCloudColoringModel(**cfg.model)
    if cfg.run.freeze_feature_model:
        set_requires_grad(model.feature_model, False)
    return model
