import argparse
import os
from .core import *


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode", type=str, default="MSE_vary_B",
        help="determine which method to call"
    )
    parser.add_argument(
        "--a", nargs="+", type=float, default=[-1.0],
        help="list of parameter a < 0"
    )
    parser.add_argument(
        "--T", nargs="+", type=float, default=[8.0],
        help="list of data horizon T"
    )
    parser.add_argument(
        "--max_T", type=float, default=8.0,
        help="max data horizon T"
    )
    parser.add_argument(
        "--B", nargs="+", type=int, default=[4096],
        help="list of data budgets B"
    )
    parser.add_argument(
        "--gamma", type=float, default=0.9,
        help="LQR discount factor"
    )
    parser.add_argument(
        "--runs", type=int, default=50,
        help="number of runs"
    )
    parser.add_argument(
        "--nb_steps", type=int, default=1,
        help="number of gradient updates"
    )
    parser.add_argument(
        "--theta0", nargs="+", type=float, default=[0.0],
        help="list of init parameter theta"
    )
    parser.add_argument(
        "--alpha", nargs="+", type=float, default=[1.0],
        help="list of base learning rate"
    )
    parser.add_argument(
        "--lr_gamma", type=float, default=1.0,
        help="discount factor for exponential lr decay",
    )
    parser.add_argument(
        "--lr_scheduler", type=str, default="constant",
        help="lr decay mode: constant, reciprocal, exponential",
    )
    parser.add_argument(
        "--verbose", action="store_true",
        help="verbose flag"
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()

    # Print all received arguments
    print("Received arguments:")
    for arg in vars(args):
        print(f"{arg}: {getattr(args, arg)}")

    if args.mode == "compute":
        # compute the MSE for mean-path TD using a specific parameter
        compute_MSE(
            args.a[0],
            args.T[0],
            args.B[0],
            gamma=args.gamma,
            theta0=args.theta0[0],
            nb_steps=args.nb_steps,
            num_runs=args.runs,
            alpha=args.alpha[0],
            lr_scheduler=args.lr_scheduler,
            lr_gamma=args.lr_gamma,
            max_T=args.max_T,
            verbose=args.verbose
        )
    elif args.mode == "compute_LSTD":
        # compute the MSE for LSTD
        compute_MSE_LSTD(
            args.a[0],
            args.T[0],
            args.B[0],
            gamma=args.gamma,
            num_runs=args.runs,
            max_T=args.max_T,
            verbose=args.verbose
        )
    else:
        assert False, f"mode {args.mode} not implemented"

