import os
import shlex, subprocess
import argparse
import torch
import json
from rich.console import Console
console = Console()


ROOT_DIR = os.getcwd()
CFG_PATH_BOOTSTRAP = os.path.join(ROOT_DIR, 'configs', 'selflearn_3b_bootstrap.yaml')
CFG_PATH_FINETUNE = os.path.join(ROOT_DIR, 'configs', 'selflearn_3b_finetune.yaml')
CFG_PATH_EVAL = os.path.join(ROOT_DIR, 'configs', 'selflearn_3b_eval.yaml')
SCRIPT_PATH_BOOTSTRAP = os.path.join(ROOT_DIR, 'self_learn_bootstrap.py')
SCRIPT_PATH_FINETUNE = os.path.join(ROOT_DIR, 'self_learn_finetune.py')
SCRIPT_PATH_EVAL = os.path.join(ROOT_DIR, 'self_learn_eval.py')
DEFAULT_PARAMS = "--knowledge_conditioning=combined --memory_decision=compute --debug=false"


def get_bootstrap_cmd(args, num_loop: int):
    items = ["python -m torch.distributed.launch"]
    items.append(f"--nproc_per_node={args.num_gpu_bootstrap} --master_port={args.master_port}")
    items.append(f"{SCRIPT_PATH_BOOTSTRAP}")
    items.append(f"--config_path={CFG_PATH_BOOTSTRAP}")
    if num_loop == 0:
        items.append(f"--model.model_file={args.initial_model_path}")
        items.append("--model.model_init_from_hf=false")
    else:
        items.append("--model.model_init_from_hf=true")
    items.append(DEFAULT_PARAMS)
    items.append("--do_eval=False")
    items.append(f"--server_port={args.server_port}")
    items.append(f"--name_space={args.name_space}")
    items.append(f"--num_bootstrap={args.num_bootstrap}")
    items.append(f"--num_loop={num_loop}")
    items.append(f"--seed={args.seed + num_loop}")
    items.append(f"--bt_inc_rate={args.bt_inc_rate}")
    items.append(f"--inc_rate={args.inc_rate}")
    items.append(f"--base_threshold={args.base_threshold}")
    items.append(f"--use_cos_sim={args.use_cos_sim}")
    return " ".join(items)


def get_finetune_cmd(args, num_loop: int):
    items = [f"CUDA_VISIBLE_DEVICES={args.visible_gpus} python -m torch.distributed.launch"]
    items.append(f"--nproc_per_node={args.num_gpu_finetune} --master_port={args.master_port}")
    items.append(f"{SCRIPT_PATH_FINETUNE}")
    items.append(f"--config_path={CFG_PATH_FINETUNE}")
    if num_loop == 0:
        items.append(f"--model.model_file={args.initial_model_path}")
        items.append("--model.model_init_from_hf=false")
    else:
        items.append("--model.model_init_from_hf=true")
    items.append(DEFAULT_PARAMS)
    items.append("--do_eval=False")
    items.append(f"--server_port={args.server_port}")
    items.append(f"--name_space={args.name_space}")
    items.append(f"--num_bootstrap={args.num_bootstrap}")
    items.append(f"--num_loop={num_loop}")
    items.append(f"--trainer.learning_rate={args.lr}")
    items.append(f"--trainer.per_device_train_batch_size={args.bs}")
    items.append(f"--trainer.gradient_accumulation_steps={args.acc}")
    items.append(f"--scheme=bt --finetune_num_epoch={args.finetune_num_epoch}")
    return " ".join(items)


def get_eval_cmd(args, num_loop):
    items = [f"python -m torch.distributed.launch"]
    items.append(f"--nproc_per_node={args.num_gpu_bootstrap} --master_port={args.master_port}")
    items.append(f"{SCRIPT_PATH_EVAL}")
    items.append(f"--config_path={CFG_PATH_EVAL}")
    if num_loop == 0:
        items.append(f"--model.model_file={args.initial_model_path}")
        items.append("--model.model_init_from_hf=false")
    else:
        items.append("--model.model_init_from_hf=true")
    items.append(DEFAULT_PARAMS)
    items.append("--do_train=False --trainer.per_device_eval_batch_size=1")
    items.append(f"--server_port={args.server_port}")
    items.append(f"--name_space={args.name_space}")
    items.append(f"--num_bootstrap={args.num_bootstrap}")
    items.append(f"--num_loop={num_loop}")
    items.append(f"--scheme=bt --finetune_num_epoch={args.finetune_num_epoch}")
    items.append(f"--max_num_entries={args.max_num_entries}")
    return " ".join(items)


def get_result_score(args, num_loop):
    dir_name = f"selflearn_{args.name_space}_{num_loop}_{args.num_bootstrap}_scheme-bt_epoch-{args.finetune_num_epoch}"
    dir_path = os.path.join(ROOT_DIR, args.exp_dir)
    dir_path = os.path.join(dir_path, dir_name)
    result_path = os.path.join(dir_path, 'selflearn_eval_log.json')
    with open(result_path, 'r') as fin:
        data = json.load(fin)
    return data['score/valid']


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--num_bootstrap', type=int, default=4000)
    ap.add_argument('--max_iter', type=int, default=10)
    ap.add_argument('--master_port', type=int, default=1111)
    ap.add_argument('--visible_gpus', type=str, default='0,1,2,3')
    ap.add_argument('--server_port', type=int, default=8080)
    ap.add_argument('--name_space', type=str, default='hexa-highest-SR')
    ap.add_argument('--initial_model_path', type=str, default=None, required=True)
    ap.add_argument('--seed', type=int, default=9)
    ap.add_argument('--exp_dir', type=str, default='experiment/bb3_3b')
    ap.add_argument('--finetune_num_epoch', type=int, default=1)
    ap.add_argument('--max_num_entries', type=int, default=10000)
    ap.add_argument('--bt_inc_rate', type=float, default=.1)
    ap.add_argument('--inc_rate', type=float, default=0.0)
    ap.add_argument('--lr', type=float, default=2e-6)
    ap.add_argument('--start_num_loop', type=int, default=0)
    ap.add_argument('--use_cos_sim', action='store_true', default=False)
    ap.add_argument('--base_threshold', type=float, default=0.25)
    ap.add_argument('--early_stop', action='store_true', default=False)
    ap.add_argument('--bs', type=int, default=1)
    ap.add_argument('--acc', type=int, default=4)
    ap.add_argument('--skip', action='store_true', default=False)
    ap.add_argument('--do_eval', action='store_true', default=False)
    args = ap.parse_args()

    num_gpu_available = torch.cuda.device_count()
    num_gpu_bootstrap = num_gpu_eval = num_gpu_available

    visible_gpus = list(map(int, args.visible_gpus.split(',')))
    num_gpu_finetune = len(visible_gpus)

    args.num_gpu_bootstrap = num_gpu_bootstrap
    args.num_gpu_eval = num_gpu_eval
    args.num_gpu_finetune = num_gpu_finetune

    eval_scores = [0]
    for i in range(args.max_iter):
        if i >= args.start_num_loop:
            if i > args.start_num_loop or not args.skip:
                bootstrap_cmd = get_bootstrap_cmd(args, i)
                console.print(bootstrap_cmd, style="bold yellow")
                bootstrap_out = subprocess.run(bootstrap_cmd, shell=True)
                if bootstrap_out.returncode != 0:
                    raise Exception(f'Invalid result: {bootstrap_out.returncode}')

            finetune_cmd = get_finetune_cmd(args, i)
            console.print(finetune_cmd, style="bold yellow")
            finetune_out = subprocess.run(finetune_cmd, shell=True)
            if finetune_out.returncode != 0:
                raise Exception(f'Invalid result: {finetune_out.returncode}')

            if args.do_eval:
                eval_cmd = get_eval_cmd(args, i)
                console.print(eval_cmd, style="bold yellow")
                eval_out = subprocess.run(eval_cmd, shell=True)
                if eval_out.returncode != 0:
                    raise Exception(f'Invalid result: {eval_out.returncode}')

                prev_score = eval_scores[-1]
                score = get_result_score(args, i)
                if prev_score > score and args.early_stop:
                    break
                eval_scores.append(score)
                print(eval_scores)

