import os
import re
import json
import subprocess
import argparse
from tqdm import tqdm
import glob

from lightning.pytorch.loggers import WandbLogger

# taking "finetuned_path" as cli argument
parser = argparse.ArgumentParser()
parser.add_argument("--result_dir", type=str, required=True)
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--no_instruction", type=str, default=False)
parser.add_argument("--include_long_prompt", type=str, default=False)
parser.add_argument("--include_meta_tokens", type=str, default=False)
parser.add_argument("--overwrite_results", type=str, default=False)
parser.add_argument("--keystrings", type=str, default=None)
parser.add_argument("--resume_from", type=str, default=None)
parser.add_argument("--prompt_style", type=str, required=False)
parser.add_argument("--task", type=str, default="retrieval_tiny")
parser.add_argument("--task_list", type=str, default=None)
parser.add_argument("--task_locks", type=str, default=False)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--ddp", type=str, default=False)
parser.add_argument("--pooling_method", type=str, default="lasttoken")
parser.add_argument("--num_few_shot", type=int, default=0)
parser.add_argument("--skip_combining", action="store_true")
parser.add_argument("--skip_eval", action="store_true")
parser.add_argument("--skip_wandb", action="store_true")
parser.add_argument("--push_to_wandb", action="store_true")
parser.add_argument("--additional_wandb_tags", type=str, default=None)
parser.add_argument("--wandb_dir", type=str, default=None)
args = parser.parse_args()

AXONN_DDP_CKPT_DIR = os.path.join(
    args.model_dir, "checkpoints-AxonnStrategy/tp_row_0_col_0_depth_0" if not args.ddp else "checkpoints-DDPStrategy"
)
CKPTS_TO_EVAL_DIR = os.path.join(args.model_dir, "combined_ckpts" if not args.ddp else "checkpoints-DDPStrategy")

if args.no_instruction:
    sub_dir_name = f"wo_instruction"
else:
    if args.include_long_prompt:
        sub_dir_name = f"w_long_prompt"
    else:
        sub_dir_name = f"w_short_prompt"
    if args.include_meta_tokens:
        sub_dir_name += "_meta_tokens"

MTEB_DIR = os.path.join(args.result_dir, "mteb")  # , sub_dir_name)
WANDB_DIR = args.wandb_dir if args.wandb_dir is not None else MTEB_DIR
# MTEB_DIR = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/results/mteb/w_short_prompt_meta_tokens/v1_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp64_128N"

# globbing the step-*.pth files in `model_dir/checkpoints-AxonnStrategy/tp_row_0_col_0_depth_0`
ckpts = [f for f in os.listdir(AXONN_DDP_CKPT_DIR) if f.endswith(".pth")]
sorted_ckpts = sorted(ckpts, key=lambda x: int(re.findall(r"step-(\d+)", x)[0]))

# checking if the ckpts are already converted to .pth
keystring = ["step-" + re.findall(r"step-(\d+)", ckpt)[0] for ckpt in sorted_ckpts]

if args.keystrings is not None:
    passed_keystring = args.keystrings.split(",")
    keystring = [key for key in keystring if any([k in key for k in passed_keystring])]

################### CKPT COMBINING ############################################
if args.resume_from is not None:
    keystring = [key for key in keystring if key > args.resume_from]

# we'll check if the ckpts are already converted to .pth and skip them
if not args.ddp:
    for key in tqdm(keystring, desc="Converting axonn ckpts to .pth"):
        if os.path.exists(os.path.join(CKPTS_TO_EVAL_DIR, f"{key}_ckpt.pth")):
            continue
        command = [
            "python",
            "scripts/combine_tensor_parallel_checkpoints.py",
            "--out_dir",
            args.model_dir,
            "--keystring",
            key,
            "--combined_ckpt_subdir",
            "combined_ckpts",
            "--combined_ckpt_name",
            f"{key}_ckpt",
        ]
        if not args.skip_combining:
            print("Running command:", " ".join(command))
            subprocess.run(command)
        else:
            print("Would run command:", " ".join(command))

################### EVAL ######################################################
# now we'll evaluate the model ckpts
for ckpt in sorted_ckpts if args.ddp else keystring:
    command = [
        "python",
        "eval/mteb_eval.py",
        "--model_path",
        args.model_dir,
        "--checkpoint_dir",
        os.path.join(CKPTS_TO_EVAL_DIR, ckpt if args.ddp else f"{ckpt}_ckpt.pth"),
        "--no_instruction",
        str(args.no_instruction),
        "--include_long_prompt",
        str(args.include_long_prompt),
        "--include_meta_tokens",
        str(args.include_meta_tokens),
        "--overwrite_results",
        str(args.overwrite_results),
        "--result_dir",
        MTEB_DIR,
        "--prompt_style",
        str(args.prompt_style),
        "--task",
        str(args.task),
        "--task_list",
        str(args.task_list),
        "--task_locks",
        str(args.task_locks),
        "--batch_size",
        str(args.batch_size),
        "--pooling_method",
        str(args.pooling_method),
        "--num_few_shot",
        str(args.num_few_shot),
    ]
    if not args.skip_eval:
        print("Running command:", " ".join(command))
        subprocess.run(command)
    else:
        print("Would run command:", " ".join(command))

################### WANDB LOGGING #############################################
# finally we'll log the results to wandb
if args.skip_wandb:
    print("Skipping wandb logging")
    exit()

run_config_path = os.path.join(args.model_dir, "run_config.json")
with open(run_config_path, "r") as f:
    cfg = json.load(f)

# add "mteb" to the tags
cfg["wandb_tags"].append("mteb")

# add the additional tags if they exist
if args.additional_wandb_tags is not None:
    cfg["wandb_tags"].extend(args.additional_wandb_tags.split(","))

# and change the save_dir to the MTEB_DIR
cfg["out_dir"] = args.result_dir

# we clear out a wandb run if it exists since we log in bulk under the run_name not incrementally
if os.path.exists(os.path.join(cfg["out_dir"], "wandb")):
    os.system(f"rm -rf {os.path.join(cfg['out_dir'], 'wandb')}")

logger = WandbLogger(
    entity="XXXX-6",
    project=cfg["logger_project"],
    name=cfg["run_name"],
    save_dir=cfg["out_dir"],
    tags=cfg["wandb_tags"],
    offline=True,  # manual sync after
)

# now we grab all of the result files in the WANDB_DIR and log them to wandb

result_files = glob.glob(os.path.join(WANDB_DIR, "**/*.json"), recursive=True)
parents_and_files = [(os.path.dirname(f), os.path.basename(f)) for f in result_files if "wandb" not in f]


# improve the metric/eval setting separation by using these args to make a new "prompt_style" key
# --include_meta_tokens True \
# --pooling_method mean \
# --include_long_prompt True \
# --prompt_style bos_task_query_task_doc \

# new_prompt_style = f"meta{args.include_meta_tokens}_pool{args.pooling_method}_long{args.include_long_prompt}_style{args.prompt_style}"
# new_prompt_style = f"prompt_style_{args.prompt_style}_pooling_method_{args.pooling_method}_no_instruction_{args.no_instruction}_include_long_prompt_{args.include_long_prompt}_include_meta_tokens_{args.include_meta_tokens}"

for parent, file in tqdm(parents_and_files, desc="Logging results to wandb", total=len(parents_and_files)):
    if file == "model_meta.json":
        continue
    with open(os.path.join(parent, file), "r") as f:
        results = json.load(f)

    # NOTE this is brittle but
    # example path structure is some prefix and then
    # mteb/<prompt_style>/<run_name>/<ckpt_step_name>/<name_of_task>.json sooo...
    # was. is now
    # mteb/<run_name>/<prompt_style>/<ckpt_step_name>/<name_of_task>.json sooo...
    def split_path(path):
        parts = path.split("/")
        return parts[-3], parts[-2], parts[-1]

    path_run_name, prompt_style, ckpt_step_name = split_path(parent)
    task_name = file.split(".")[0]
    # isolate the 8 digit justified int ckpt step number
    ckpt_optim_step = int(ckpt_step_name.split("-")[-1].split("_")[0])

    print(
        f"Logging results for {path_run_name} ckpt {ckpt_optim_step} for task {task_name} with prompt style {prompt_style}"
    )

    # NOTE ultra brittle so checking annoyingly
    assert (
        path_run_name == cfg["run_name"]
    ), f"run name {path_run_name} does not match run name in run_config.json {cfg['run_name']}"
    # assert prompt_style == new_prompt_style, f"prompt style {prompt_style} does not match prompt style in args {new_prompt_style}"

    assert len(results["scores"]) == 1
    assert len(results["scores"]["test"]) == 1

    scores = results["scores"]["test"][0]

    assert (
        task_name == results["task_name"]
    ), f"task name {task_name} does not match task name in results {results['task_name']}"

    scores.pop("hf_subset")
    scores.pop("languages")

    results_to_log = {f"mteb/{task_name}/{prompt_style}/{k}": v for k, v in scores.items()}
    # results_to_log = {f"mteb/{task_name}/{new_prompt_style}/{k}" : v for k, v in scores.items()}

    results_to_log["optimizer_step"] = ckpt_optim_step

    logger.experiment.log(results_to_log)

# finalize the logging
logger.experiment.finish()

# optionally, launch the wandb sync command from the parent of the resulting wandb dir
# which will be args.result_dir, if you have internet
if args.push_to_wandb:
    # sleep for 5 seconds to let the wandb process finish writing the files
    import time

    time.sleep(5)

    command = f"""\
    cd {args.result_dir} && \
    wandb sync --sync-all .
    """
    os.system(command)

print("All done!")
