#!/usr/bin/env python3
import argparse
import os
import json

import yaml

from toolkit.utils import set_config_args, set_seed


def parse_arguments(save_config=False):
    parser = argparse.ArgumentParser()
    # miscellaneous args
    parser.add_argument(
        "--results",
        type=str,
        default="./results",
        help="Results path (base dir) (default=%(default)s)",
    )
    parser.add_argument(
        "--exp_name",
        default=None,
        type=str,
        help="Experiment name (default=%(default)s)",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="Random seed (default=%(default)s)"
    )
    parser.add_argument(
        "--save_models",
        action="store_true",
        help="Save trained models (default=%(default)s)",
    )

    # benchmark args
    parser.add_argument("--dataset", default="cifar10", choices=["cifar10", "cifar100", "mini"])
    parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"])
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument(
        "--scenario",
        default="new_classes_incremental_with_labels",
        choices=[
            "new_instances",
            "new_classes_multitask",
            "new_classes_multitask_unbalanced",
            "new_classes_incremental",
            "new_classes_incremental_with_labels"
        ],
    )
    parser.add_argument(
        "--num_workers",
        default=4,
        type=int,
        required=False,
        help="Number of subprocesses to use for dataloader (default=%(default)s)",
    )
    parser.add_argument(
        "--dynamic",
        default=0,
        type=int,
        help="Number of subprocesses to use for dataloader (default=%(default)s)",
    )
    parser.add_argument(
        "--pin_memory",
        default=False,
        type=bool,
        help="Copy Tensors into CUDA pinned memory before returning them (default=%(default)s)",
    )
    parser.add_argument(
        "--batch_size",
        default=64,
        type=int,
        required=False,
        help="Number of samples per batch to load (default=%(default)s)",
    )
    parser.add_argument(
        "--num_tasks",
        default=5,
        type=int,
        required=False,
        help="Number of tasks per dataset (default=%(default)s)",
    )
    # training args
    parser.add_argument(
        "--strategy",
        type=str,
        default="naive",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="SGD",
        help="Optimizer",
        choices=["Adam", "SGD"],
    )
    parser.add_argument(
        "--model",
        type=str,
        default="resnet18",
        help="Model type",
    )
    parser.add_argument(
        "--val_size",
        default=0.05,
        type=float,
        required=False,
        help="Validation size (default=%(default)s)",
    )


    parser.add_argument(
        "--nepochs",
        default=1,
        type=int,
        required=False,
        help="Number of epochs per training session (default=%(default)s)",
    )
    parser.add_argument(
        "--lr",
        default=0.05,
        type=float,
        required=False,
        help="Starting learning rate (default=%(default)s)",
    )
    parser.add_argument(
        "--clipping",
        default=10000,
        type=float,
        required=False,
        help="Clip gradient norm (default=%(default)s)",
    )
    parser.add_argument(
        "--momentum",
        default=0.95,
        type=float,
        required=False,
        help="Momentum factor (default=%(default)s)",
    )
    parser.add_argument(
        "--weight_decay",
        default=0.0002,
        type=float,
        required=False,
        help="Weight decay (L2 penalty) (default=%(default)s)",
    )
    parser.add_argument(
        "--source_results_path",
        default=None,
        type=str,
        required=False,
        help="Sometimes used to load model from another training run",
    )

    parser.add_argument("--stop_after", type=int, default=-1, help="Stop training at task")

    # Dataset
    parser.add_argument("--restricted", type=json.loads, default={}, required=False)

    # Logging and evaluation
    parser.add_argument("--tensorboard", type=int, default=0)
    parser.add_argument("--wandblog", type=int, default=0)
    parser.add_argument("--eval_on_previous", action="store_true", required=False)

    # Plugins
    parser.add_argument("--schedule", type=bool, default=False)
    parser.add_argument("--mean_evaluation", type=int, default=0)
    parser.add_argument("--retain_best", action="store_true", required=False)
    parser.add_argument("--parallel_evaluation", type=int, default=0)
    parser.add_argument("--eval_mode", type=str, default="iteration", required=False)
    parser.add_argument("--use_transforms", action="store_true", required=False, help="activate input transforms")
    parser.add_argument("--milestone", type=int, default=1, help="activate input transforms")

    # Hyperparameters for plugins
    parser.add_argument("--lmb", type=float, default=0.4)
    parser.add_argument("--lmb2", type=float, default=0.4)
    parser.add_argument("--every", type=int, default=-1, help="Evaluate every k training iterations")
    parser.add_argument("--memory_size", type=int, default=500, required=False)
    parser.add_argument("--start_from_pretrained", type=str, default=None, required=False)
    parser.add_argument("--expstep", type=bool, default=False, help="activate input transforms")


    # Set config
    config_parser = argparse.ArgumentParser(add_help=False)
    config_parser.add_argument("--config", required=False, default=None)

    # Parses config file to set default arguments
    args, remaining_argv = config_parser.parse_known_args()

    if args.config:
        set_config_args(parser, args.config)

    args = parser.parse_args(remaining_argv)

    # Create Logger
    exp_name = args.exp_name if args.exp_name is not None else "default"
    print(os.getcwd())
    args.results = modify_and_create_path(args.results, args)
    exp_path = os.path.join(args.results, exp_name)
    if not os.path.isdir(exp_path):
        os.makedirs(exp_path, exist_ok=True)
    args.results_path = os.path.join(exp_path, str(args.seed))

    if not os.path.isdir(args.results_path):
        os.makedirs(args.results_path, exist_ok=True)

    if save_config:
        # Save config under results dir
        with open(os.path.join(args.results_path, "config.yml"), "w") as f:
            f.write("!!python/object:argparse.Namespace\n")
            yaml.dump(vars(args), f)

    return args

def modify_and_create_path(path, args):
    if args.lmb > 1.0:
        file = "four"
    elif args.lmb == 1.0:
        file = "normal"
    else:
        file = "half"
    path_parts = path.split('/')
    path_parts.insert(-1, file)
    modified_path = '/'.join(path_parts)
    os.makedirs(modified_path, exist_ok=True)
    return modified_path