from pathlib import Path
import traceback
import argparse

from experiments.exp_biclass import Exp_BiClass
from experiments.exp_extrap import Exp_Extrap
from experiments.exp_synthetic import Exp_Synthetic

parser = argparse.ArgumentParser(
    description="Run all experiments within the Leit framework.")


parser.add_argument("--random-state", type=int, default=0, help="Random seed")
parser.add_argument("--proj-path", type=str,
                    default=str(Path(__file__).parents[0]))
parser.add_argument("--test-info", default="testing")
parser.add_argument("--leit-model", default="ivp_vae")

parser.add_argument("--num-dl-workers", type=int, default=4)
parser.add_argument("--device", type=str, default="cuda:1")
parser.add_argument("--exp-name", type=str, default="")

parser.add_argument("--epochs-min", type=int, default=1)

parser.add_argument("--epochs-max", type=int, default=1000,
                    help="Max training epochs")
parser.add_argument("--patience", type=int, default=5,
                    help="Early stopping patience")
parser.add_argument("--weight-decay", type=float,
                    default=0.0001, help="Weight decay (regularization)")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
parser.add_argument("--lr-scheduler-step", type=int, default=20,
                    help="Every how many steps to perform lr decay")
parser.add_argument("--lr-decay", type=float, default=0.5,
                    help="Multiplicative lr decay factor")
parser.add_argument("--clip-gradient", action='store_false')
parser.add_argument("--clip", type=float, default=1)
parser.add_argument("--log-tool", default="logging",
                    choices=["logging", "wandb"])

parser.add_argument("--data", default="physionet12", help="Dataset name",
                    choices=["m4_full","p12","physionet12_full_hour","physionet12", "physionet12_full","P19","eicu", 'sine', 'square', 'triangle', 'sawtooth', 'sink', 'sink_g', 'ellipse'])

parser.add_argument("--num-samples", type=int, default=-1)

parser.add_argument("--variable-num", type=int,
                    default=37, choices=[96, 37, 41, 14, 10, 5])

parser.add_argument("--ts-full", action='store_true')

parser.add_argument("--del-std5", action='store_true')

parser.add_argument("--time-scale", default="time_max",
                    choices=["time_max", "self_max", "constant", "none", "max"])
parser.add_argument("--time-constant", type=float, default=2880)
parser.add_argument("--first-dim", default="batch",
                    choices=["batch", "time_series"])
parser.add_argument("--batch-size", type=int, default=50)
parser.add_argument("--t-offset", type=float, default=0.1)
parser.add_argument("--ml-task", default="biclass",
                    choices=["biclass", "extrap", "synthetic"])

parser.add_argument("--extrap-full", action='store_true')

parser.add_argument("--down-times", type=int, default=1,
                    help="downsampling timestamps")

parser.add_argument("--time-max", type=int, default=2880)

parser.add_argument("--next-start", type=float, default=1440)

parser.add_argument("--next-end", type=float, default=2880)

parser.add_argument("--next-headn", type=int, default=0)

parser.add_argument('--mask-drop-rate', type=float, default=0.0)

parser.add_argument("--norm", action='store_false')

parser.add_argument("--ivp-solver", default="gnn",
                    choices=["resnetflow", "couplingflow", "gruflow", "gnn", "invergnn","graph","invergnn", "ode"])
parser.add_argument("--hidden-layers", type=int, default=3,
                    help="Number of hidden layers")
parser.add_argument("--hidden-dim", type=int, default=128,
                    help="Size of hidden layer")
parser.add_argument("--activation", type=str, default="ELU",
                    help="Hidden layer activation")
parser.add_argument("--final-activation", type=str,
                    default="Tanh", help="Last layer activation")
parser.add_argument("--atol", type=float, default=1e-4,
                    help="Absolute tolerance")
parser.add_argument("--rtol", type=float, default=1e-3,
                    help="Relative tolerance")
parser.add_argument("--flow-layers", type=int, default=2,
                    help="Number of flow layers")
parser.add_argument("--time-net", type=str, default="TimeTanh", help="Name of time net",
                    choices=["TimeFourier", "TimeFourierBounded", "TimeLinear", "TimeTanh"])
parser.add_argument("--time-hidden-dim", type=int, default=8,
                    help="Number of time features (only for Fourier)")


parser.add_argument("--k-iwae", type=int, default=1)
parser.add_argument("--kl-coef", type=float, default=1.0)
parser.add_argument("--latent-dim", type=int, default=20)
parser.add_argument("--classifier-input", default="z0")

parser.add_argument("--train-w-reconstr", action='store_false')

parser.add_argument("--ratio-ce", type=float, default=1000)

parser.add_argument("--ratio-nl", type=float, default=1)

parser.add_argument("--ratio-zz", type=float, default=0)

parser.add_argument("--prior-mu", type=float, default=0.0)

parser.add_argument("--prior-std", type=float, default=1.0)

parser.add_argument("--obsrv-std", type=float, default=0.01)

parser.add_argument("--combine-methods", default="average",
                    choices=["average", "kl_weighted"])

parser.add_argument('--n_ts',  type=int, default=3, help='Number of time series (for synth dataset)')
parser.add_argument('--dag_data', type=int, default=1, help='boolean for create interacting Time series data')
parser.add_argument('--log_metrics', type=int, default=0, help='boolean for logging metrics')
parser.add_argument('--scale', type=float, default=0.4, help='Noise to add in A truth init')
parser.add_argument('--max_iter', type=int, default=20, help='DAG learning')
parser.add_argument('--training_scheme', type=str, default='', help='to save ckpts', choices=['lgnf', 'tgnf'])
parser.add_argument('--experiment', type=str, help='Which experiment to run',
                    choices=['synthetic', 'ivp-vae'])
if __name__ == "__main__":
    args = parser.parse_args()
    if args.ml_task == 'extrap':

        experiment = Exp_Extrap(args)
    elif args.ml_task == 'biclass':

        experiment = Exp_BiClass(args)
    elif args.ml_task == 'synthetic':

        experiment = Exp_Synthetic(args)
    else:
        raise ValueError("Unknown")

    try:
        experiment.run()
        experiment.finish()
    except Exception:
        with open(experiment.proj_path/"log"/"err_{}.log".format(experiment.args.exp_name), "w") as fout:
            print(traceback.format_exc(), file=fout)
