import argparse
import json
import os
from model.networks import *
import torch.utils.data as data_utils



def get_ood_shapetype(spec):
    if "Airway" in spec["Class"]:
        spec["Class"] = "Airway_SGS"
    else:
        print("wrong dataset: there is no ood dataset for " + spec["Class"])
    return spec
def get_dataloader(dataset_, batch_size, which_split):
    if 'train' in which_split:
        SHUFFLE = True
    else:
        SHUFFLE = False

    dataloader = data_utils.DataLoader(
        dataset_,
        batch_size=batch_size,
        shuffle=SHUFFLE,
        drop_last=False,
    )
    return dataloader
def load_dataset(specs: dict, which_split:str):

    '''
    load training parameters
    '''
    shapetype = specs["Class"]
    List_Covariate_Names = specs["CovariateNames"]
    Filename_Datasource = specs["DataSource"]
    Filename_Split = specs["Split"]
    batch_size = specs["BatchSize"]
    tgt_var_name = specs["TargetVariableName"]
    allow_missingness = specs["AllowMissingness"]
    padding_muter = specs["AllowMuters"]

    if shapetype == 'Airway':
        atlas_dataset = PediatricAirwayDataset(
            filename_datasource=Filename_Datasource,
            filename_split=Filename_Split,
            covariate_names=List_Covariate_Names,
            tgt_var_name=tgt_var_name,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter,
        )

    if shapetype == 'AirwayLog':
        atlas_dataset = PediatricAirwayLogDataset(
            filename_datasource=Filename_Datasource,
            filename_split=Filename_Split,
            covariate_names=List_Covariate_Names,
            tgt_var_name=tgt_var_name,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter,
        )



    # elif shapetype == 'Airway_CTL':
    #     atlas_dataset= PediatricAirwayDataset(
    #         filename_datasource=Filename_Datasource,
    #         filename_split=Filename_Split,
    #         covariate_names=List_Covariate_Names,
    #         tgt_var_name=tgt_var_name,
    #         split=which_split,
    #         augtype='none',
    #         allow_missingness=allow_missingness,
    #         padding_muter=padding_muter,
    #     )

    elif shapetype == 'Airway_SGS' or shapetype == 'Airway_PRS' or shapetype == 'Airway_TBM':
        Filename_OOD_Datasource = specs.get("OODDataSource", "")
        Filename_OOD_Split = specs.get("OODSplit", "")
        ONLY_OOD = specs.get("ONLY_OOD", False)

        atlas_dataset = AirwayOODDecDataset(
            filename_atlas_datasource=Filename_Datasource,
            filename_atlas_split=Filename_Split,
            filename_ood_datasource=Filename_OOD_Datasource,
            filename_ood_split=Filename_OOD_Split,
            covariate_names=List_Covariate_Names,
            tgt_var_name=tgt_var_name,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter,
            ONLY_OOD = ONLY_OOD
        )


    elif shapetype == "OASISBrain":
        atlas_dataset= OASISBrainDataset(
            filename_datasource=Filename_Datasource,
            filename_split=Filename_Split,
            covariate_names=List_Covariate_Names,
            tgt_var_name=tgt_var_name,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter)


    elif "AFQ" in shapetype:
        atlas_dataset = AFQDataset(
            filename_datasource=Filename_Datasource,
            filename_split=Filename_Split,
            covariate_names=List_Covariate_Names,
            tract_name=specs["TractName"],
            tgt_var_name=tgt_var_name,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter,
        )


    elif shapetype == "ToyData":
        atlas_dataset = ToySTDataset(
            filename_datasource="",
            filename_split="",
            covariate_names=List_Covariate_Names,
            split=which_split,
            augtype='none',
            allow_missingness=allow_missingness,
            padding_muter=padding_muter,
            training_sample_size=specs.get("SampleSize", 200000)
        )


    dataloader = get_dataloader(atlas_dataset, batch_size, which_split)

    return atlas_dataset, dataloader


if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Train a LucidAtlas autodecoder")
    arg_parser.add_argument(
        "--experiment",
        "-e",
        dest="experiment_directory",
        default='../../configs/airways/lucidatlas_1d_csa_v12_0115.json',
        help="The experiment directory. This directory should include "
             + "experiment specifications in 'specs.json', and logging will be "
             + "done in this directory as well.",
    )
    arg_parser.add_argument(
        "--checkpoint",
        "-c",
        dest="checkpoint",
        default="latest",
        help="The checkpoint weights to use. This can be a number indicated an epoch "
        + "or 'latest' for the latest weights (this is the default)",
    )

    args = arg_parser.parse_args()
    csa_dataset, csa_dataloader = load_dataset(specs_filename=args.experiment_directory,
                                                which_split='train')
    print('1')