from .models import UNet
from .tabular_diffusion import TabularDiffusionMLP
from .transformer_diffusion import TabularDiffusionTransformer


def get_model(config, data_shape, device):
    if config["model"] == "UNet":
        channel = data_shape[1]
        image_size = data_shape[2]
        model = UNet(
            c_in=channel,
            c_out=channel,
            time_dim=config["noise_step"],
            device=device,
            image_size=image_size,
        )
    elif config["model"] == "TabularDiffusionMLP":
        input_dim = 1
        for dim in data_shape[1:]:
            input_dim *= dim

        # Handle different conditioning methods
        use_archetype_conditioning = config.get("use_archetype_conditioning", False)
        condition_dim = 0
        num_archetypes = None
        
        if use_archetype_conditioning:
            condition_dim = config.get("condition_dim", 64)
            num_archetypes = config.get("num_archetypes", 5)
            print(f"Using archetype conditioning with {num_archetypes} archetypes")
        elif config.get("condition") == "concat":
            condition_dim = config["num_classes"]

        print(f"Input dimension for TabularDiffusionMLP: {input_dim}")
        print(f"Hidden dimension for TabularDiffusionMLP: {config['hidden_dim']}")
        print(f"Condition dimension: {condition_dim}")
        
        model = TabularDiffusionMLP(
            config,
            input_dim=input_dim,
            hidden_dim=config["hidden_dim"],
            num_blocks=config["num_blocks"],
            condition_dim=condition_dim,
            use_archetype_conditioning=use_archetype_conditioning,
            num_archetypes=num_archetypes,
        )
    elif config["model"] == "TabularDiffusionTransformer":
        input_dim = 1
        for dim in data_shape[1:]:
            input_dim *= dim

        # Handle different conditioning methods
        use_archetype_conditioning = config.get("use_archetype_conditioning", False)
        condition_dim = 0
        num_archetypes = None
        
        if use_archetype_conditioning:
            condition_dim = config.get("condition_dim", 64)
            num_archetypes = config.get("num_archetypes", 5)
            print(f"Using archetype conditioning with {num_archetypes} archetypes")

        print(f"Input dimension for TabularDiffusionTransformer: {input_dim}")
        print(f"Hidden dimension for TabularDiffusionTransformer: {config['hidden_dim']}")
        print(f"Condition dimension: {condition_dim}")
        
        model = TabularDiffusionTransformer(
            config,
            input_dim=input_dim,
            hidden_dim=config["hidden_dim"],
            num_blocks=config["num_blocks"],
            use_archetype_conditioning=use_archetype_conditioning,
            num_archetypes=num_archetypes,
            condition_dim=condition_dim,
        )
    else:
        raise ValueError(f"Model {config['model']} not supported.")
    return model.to(device)
