import argparse
import logging
import os

import torch
from sklearn.metrics import accuracy_score

from tableshift import get_dataset
from tableshift.models.training import train
from tableshift.models.utils import get_estimator
from tableshift.models.default_hparams import get_default_config

LOG_LEVEL = logging.DEBUG

logger = logging.getLogger()
logging.basicConfig(
    format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
    level=LOG_LEVEL,
    datefmt='%Y-%m-%d %H:%M:%S')


def main(experiment, cache_dir, debug: bool):
    if experiment not in ['anes','acsincome','brfss_blood_pressure','acspubcov']:
        raise NotImplementedError
    if debug:
        print("[INFO] running in debug mode.")
        experiment = "_debug"

    dset = get_dataset(experiment, cache_dir)

    os.mkdir(f'./{experiment}/')
    X_tr, y_tr, _, _ = dset.get_pandas("train")
    X_tr.to_csv(f'./{experiment}/{experiment}_Xtrain.csv',index=None)
    y_tr.to_csv(f'./{experiment}/{experiment}_ytrain.csv',index=None)

    X_tr, y_tr, _, _ = dset.get_pandas("id_test")
    X_tr.to_csv(f'./{experiment}/{experiment}_Xidtest.csv',index=None)
    y_tr.to_csv(f'./{experiment}/{experiment}_yidtest.csv',index=None)

    X_tr, y_tr, _, _ = dset.get_pandas("validation")
    X_tr.to_csv(f'./{experiment}/{experiment}_Xval.csv',index=None)
    y_tr.to_csv(f'./{experiment}/{experiment}_yval.csv',index=None)

    X_tr, y_tr, _, _ = dset.get_pandas("ood_test")
    X_tr.to_csv(f'./{experiment}/{experiment}_Xood.csv',index=None)
    y_tr.to_csv(f'./{experiment}/{experiment}_yood.csv',index=None)
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_dir", default="tmp",
                        help="Directory to cache raw data files to.")
    parser.add_argument("--debug", action="store_true", default=False,
                        help="Whether to run in debug mode. If True, various "
                             "truncations/simplifications are performed to "
                             "speed up experiment.")
    parser.add_argument("--experiment", default="anes",
                        help="Experiment to run. Overridden when debug=True.")
    args = parser.parse_args()
    main(**vars(args))
