import argparse
import os
from datetime import datetime

import torch
from easydict import EasyDict

PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
DEFAULT_LOAD_DIR = os.path.join(PROJECT_DIR, "dump")
DEFAULT_OUTPUT_DIR = os.path.join(PROJECT_DIR, "dump")


def _str2bool(value):
    if isinstance(value, bool):
        return value
    value = value.lower()
    if value in ("yes", "true", "t", "1", "y"):
        return True
    if value in ("no", "false", "f", "0", "n"):
        return False
    raise argparse.ArgumentTypeError("Expected a boolean value.")


def build_parser():
    parser = argparse.ArgumentParser(description="US-States ILI configuration")

    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--debug", type=_str2bool, default=False)

    parser.add_argument("--load-dir", default=DEFAULT_LOAD_DIR)
    parser.add_argument("--output-dir", default=DEFAULT_OUTPUT_DIR)
    parser.add_argument("--log-path", default=None)

    parser.add_argument("--dataset", default="ILI")
    parser.add_argument("--data-dir", default="./data")
    parser.add_argument("--num-states", type=int, default=49)
    parser.add_argument("--states-per-client", type=int, default=7)
    parser.add_argument("--sequence-length", type=int, default=10)
    parser.add_argument("--num-features", type=int, default=7)
    parser.add_argument("--pin-memory", type=_str2bool, default=True)
    parser.add_argument("--num-workers", type=int, default=4)

    parser.add_argument("--train-rows", type=int, default=300)
    parser.add_argument("--test-rows", type=int, default=60)

    parser.add_argument("--use-g-encode", type=_str2bool, default=True)

    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--lambda-gan", type=float, default=0.5)
    parser.add_argument("--batch-size", type=int, default=32)
    parser.add_argument("--lr-d", type=float, default=1e-4)
    parser.add_argument("--lr-f", type=float, default=1e-4)
    parser.add_argument("--lr-e", type=float, default=1e-4)
    parser.add_argument("--gamma", type=float, default=1000)
    parser.add_argument("--beta1", type=float, default=0.9)
    parser.add_argument("--beta2", type=float, default=0.999)
    parser.add_argument("--weight-decay", type=float, default=5e-4)
    parser.add_argument("--shuffle", type=_str2bool, default=True)

    parser.add_argument("--gat-rounds", type=int, default=10)
    parser.add_argument("--gat-epochs", type=int, default=20)
    parser.add_argument("--gat-lr", type=float, default=1e-5)

    parser.add_argument("--num-clients", type=int, default=7)
    parser.add_argument("--num-task", type=int, default=6)
    parser.add_argument("--weeks-per-task", type=int, default=50)
    parser.add_argument("--num-local-epochs", type=int, default=10)
    parser.add_argument("--num-rounds", type=int, default=20)

    parser.add_argument("--use-visdom", type=_str2bool, default=False)
    parser.add_argument("--outf", default=DEFAULT_LOAD_DIR)

    parser.add_argument("--nt", type=int, default=7)
    parser.add_argument("--nh", type=int, default=512)
    parser.add_argument("--ni", type=int, default=512)
    parser.add_argument("--nc", type=int, default=7)
    parser.add_argument("--nd-out", type=int, default=7)
    parser.add_argument("--p", type=float, default=0.2)
    parser.add_argument("--no-bn", type=_str2bool, default=True)

    parser.add_argument("--dp", type=_str2bool, default=True)
    parser.add_argument("--sensitivity", type=float, default=1.0)
    parser.add_argument("--epsilon", type=float, default=1.0)

    parser.add_argument("--ray-num-gpus-per-task", type=float, default=None)
    parser.add_argument("--ray-num-cpus-per-task", type=float, default=1.0)
    parser.add_argument("--ray-max-in-flight", type=int, default=None)

    parser.add_argument(
        "--ablation",
        default="none",
        choices=["none", "no_graph", "no_temporal", "no_dp"],
        help="Optional ablation switches.",
    )

    parser.add_argument("--replay", type=_str2bool, default=True)

    return parser


def finalize_opt(opt):
    opt.train = not opt.debug
    if opt.device == "cuda" and not torch.cuda.is_available():
        opt.device = "cpu"

    if opt.output_dir == DEFAULT_OUTPUT_DIR:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        opt.output_dir = os.path.join(opt.output_dir, timestamp)
        opt.output_dir_is_default = True
    else:
        opt.output_dir_is_default = False

    if not os.path.exists(opt.load_dir):
        os.makedirs(opt.load_dir, exist_ok=True)

    if not os.path.exists(opt.output_dir):
        os.makedirs(opt.output_dir, exist_ok=True)

    if opt.log_path is None:
        opt.log_path = os.path.join(opt.output_dir, "run.log")

    opt.b = opt.sensitivity / opt.epsilon if opt.epsilon != 0 else 0.0

    if opt.ray_num_gpus_per_task is None:
        if opt.device == "cpu":
            opt.ray_num_gpus_per_task = 0.0
        else:
            opt.ray_num_gpus_per_task = 1.0 / max(1, opt.num_clients)

    if opt.ray_max_in_flight is None:
        opt.ray_max_in_flight = min(opt.num_clients, 4)

    return opt


def parse_args(args=None):
    parser = build_parser()
    parsed = parser.parse_args(args=args)
    opt = EasyDict(vars(parsed))
    return finalize_opt(opt)


opt = parse_args(args=[])
