import argparse
from pathlib import Path
import wandb
import sys
sys.path.append("~pythia_replicate")
from lib.model_setup import load_model_and_tokenizer
from lib.translation import BilingualFewShotDataset
from lib.translation import evaluate_translation_accuracy

def compute_translation_metrics(
    model,
    tokenizer,
    lang1,
    lang2,
    batch_size=32,
    max_new_tokens=50,
    n_shots=10,
    step=None,
    random_pairs=False,
    debug=False,
):
    if lang1 is None:
        lang1 = ["spa", "jpn", "arb", "swe", "cmn", "eng", "fra"]
    for lang in lang1:
        dataset = BilingualFewShotDataset(
            Path("~pythia_replicate/dataset/parallel_concepts.csv"),
            lang,
            lang2,
            n_shots,
            random_pairs,
        )
        wandb.summary[f"samples_for_{lang}_to_{lang2}"] = len(dataset)
        if debug:
            print(dataset.prompts)
            print(dataset.targets)

        device = next(model.parameters()).device
        accuracy = evaluate_translation_accuracy(
            model, tokenizer, dataset, device, batch_size, max_new_tokens, random_pairs
        )
        if step is not None:
            wandb.log(
                {f"{lang}_to_{lang2}_accuracy(num_of_shots={n_shots})": accuracy},
                step=step,
            )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", default=None)
    parser.add_argument("--non_local", action="store_true")
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument(
        "--lang1",
        nargs="*",
        default=None,
        help="List of source languages (e.g., --lang1 spa jpn arb) or comma-separated string",
    )
    parser.add_argument("--lang2", default="eng")
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--max_new_tokens", type=int, default=8)
    parser.add_argument("--n_shots", type=int, default=5)
    parser.add_argument("--random_pairs", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    lang1 = args.lang1

    wandb_run_name = f"{args.model_type}"
    model_type = f"{args.model_type}"

    wandb.init(
        project="pythia_replicate_all_benchmark",
        name=wandb_run_name,
    )

    wandb.config.update(vars(args))
    for step in range(args.first_step, args.last_step, 100):
        if args.non_local:
            revision = f"step{step}"
            model, tokenizer = load_model_and_tokenizer(
                args.model_type, revision=revision
            )
        else:
            model_name = f"~pythia_replicate/hf_output/{args.model_type}/step={step}"
            model, tokenizer = load_model_and_tokenizer(model_name, revision=None)
        model.eval()
        compute_translation_metrics(
            model,
            tokenizer,
            lang1,
            args.lang2,
            args.batch_size,
            args.max_new_tokens,
            args.n_shots,
            step,
            args.random_pairs,
            args.debug,
        )
