# Arguments
import hydra
from omegaconf import OmegaConf

# Our imports
import gp.svgp.train
import gp.sgpr.train
import gp.softki.train
import gp.exact.train
# import gp.simplex_gp.train
import gp.ski.train
import gp.skip.train
from gp.util import *

from data.get_uci import (
    PoleteleDataset,
    ElevatorsDataset,
    BikeDataset,
    Kin40KDataset,
    ProteinDataset,
    KeggDirectedDataset,
    CTSlicesDataset,
    KeggUndirectedDataset,
    RoadDataset,
    SongDataset,
    BuzzDataset,
    HouseElectricDataset,
)
from data.get_md22 import (
    MD22_AcAla3NHME_Dataset,
    MD22_DHA_Dataset,
    MD22_DNA_AT_AT_CG_CG_Dataset,
    MD22_DNA_AT_AT_Dataset,
    MD22_Stachyose_Dataset,
    MD22_Buckyball_Catcher_Dataset,
    MD22_DoubleWalledNanotube_Dataset,
)


@hydra.main(version_base=None, config_path="./", config_name="config")
def main(cli_config):
    OmegaConf.set_struct(cli_config, False)
    print(cli_config)
 
    # Config and train function factory
    if cli_config.model == "svgp":
        train_gp = gp.svgp.train.train_gp
        config = cli_config.gp.svgp
    elif cli_config.model == "softki":
        train_gp = gp.softki.train.train_gp
        config = cli_config.gp.softki
    elif cli_config.model == "sgpr":
        train_gp = gp.sgpr.train.train_gp
        config = cli_config.gp.sgpr
    elif cli_config.model == "exact":
        train_gp = gp.exact.train.train_gp
        config = cli_config.gp.exact
    elif cli_config.model == "ski":
        train_gp = gp.ski.train.train_gp
        config = cli_config.gp.ski
    elif cli_config.model == "skip":
        train_gp = gp.skip.train.train_gp
        config = cli_config.gp.skip
    else:
        raise ValueError(f"Name not found {config.model.name}")
    config = OmegaConf.merge(config, {"dataset": cli_config.dataset, "wandb": cli_config.wandb})

    # Dataset factory
    minmax = cli_config.minmax
    if config.dataset.name == "pol":
        dataset = PoleteleDataset(f"{cli_config.data_dir}/pol/data.csv", minmax=minmax)
    elif config.dataset.name == "elevators":
        dataset = ElevatorsDataset(f"{cli_config.data_dir}/elevators/data.csv", minmax=minmax)
    elif config.dataset.name == "bike":
        dataset = BikeDataset(f"{cli_config.data_dir}/bike/data.csv", minmax=minmax)
    elif config.dataset.name == "kin40k":
        dataset = Kin40KDataset(f"{cli_config.data_dir}/kin40k/data.csv", minmax=minmax)
    elif config.dataset.name == "protein":
        dataset = ProteinDataset(f"{cli_config.data_dir}/protein/data.csv", minmax=minmax)
    elif config.dataset.name == "keggdirected":
        dataset = KeggDirectedDataset(f"{cli_config.data_dir}/keggdirected/data.csv", minmax=minmax)
    elif config.dataset.name == "slice":
        dataset = CTSlicesDataset(f"{cli_config.data_dir}/slice/data.csv", minmax=minmax)
    elif config.dataset.name == "keggundirected":
        dataset = KeggUndirectedDataset(f"{cli_config.data_dir}/keggundirected/data.csv", minmax=minmax)
    elif config.dataset.name == "3droad":
        dataset = RoadDataset(f"{cli_config.data_dir}/3droad/data.csv", minmax=minmax)
    elif config.dataset.name == "song":
        dataset = SongDataset(f"{cli_config.data_dir}/song/data.csv", minmax=minmax)
    elif config.dataset.name == "buzz":
        dataset = BuzzDataset(f"{cli_config.data_dir}/buzz/data.csv", minmax=minmax)
    elif config.dataset.name == "houseelectric":
        dataset = HouseElectricDataset(f"{cli_config.data_dir}/houseelectric/data.csv", minmax=minmax)
    elif config.dataset.name == "Ac-Ala3-NHMe":
        dataset = MD22_AcAla3NHME_Dataset(f"{cli_config.data_dir}/md22_Ac-Ala3-NHMe.npz")
    elif config.dataset.name == "AT-AT":
        dataset = MD22_DNA_AT_AT_Dataset(f"{cli_config.data_dir}/md22_AT-AT.npz")
    elif config.dataset.name == "AT-AT-CG-CG":
        dataset = MD22_DNA_AT_AT_CG_CG_Dataset(f"{cli_config.data_dir}/md22_AT-AT-CG-CG.npz")
    elif config.dataset.name == "stachyose":
        dataset = MD22_Stachyose_Dataset(f"{cli_config.data_dir}/md22_stachyose.npz")
    elif config.dataset.name == "DHA":
        dataset = MD22_DHA_Dataset(f"{cli_config.data_dir}/md22_DHA.npz")
    elif config.dataset.name == "buckyball-catcher":
        dataset = MD22_Buckyball_Catcher_Dataset(f"{cli_config.data_dir}/md22_buckyball-catcher.npz")
    elif config.dataset.name == "double-walled-nanotube":
        dataset = MD22_DoubleWalledNanotube_Dataset(f"{cli_config.data_dir}/md22_double-walled_nanotube.npz")
    else:
        raise ValueError(f"Dataset {config.dataset.name} not supported ...")
    
    # Seed
    np.random.seed(config.training.seed)
    torch.manual_seed(config.training.seed)

    # Generate splits
    train_dataset, val_dataset, test_dataset = split_dataset(
        dataset,
        train_frac=config.dataset.train_frac,
        val_frac=config.dataset.val_frac
    )

    # Train
    model = train_gp(config, train_dataset, test_dataset)
    

if __name__ == "__main__":
    main()