from typing import Tuple, NamedTuple, Optional
import numpy as np

from .demand_design_image import generate_test_demand_design_image, generate_train_demand_design_image
from .demand_design import generate_test_demand_design, generate_train_demand_design
from .dsprine import generate_train_dsprite, generate_test_dsprite
from .sin import generate_test_sin, generate_train_sin
from .imca import generate_train_imca, generate_test_imca
from .dsprites import generate_train_dsprites_latent, generate_test_dsprites_latent
from .data_class import TrainDataSet, TestDataSet


def generate_train_data(data_name: str, rand_seed: int, **args) -> TrainDataSet:
    if data_name == "demand":
        return generate_train_demand_design(args["data_size"], args["rho"], rand_seed, False)
    elif data_name == "demand_old":
        # Demand design for no covariate (deprecated)
        return generate_train_demand_design(args["data_size"], args["rho"], rand_seed, True)
    elif data_name == "demand_image":
        return generate_train_demand_design_image(args["data_size"], args["rho"], rand_seed)
    elif data_name == "dsprite":
        return generate_train_dsprite(args["data_size"], rand_seed)
    elif data_name == "sin":
        return generate_train_sin(args["data_size"], rand_seed)
    elif data_name == "imca":
        return generate_train_imca(args["data_size"], args["rho"], rand_seed)
    elif data_name == "dsprites_latent":
        return generate_train_dsprites_latent(args["data_size"], args["rho"], rand_seed)
    else:
        raise ValueError(f"data name {data_name} is not valid")


def generate_test_data(data_name: str, **args) -> TestDataSet:
    if data_name == "demand":
        return generate_test_demand_design(False)
    elif data_name == "demand_old":
        # Demand design for no covariate (deprecated)
        return generate_test_demand_design(True)
    elif data_name == "demand_image":
        return generate_test_demand_design_image()
    elif data_name == "dsprite":
        return generate_test_dsprite()
    elif data_name == "sin":
        return generate_test_sin()
    elif data_name == "imca":
        return generate_test_imca()
    elif data_name == "dsprites_latent":
        return generate_test_dsprites_latent()
    else:
        raise ValueError(f"data name {data_name} is not valid")

