import random
import torch
import logging
import multiprocessing
import numpy as np

logger = logging.getLogger(__name__)


def add_args(parser):
    parser.add_argument(
        "--task",
        type=str,
        required=True,
        choices=[
            "summarize",
            "concode",
            "translate",
            "refine",
            "defect",
            "clone",
            "multi_task",
            "mathqa",
            "fixeval",
            "mbpp",
            "conala",
            "avatar",
        ],
    )
    parser.add_argument("--sub_task", type=str, default="")
    parser.add_argument("--lang", type=str, default="")
    parser.add_argument("--eval_task", type=str, default="")
    parser.add_argument(
        "--model_type",
        default="codet5",
        type=str,
        choices=[
            "roberta",
            "bart",
            "codet5",
            "codet5_custom",
            "plbart",
            "codebert",
            "graphcodebert",
            "unixcoder",
            "codegen",
        ],
    )
    parser.add_argument("--add_lang_ids", action="store_true")
    parser.add_argument("--data_num", default=-1, type=int)
    parser.add_argument("--start_epoch", default=0, type=int)
    parser.add_argument("--num_train_epochs", default=100, type=int)
    parser.add_argument("--patience", default=5, type=int)
    parser.add_argument("--cache_path", type=str, required=True)
    parser.add_argument("--summary_dir", type=str, required=True)
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--res_dir", type=str, required=True)
    parser.add_argument("--res_fn", type=str, default="")
    parser.add_argument(
        "--add_task_prefix",
        action="store_true",
        help="Whether to add task prefix for t5 and codet5",
    )
    parser.add_argument("--save_last_checkpoints", action="store_true")
    parser.add_argument("--always_save_model", action="store_true")
    parser.add_argument(
        "--do_eval_bleu",
        action="store_true",
        help="Whether to evaluate bleu on dev set.",
    )

    ## Required parameters
    parser.add_argument(
        "--model_name_or_path",
        default="roberta-base",
        type=str,
        help="Path to pre-trained model: e.g. roberta-base",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--load_model_path",
        default=None,
        type=str,
        help="Path to trained model: Should contain the .bin files",
    )
    ## Other parameters
    parser.add_argument(
        "--train_filename",
        default=None,
        type=str,
        help="The train filename. Should contain the .jsonl files for this task.",
    )
    parser.add_argument(
        "--dev_filename",
        default=None,
        type=str,
        help="The dev filename. Should contain the .jsonl files for this task.",
    )
    parser.add_argument(
        "--test_filename",
        default=None,
        type=str,
        help="The test filename. Should contain the .jsonl files for this task.",
    )

    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="roberta-base",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--max_source_length",
        default=64,
        type=int,
        help="The maximum total source sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument(
        "--max_target_length",
        default=32,
        type=int,
        help="The maximum total target sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )

    parser.add_argument("--do_train", action="store_true", help="Whether to run eval on the train set.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action="store_true",
        help="Set this flag if you are using an uncased model.",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")

    parser.add_argument(
        "--train_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--eval_batch_size",
        default=8,
        type=int,
        help="Batch size per GPU/CPU for evaluation.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--learning_rate",
        default=5e-5,
        type=float,
        help="The initial learning rate for Adam.",
    )
    parser.add_argument("--beam_size", default=10, type=int, help="beam size for beam search")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

    parser.add_argument(
        "--save_steps",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--log_steps",
        default=-1,
        type=int,
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--eval_steps", default=-1, type=int, help="")
    parser.add_argument("--train_steps", default=-1, type=int, help="")
    parser.add_argument("--warmup_steps", default=100, type=int, help="Linear warmup over warmup_steps.")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="For distributed training: local_rank",
    )
    parser.add_argument("--seed", type=int, default=1234, help="random seed for initialization")

    # Tree parameters
    parser.add_argument("--parse_as_tree", action="store_true")

    args = parser.parse_args()

    if args.task in ["summarize"]:
        args.lang = args.sub_task
    elif args.task in ["refine", "concode", "clone"]:
        args.lang = "java"
    elif args.task == "defect":
        args.lang = "c"
    elif args.task == "translate":
        args.lang = "c_sharp" if args.sub_task == "java-cs" else "java"
    elif args.task in ["mathqa", "mbpp", "conala"]:
        args.lang = "python"
    elif args.task == "fixeval":
        args.lang = args.sub_task  # java or python
    elif args.task == "avatar":
        args.lang = "python" if args.sub_task == "java-py" else "java"

    # TAG: MY
    if args.task == "summarize":
        args.ip_lang = args.sub_task
        args.op_lang = None
    elif args.task == "refine":
        args.ip_lang = "java"
        args.op_lang = "java"
    elif args.task == "concode":
        args.ip_lang = None
        args.op_lang = "java"
    elif args.task == "clone":
        args.ip_lang = "java"
        args.op_lang = None
    elif args.task == "defect":
        args.ip_lang = "c"
        args.op_lang = None
    elif args.task == "translate":
        if args.sub_task == "java-cs":
            args.ip_lang = "java"
            args.op_lang = "c_sharp"
        else:
            args.ip_lang = "c_sharp"
            args.op_lang = "java"
    elif args.task in ["mathqa", "mbpp", "conala"]:
        args.ip_lang = None
        args.op_lang = "python"
    elif args.task == "fixeval":
        args.ip_lang = args.sub_task
        args.op_lang = args.sub_task
    elif args.task == "avatar":
        if args.sub_task == "java-python":
            args.ip_lang = "java"
            args.op_lang = "python"
        elif args.sub_task == "python-java":
            args.ip_lang = "python"
            args.op_lang = "java"
        else:
            raise AssertionError("Incorrect subtask")

    return args


def set_dist(args):
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:
        # Setup for distributed data parallel
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    cpu_cont = multiprocessing.cpu_count()
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, cpu count: %d",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        cpu_cont,
    )
    args.device = device
    args.cpu_cont = cpu_cont


def set_seed(args):
    """set random seed."""
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
