import os
import re
import subprocess
import argparse

# taking "finetuned_path" as cli argument
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, required=True)
parser.add_argument("--checkpoint_dir", type=str, required=True)
parser.add_argument("--prefix_add_eos", type=str, default=True)
parser.add_argument("--suffix_add_eos", type=str, default=True)
parser.add_argument("--siglip_loss", type=str, default=False)
parser.add_argument("--include_meta_tokens", type=str, default=False)
args = parser.parse_args()

# base_model_path = "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/external/TinyLlama-1.1B-intermediate-step-1431k-3T"
# finetuned_path = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/output/fixed_cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-2-ctx-rand-batch_negative_ddp_PP_lr_1e-5"
# finetuned_path = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/output/fixed_cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-2-ctx-rand-batch_negative_ddp_RR_lr_1e-5"
# finetuned_path = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/output/concatedOrca-retrieval-dual-causal-llama-1.1b-bsz-2-ctx-rand-batch_negative_ddp_PP_lr_1e-5"
checkpoint_dir = args.checkpoint_dir
# finetuned_path = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/orca_finetune_rpj_v2_100B-retrieval-dual-causal-pythia-160m-mbsz-23-wbsz-736-ctx-var-batch_negative_ddp_RR_lr_1e-3/checkpoints-ddp"
# finetuned_path = "/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/dolma-lm-pythia-160m-mbsz-24-wbsz-192_ddp_lr_6e-4/checkpoints-ddp"
# ckpt_steps = [2000, 20000, 38000, 54000]
# ckpt_steps = [2000, 4000, 6000, 8000, 10000, 12000, 16000, 16585]
# ckpt_steps = [1000, 3000, 5000, 7000, 9000, 11000, 13000, 14000]
# ckpt_steps = [1000, 3000, 5000, 7000, 9000, 10000]
# ckpt_steps = [500, 1000, 1500, 2000, 4000, 6000, 8000, 10000]
# ckpt_steps = [1000, 2000, 3000, 4000, 5000, 5739]
ckpt_steps = [2000, 4000, 6000, 8000, 10000, 12000, 14000, 16000, 18000, 20000, 22000, 24000, 26000, 28000, 30000, 32000, 34000, 36000, 38000, 40000, 42000, 44000, 45000]
# ckpt_steps = [1000, 2000, 3000, 4000, 5000, 6000]
# ckpt_steps = [10000]
data_dirs = [
    # "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/cosmopedia_retrieval_val_data_pythia",
             "/XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/orca_retrieval_val_data_10k_pythia"]

# grab all the files that ends with .pth
ckpts = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
ckpts = [f for f in ckpts if int(re.findall(r'step-(\d+)', f)[0]) in ckpt_steps]

# sort the files based on the step number
sorted_ckpts = sorted(ckpts, key=lambda x: int(re.findall(r'step-(\d+)', x)[0]))
print("Checkpoints to evaluate:", sorted_ckpts)

for data_dir in data_dirs:
    for split_sequence in [
        False,
        # True, 
        ]:
        if "cosmopedia" in data_dir and split_sequence == True:
            continue    # cause cosmopedia is already splitted and doesn't have a natural split
        for ckpt in sorted_ckpts:
            # launch the evaluation script
            # python eval/eval_retrieval_anticausal.py --checkpoint_dir /XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/external/TinyLlama-1.1B-intermediate-step-1431k-3T --finetuned_path /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/output/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-16-ctx-rand-batch_negative_ddp_PP_lr_1e-5/step-00052000-cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-16-ctx-rand-batch_negative_ddp_PP_lr_1e-5.pth --precision bf16-mixed --attn_type causal_attn --data_dir /XXXX-30/XXXX-29/XXXX-31/proj-shared/language_datasets/processed/cosmopedia_retrieval_val_data
            command = [
                'python', 'eval/eval_retrieval_anticausal.py',
                '--model_path', args.model_dir,
                '--checkpoint_dir', os.path.join(checkpoint_dir, ckpt),
                '--data_dir', data_dir,
                '--random_split', str(split_sequence),
                '--prefix_add_eos', str(args.prefix_add_eos),
                '--suffix_add_eos', str(args.suffix_add_eos),
                '--siglip_loss', str(args.siglip_loss),
                '--include_meta_tokens', str(args.include_meta_tokens),
            ]
            print("Running command:", ' '.join(command))
            subprocess.run(command)

print("Done!")