import argparse
import os
import json
import datetime


def glue_main(args):
    task = args.target_task  # should be one of COLA, SST2 and QNLI tasks
    adapter_name = args.adapter_name
    reg_lambda = args.regularization_lambda
    parametrize_S = args.parametrize_S
    gradient_type = args.gradient_type
    wandb_project = args.wandb_project
    wandb_mode = args.wandb_mode
    cuda_device = args.device
    epoch = args.epochs
    lr = args.lr
    init_lora_weights = args.init_lora_weights
    seed = args.seed
    # model_name = "microsoft/deberta-v3-base"  # roberta-large, microsoft/deberta-v2-xxlarge", microsoft/deberta-v3-base
    model_name = args.model_name
    rank = args.rank
    lora_alpha = args.alpha
    use_dora = args.use_dora
    cls_lr = args.lr
    batch_size = args.batch_size

    now = datetime.datetime.now()
    formatted_time = now.strftime("%Y-%m-%dT%H:%M:%S.%f")
    results_dir = os.path.join(wandb_project, formatted_time)
    wandb_name = f"{formatted_time}_{task}_{adapter_name}_i{init_lora_weights}_r{rank}_l{reg_lambda}"
    run_str = f"""CUDA_VISIBLE_DEVICES="{cuda_device}" \
        WANDB_MODE={wandb_mode} WANDB_PROJECT={wandb_project} \
        python run_glue.py \
            --adapter_name {adapter_name} \
            --regularization_lambda {reg_lambda} \
            --parametrize_S {parametrize_S} \
            --gradient_type {gradient_type} \
            --use_dora {use_dora} \
            --run_name {wandb_name} \
            --init_lora_weights {init_lora_weights} \
            --model_name_or_path {model_name} \
            --lora_rank {rank} \
            --lora_alpha {lora_alpha} \
            --task_name {task} \
            --do_train \
            --do_eval \
            --seed {seed}\
            --max_seq_length 128 \
            --per_device_train_batch_size {batch_size} \
            --learning_rate {lr} \
            --cls_learning_rate {cls_lr} \
            --num_train_epochs {epoch} \
            --save_steps -1 \
            --save_strategy no \
            --evaluation_strategy epoch  \
            --logging_steps 300 \
            --overwrite_output_dir \
            --output_dir {results_dir}"""
    os.system(run_str)
    cli_run_args = {
        "epoch": epoch,
        "task": task,
        "model_name": model_name,
        "adapter": adapter_name,
        "lr": lr,
        "cls_ls": cls_lr,
        "seed": seed,
        "rank": rank,
        "reg_lambda": reg_lambda,
        "parametrize_S": parametrize_S,
        "gradient_type": gradient_type,
        "wandb_name": wandb_name,
    }
    with open(os.path.join(results_dir, "cli_run_args.json"), "w") as file:
        json.dump(cli_run_args, file, indent=4)

    # analyze_str = f"""CUDA_VISIBLE_DEVICES="{cuda_device}" \
    #     WANDB_MODE=disabled \
    #     python analyze_lora.py --path_to_exp_dir {results_dir} --ext safetensors"""
    # os.system(analyze_str)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # MNLI (too large)
    # python scripts/run_glue.py --target_task mnli --device 2 --epochs 10 --init_lora_weights default
    # SST-2
    # python scripts/run_glue.py --target_task sst2 --device 1 --epochs 20 --init_lora_weights default
    # MRPC
    # python scripts/run_glue.py --target_task mrpc --device 0 --epochs 50 --init_lora_weights default
    # CoLA
    # python scripts/run_glue.py --target_task cola --device 0 --epochs 20 --init_lora_weights default
    # QNLI
    # python scripts/run_glue.py --target_task qnli --device 3 --epochs 5 --init_lora_weights default

    def str2bool(value):
        if isinstance(value, bool):
            return value
        if value.lower() == "true":
            return True
        elif value.lower() == "false":
            return False
        else:
            raise argparse.ArgumentTypeError("Boolean value expected.")

    parser.add_argument("--target_task", required=True)
    parser.add_argument("--wandb_project", type=str, required=True)
    parser.add_argument("--wandb_mode", type=str, required=True)
    parser.add_argument("--adapter_name", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--regularization_lambda", type=float, required=True)
    parser.add_argument("--gradient_type", type=str, required=True)
    parser.add_argument("--parametrize_S", type=str, required=True)
    parser.add_argument("--rank", type=int, required=True)
    parser.add_argument("--use_dora", type=str2bool, required=True)
    parser.add_argument("--alpha", type=int, required=True)
    parser.add_argument("--device", type=int, required=True)
    parser.add_argument("--epochs", type=int, required=True)
    parser.add_argument("--lr", type=float, required=True)
    parser.add_argument("--batch_size", type=int, required=True)
    parser.add_argument("--init_lora_weights", type=str, required=True)
    parser.add_argument("--seed", type=int, required=True)

    args = parser.parse_args()

    glue_main(args)
