# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""GPT zero-shot evaluation for LAMBADA, RACE, ARC-E, ARC-C, BOOLQ, HELLASWAG, PIQA.

This file is adapted to work with MCore and DirVAE MoE router.
"""

import math
import os
import json
from tqdm import tqdm

import torch

from megatron.training import get_args
from megatron.training import print_rank_0, is_last_rank
from megatron.training import get_tokenizer
from megatron.core import parallel_state, tensor_parallel
from megatron.training.checkpointing import load_checkpoint
from megatron.training.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training import get_model


def _iter_race_samples(race_root, split):
    file_path = os.path.join(race_root, split, f"{split}.json")
    if not os.path.isfile(file_path):
        file_path = os.path.join(race_root, f"{split}.json")
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"Cannot locate RACE split file: {file_path}")
    with open(file_path, "r", encoding="utf-8") as f:
        dataset = json.load(f)
    for data in dataset:
        art = data["article"]
        for i, q in enumerate(data["questions"]):
            yield {
                "article": art,
                "question": q,
                "choices": data["options"][i],
                "label": ord(data["answers"][i]) - ord("A"),
            }


@torch.no_grad()
def _choice_score(model, prefix_ids, choice_ids, eod_id):
    ids = prefix_ids + choice_ids
    tokens = torch.tensor(ids, device="cuda").unsqueeze(0)
    # Build masks/positions with current Megatron signature
    # pad_token set equal to eod_id in zero-shot contexts
    pad_token = eod_id
    reset_position_ids = False
    reset_attention_mask = False
    eod_mask_loss = False
    pad_mask_loss = False
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens, eod_id, pad_token, reset_position_ids, reset_attention_mask, eod_mask_loss, pad_mask_loss
    )
    logits = model(tokens, position_ids=position_ids, attention_mask=attention_mask)
    logp = 0.0
    for idx, tid in enumerate(choice_ids, start=len(prefix_ids) - 1):
        logp += torch.log_softmax(logits[0, idx], -1)[tid].item()
    return logp


def _predict_race_sample(model, tokenizer, eod_id, sample, max_ctx=1024):
    art = sample["article"]
    q = sample["question"]
    choices = sample["choices"]
    prefix = f"Article: {art}\nQuestion: {q}\nAnswer:"
    prefix_ids = tokenizer.tokenize(prefix)[-max_ctx:]
    scores = []
    for ch in choices:
        ch_ids = tokenizer.tokenize(" " + ch)
        scores.append(_choice_score(model, prefix_ids, ch_ids, eod_id))
    return int(max(range(4), key=lambda i: scores[i]))


def _evaluate_race(model, tokenizer, race_root, split="test"):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_race_samples(race_root, split)
    sample_iter = tqdm(sample_iter, desc=f"RACE-{split}")
    for sample in sample_iter:
        if _predict_race_sample(model, tokenizer, eod_id, sample) == sample["label"]:
            correct += 1
        total += 1
    torch.distributed.all_reduce(correct)
    torch.distributed.all_reduce(total)
    acc = (correct / total).item() if total.item() > 0 else 0.0
    return acc


# -----------------------------
# Helpers for parquet/jsonl loading
# -----------------------------
def _load_parquet(path):
    try:
        import pandas as pd  # type: ignore
        df = pd.read_parquet(path)
        return df.to_dict(orient="records")
    except Exception:
        try:
            import pyarrow.parquet as pq  # type: ignore
            table = pq.read_table(path)
            return table.to_pylist()
        except Exception as e:
            raise RuntimeError(f"Failed to read parquet at {path}: {e}")


def _load_jsonl(path):
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Cannot locate jsonl: {path}")
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            items.append(json.loads(line))
    return items


def _to_list_or_empty(value):
    """Robustly convert array-like to a Python list; return [] if None."""
    if value is None:
        return []
    if isinstance(value, list):
        return value
    if isinstance(value, tuple):
        return list(value)
    # numpy/pyarrow/series objects
    if hasattr(value, 'tolist'):
        try:
            return value.tolist()
        except Exception:
            pass
    # Fallback: wrap scalar as single-element list
    return [value]


def _iter_arc_samples(root, subset):
    file_path = os.path.join(root, f"{subset}-test.json")
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"Cannot locate ARC file: {file_path}")
    with open(file_path, "r", encoding="utf-8") as f:
        dataset = json.load(f)
    for item in dataset:
        question = item["question"]
        choices = item["choices"]
        answer_key = item["answerKey"]
        label = ord(answer_key) - ord("A")
        yield {"question": question, "choices": choices, "label": label}


def _predict_arc_sample(model, tokenizer, eod_id, sample, max_ctx=1024):
    question = sample["question"]
    choices = sample["choices"]
    prefix = f"Question: {question}\nAnswer:"
    prefix_ids = tokenizer.tokenize(prefix)[-max_ctx:]
    scores = []
    for choice in choices:
        choice_ids = tokenizer.tokenize(" " + choice)
        scores.append(_choice_score(model, prefix_ids, choice_ids, eod_id))
    return int(max(range(len(choices)), key=lambda i: scores[i]))


def _evaluate_arc(model, tokenizer, root, subset="ARC-Easy"):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_arc_samples(root, subset)
    sample_iter = tqdm(sample_iter, desc=f"ARC-{subset}")
    for sample in sample_iter:
        if _predict_arc_sample(model, tokenizer, eod_id, sample) == sample["label"]:
            correct += 1
        total += 1
    torch.distributed.all_reduce(correct)
    torch.distributed.all_reduce(total)
    acc = (correct / total).item() if total.item() > 0 else 0.0
    return acc


# -----------------------------
# BOOLQ (parquet/jsonl): binary QA
# -----------------------------
def _iter_boolq_samples(path):
    """Load BoolQ from a directory of parquet files, preferring validation.

    Accepted inputs:
    - Directory containing parquet files named validation-*.parquet (preferred),
      otherwise test-*.parquet, otherwise train-*.parquet.
    - Direct path to a single .parquet file.
    - Generic jsonl.
    """
    if os.path.isdir(path):
        def _pick_parquet(dir_path, prefer=("validation-", "test-", "train-")):
            try:
                files = sorted(os.listdir(dir_path))
            except Exception as e:
                raise FileNotFoundError(f"Cannot list directory: {dir_path}: {e}")
            for prefix in prefer:
                for fn in files:
                    if fn.startswith(prefix) and fn.endswith(".parquet"):
                        return os.path.join(dir_path, fn)
            raise FileNotFoundError(f"No parquet file found in {dir_path} with prefixes {prefer}")
        parquet_path = _pick_parquet(path)
        rows = _load_parquet(parquet_path)
    elif path.endswith(".parquet"):
        rows = _load_parquet(path)
    else:
        rows = _load_jsonl(path)
    for r in rows:
        q = r.get("question") or r.get("query")
        p = r.get("passage") or r.get("context") or r.get("paragraph")
        # labels may be 'answer' (bool/str) or 'label'
        ans = r.get("answer", r.get("label"))
        if isinstance(ans, str):
            ans = ans.strip().lower() in ("true", "1", "yes")
        yield {"question": q, "passage": p, "label": ans}


def _predict_boolq_sample(model, tokenizer, eod_id, sample, max_ctx=1024):
    q = sample["question"]
    p = sample["passage"]
    prefix = f"Passage: {p}\nQuestion: {q}\nAnswer:"
    prefix_ids = tokenizer.tokenize(prefix)[-max_ctx:]
    choices = [" true", " false"]
    scores = []
    for ch in choices:
        ch_ids = tokenizer.tokenize(ch)
        scores.append(_choice_score(model, prefix_ids, ch_ids, eod_id))
    return int(max(range(2), key=lambda i: scores[i]))


def _evaluate_boolq(model, tokenizer, path):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_boolq_samples(path)
    sample_iter = tqdm(sample_iter, desc="BOOLQ")
    for sample in sample_iter:
        pred_idx = _predict_boolq_sample(model, tokenizer, eod_id, sample)
        label = sample.get("label")
        if label is not None:
            gold_idx = 0 if bool(label) else 1
            if pred_idx == gold_idx:
                correct += 1
            total += 1
    if torch.distributed.is_initialized():
        torch.distributed.all_reduce(correct)
        torch.distributed.all_reduce(total)
    acc = (correct / total).item() if total.item() > 0 else 0.0
    return acc


# -----------------------------
# HellaSwag (parquet/jsonl): 4-way MCQ with context+endings
# -----------------------------
def _iter_hellaswag_samples(path):
    """Load HellaSwag from a directory of parquet files, preferring validation.

    Accepted inputs:
    - Directory containing parquet files named validation-*.parquet (preferred),
      otherwise test-*.parquet, otherwise train-*.parquet.
    - Direct path to a single .parquet file.
    """
    def _pick_parquet(dir_path, prefer=("validation-", "test-", "train-")):
        try:
            files = sorted(os.listdir(dir_path))
        except Exception as e:
            raise FileNotFoundError(f"Cannot list directory: {dir_path}: {e}")
        for prefix in prefer:
            for fn in files:
                if fn.startswith(prefix) and fn.endswith(".parquet"):
                    return os.path.join(dir_path, fn)
        raise FileNotFoundError(f"No parquet file found in {dir_path} with prefixes {prefer}")

    if os.path.isdir(path):
        parquet_path = _pick_parquet(path)
        rows = _load_parquet(parquet_path)
    elif path.endswith(".parquet"):
        rows = _load_parquet(path)
    else:
        raise FileNotFoundError(f"HELLASWAG expects a directory of parquet files or a .parquet file. Got: {path}")
    for r in rows:
        ctx = r.get("ctx") if r.get("ctx") is not None else (r.get("context") if r.get("context") is not None else r.get("passage"))
        ends_raw = r.get("endings")
        if ends_raw is None:
            ends_raw = r.get("endings_list")
        ends = _to_list_or_empty(ends_raw)
        label = r.get("label", r.get("gold"))
        try:
            label = int(label) if label is not None else None
        except Exception:
            label = None
        yield {"ctx": ctx, "endings": ends, "label": label}


def _predict_hellaswag_sample(model, tokenizer, eod_id, sample, max_ctx=1024):
    ctx = sample["ctx"]
    endings = sample["endings"]
    prefix = f"Context: {ctx}\nEnding:"
    prefix_ids = tokenizer.tokenize(prefix)[-max_ctx:]
    scores = []
    for e in endings:
        e_ids = tokenizer.tokenize(" " + str(e))
        scores.append(_choice_score(model, prefix_ids, e_ids, eod_id))
    if len(scores) == 0:
        return 0
    return int(max(range(len(scores)), key=lambda i: scores[i]))


def _evaluate_hellaswag(model, tokenizer, path):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_hellaswag_samples(path)
    sample_iter = tqdm(sample_iter, desc="HELLASWAG")
    for sample in sample_iter:
        pred_idx = _predict_hellaswag_sample(model, tokenizer, eod_id, sample)
        label = sample.get("label")
        if label is not None:
            if pred_idx == int(label):
                correct += 1
            total += 1
    if torch.distributed.is_initialized():
        torch.distributed.all_reduce(correct)
        torch.distributed.all_reduce(total)
    acc = (correct / total).item() if total.item() > 0 else 0.0
    return acc


# -----------------------------
# PIQA (jsonl preferred): 2-way MCQ goal + sol1/sol2
# -----------------------------
def _iter_piqa_samples(path):
    """Iterate PIQA samples from local tests.jsonl only.

    Accepted inputs:
    - Directory containing tests.jsonl, or dev.jsonl + dev-labels.lst
    - Direct path to tests.jsonl
    """

    def _yield_with_labels(jsonl_fp, labels_fp):
        rows_local = _load_jsonl(jsonl_fp)
        labels = []
        with open(labels_fp, "r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if s == "":
                    continue
                try:
                    labels.append(int(s))
                except Exception:
                    labels.append(None)
        for i, r in enumerate(rows_local):
            goal = r.get("goal")
            s1 = r.get("sol1") or r.get("choice1")
            s2 = r.get("sol2") or r.get("choice2")
            lab = labels[i] if i < len(labels) else None
            yield {"goal": goal, "choices": [s1, s2], "label": lab}

    if os.path.isdir(path):
        dev_json = os.path.join(path, "dev.jsonl")
        dev_lbl = os.path.join(path, "dev-labels.lst")
        if os.path.isfile(dev_json) and os.path.isfile(dev_lbl):
            for item in _yield_with_labels(dev_json, dev_lbl):
                yield item
            return
        tests_json = os.path.join(path, "tests.jsonl")
        if not os.path.isfile(tests_json):
            raise FileNotFoundError(f"PIQA expects tests.jsonl or dev.jsonl+dev-labels.lst under directory: {path}")
        rows = _load_jsonl(tests_json)
    else:
        if not path.endswith(".jsonl"):
            raise FileNotFoundError(f"PIQA expects tests.jsonl file or directory containing it. Got: {path}")
        rows = _load_jsonl(path)
    def _norm_label(val):
        try:
            v = int(val)
        except Exception:
            return None
        if v in (0, 1):
            return v
        if v in (1, 2):  # some dumps use 1/2
            return v - 1
        return None
    for r in rows:
        goal = r.get("goal")
        s1 = r.get("sol1") or r.get("choice1")
        s2 = r.get("sol2") or r.get("choice2")
        lab_raw = r.get("label")
        lab = _norm_label(lab_raw) if lab_raw is not None else None
        yield {"goal": goal, "choices": [s1, s2], "label": lab}


def _predict_piqa_sample(model, tokenizer, eod_id, sample, max_ctx=1024):
    goal = sample["goal"]
    choices = sample["choices"]
    prefix = f"Goal: {goal}\nSolution:"
    prefix_ids = tokenizer.tokenize(prefix)[-max_ctx:]
    scores = []
    for c in choices:
        c_ids = tokenizer.tokenize(" " + str(c))
        scores.append(_choice_score(model, prefix_ids, c_ids, eod_id))
    if len(scores) == 0:
        return 0
    return int(max(range(len(scores)), key=lambda i: scores[i]))


def _evaluate_piqa(model, tokenizer, path):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_piqa_samples(path)
    sample_iter = tqdm(sample_iter, desc="PIQA")
    for sample in sample_iter:
        pred_idx = _predict_piqa_sample(model, tokenizer, eod_id, sample)
        label = sample.get("label")
        if label is not None:
            if pred_idx == int(label):
                correct += 1
            total += 1
    if torch.distributed.is_initialized():
        torch.distributed.all_reduce(correct)
        torch.distributed.all_reduce(total)
    # If no labels are present, return NaN to avoid misleading zero accuracy
    acc = (correct / total).item() if total.item() > 0 else float('nan')
    return acc


def _iter_lambada_samples(jsonl_path):
    if not os.path.isfile(jsonl_path):
        raise FileNotFoundError(f"Cannot locate LAMBADA jsonl: {jsonl_path}")
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            obj = json.loads(line)
            # Common fields across dumps
            text = obj.get("text") or obj.get("context") or obj.get("passage")
            target = obj.get("target") or obj.get("answer") or obj.get("label")
            if text is None:
                continue
            if target is None:
                # Fallback: last whitespace-delimited token
                s = text.rstrip()
                if not s:
                    continue
                parts = s.split()
                target = parts[-1]
                text = s[: -len(target)].rstrip()
            yield {"context": text, "target": target}


@torch.no_grad()
def _evaluate_lambada(model, tokenizer, jsonl_path, max_ctx=1024):
    model.eval()
    eod_id = tokenizer.eod
    correct = torch.tensor(0.0, device="cuda")
    total = torch.tensor(0.0, device="cuda")
    sample_iter = _iter_lambada_samples(jsonl_path)
    sample_iter = tqdm(sample_iter, desc='LAMBADA')

    for sample in sample_iter:
        ctx = sample["context"]
        tgt = sample["target"]
        # Tokenize
        ctx_ids = tokenizer.tokenize(ctx)[-max_ctx:]
        ans_ids = tokenizer.tokenize(" " + tgt)
        # Build input sequence and run model once
        ids = ctx_ids + ans_ids
        tokens = torch.tensor(ids, device="cuda").unsqueeze(0)
        # Build masks/positions with current Megatron signature
        pad_token = eod_id
        attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
            tokens, eod_id, pad_token, False, False, False, False
        )
        logits = model(tokens, position_ids=position_ids, attention_mask=attention_mask)
        # Check sequential next-token predictions against ground truth
        ok = True
        start = len(ctx_ids) - 1
        for i, tid in enumerate(ans_ids):
            if torch.argmax(logits[0, start + i]) != tid:
                ok = False
                break
        if ok:
            correct += 1
        total += 1

    torch.distributed.all_reduce(correct)
    torch.distributed.all_reduce(total)
    acc = (correct / total).item() if total.item() > 0 else 0.0
    return acc


def get_model_provider(eval_metric):
    def model_provider(pre_process=True, post_process=True):
        config = core_transformer_config_from_args(get_args())
        parallel_output = (eval_metric == 'loss')
        args = get_args()
        from megatron.core.models.gpt import GPTModel as MCoreGPTModel
        from importlib import import_module
        use_te = args.transformer_impl == "transformer_engine"
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            from megatron.core.models.gpt.gpt_layer_specs import (
                get_gpt_layer_with_transformer_engine_spec,
                get_gpt_layer_local_spec,
                get_gpt_decoder_block_spec,
            )
            if args.num_experts:
                transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te)
            else:
                transformer_layer_spec = (
                    get_gpt_layer_with_transformer_engine_spec if use_te else get_gpt_layer_local_spec
                )(
                    args.num_experts,
                    args.moe_grouped_gemm,
                    args.qk_layernorm,
                    args.multi_latent_attention,
                    moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
                )
        model = MCoreGPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            parallel_output=parallel_output,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
        )
        return model

    return model_provider


def process_batch(batch):
    args = get_args()
    tokenizer = get_tokenizer()
    loss_mask = batch['pad_mask'].long().cuda().contiguous().byte()
    tokens_ = batch['text'].long().cuda().contiguous()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens, tokenizer.eod, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss, getattr(args, 'pad_mask_loss', False)
    )
    return tokens, labels, attention_mask, position_ids, loss_mask


def forward_step(batch, model, eval_metric, config):
    tokens, labels, attention_mask, position_ids, loss_mask = process_batch(batch)
    args = get_args()
    args.micro_batch_size = len(labels)
    # Pipeline MP size is 1 for our eval; no P2P recv/send needed.
    unwrapped_model = unwrap_model(model)
    try:
        unwrapped_model.set_input_tensor(None)
    except Exception:
        pass
    output = model(tokens, position_ids, attention_mask)
    if parallel_state.is_pipeline_last_stage():
        if eval_metric == 'loss':
            losses = tensor_parallel.vocab_parallel_cross_entropy(output.contiguous().float(), labels.contiguous())
            loss = torch.sum(losses.view(-1) * loss_mask.contiguous().view(-1).float())
            return loss
        if eval_metric == 'accuracy':
            outputs = torch.argmax(output, -1)
            correct = (outputs == labels).float()
            correct[(1 - loss_mask).bool()] = 1
            correct = correct.prod(-1)
            return correct.sum()
        raise NotImplementedError
    return None


def evaluate(data_loader, model, eval_metric):
    args = get_args()
    config = core_transformer_config_from_args(args)
    model.eval()
    total_output = 0.0
    with torch.no_grad():
        for iteration, batch in enumerate(data_loader):
            if iteration % args.log_interval == 0:
                print_rank_0('> working on iteration: {}'.format(iteration))
            output = forward_step(batch, model, eval_metric, config)
            if parallel_state.is_pipeline_last_stage():
                torch.distributed.all_reduce(output, group=parallel_state.get_data_parallel_group())
                total_output += output
    return total_output


def evaluate_and_print_results(task, data_loader, model, eval_metric):
    output = evaluate(data_loader, model, eval_metric)
    string = ' validation results on {} | '.format(task)
    if eval_metric == 'loss':
        num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens
        num_original_tokens = data_loader.dataset.num_original_tokens
        val_loss = output / (num_tokenized_tokens - 1)
        ppl = math.exp(min(20, val_loss))
        token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1)
        adjusted_ppl = math.exp(min(20, val_loss * token_ratio))
        string += 'avg loss: {:.4E} | '.format(val_loss)
        string += 'ppl: {:.4E} | '.format(ppl)
        string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl)
        string += 'token ratio: {} |'.format(token_ratio)
    elif eval_metric == 'accuracy':
        num_examples = len(data_loader.dataset)
        acc = output / num_examples
        string += 'number correct: {:.4E} | '.format(output)
        string += 'total examples: {:.4E} | '.format(num_examples)
        string += 'avg accuracy: {:.4E}'.format(acc)
        print('-' * (len(string) + 1))
        print(string)
        print('-' * (len(string) + 1))


def main():
    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    if args.task == 'LAMBADA':
        eval_metric = 'accuracy'
    elif args.task == 'WIKITEXT103':
        eval_metric = 'loss'
    elif args.task == 'RACE':
        eval_metric = 'accuracy'
    elif args.task == 'ARC-E':
        eval_metric = 'accuracy'
    elif args.task == 'ARC-C':
        eval_metric = 'accuracy'
    elif args.task in ['BOOLQ', 'HELLASWAG', 'PIQA']:
        eval_metric = 'accuracy'
    else:
        raise NotImplementedError('{} task is not implemented.'.format(args.task))

    model = get_model(get_model_provider(eval_metric), wrap_with_ddp=False)
    if args.load is not None:
        _ = load_checkpoint(model, None, None)
    assert len(model) == 1
    model = model[0]

    if args.task == 'RACE':
        tokenizer = get_tokenizer()
        race_root = args.valid_data[0] if args.valid_data else './RACE'
        acc = _evaluate_race(model, tokenizer, race_root, split='test')
        print_rank_0(f'RACE test accuracy: {acc:.4%}')
    elif args.task == 'ARC-E':
        tokenizer = get_tokenizer()
        arc_root = args.valid_data[0] if args.valid_data else './ARC'
        acc = _evaluate_arc(model, tokenizer, arc_root, subset='ARC-Easy')
        print_rank_0(f'ARC-E accuracy: {acc:.4%}')
    elif args.task == 'ARC-C':
        tokenizer = get_tokenizer()
        arc_root = args.valid_data[0] if args.valid_data else './ARC'
        acc = _evaluate_arc(model, tokenizer, arc_root, subset='ARC-Challenge')
        print_rank_0(f'ARC-C accuracy: {acc:.4%}')
    elif args.task == 'BOOLQ':
        tokenizer = get_tokenizer()
        data_path = args.valid_data[0] if args.valid_data else './boolq/validation.parquet'
        acc = _evaluate_boolq(model, tokenizer, data_path)
        print_rank_0(f'BOOLQ accuracy: {acc:.4%}')
    elif args.task == 'HELLASWAG':
        tokenizer = get_tokenizer()
        # Accept directory with parquet shards; prefer validation-*.parquet
        data_path = args.valid_data[0] if args.valid_data else './hellaswag/data'
        acc = _evaluate_hellaswag(model, tokenizer, data_path)
        print_rank_0(f'HELLASWAG accuracy: {acc:.4%}')
    elif args.task == 'PIQA':
        tokenizer = get_tokenizer()
        # Accept tests.jsonl file or a directory containing it
        data_path = args.valid_data[0] if args.valid_data else './piqa'
        acc = _evaluate_piqa(model, tokenizer, data_path)
        print_rank_0(f'PIQA accuracy: {acc:.4%}')
    else:
        # Default path handles WIKITEXT103 loss eval using Megatron's dataset builders
        from tasks.finetune_utils import build_data_loader
        try:
            from .datasets import build_dataset
            dataset = build_dataset(args.task)
            dataloader = build_data_loader(dataset, args.micro_batch_size, args.num_workers, drop_last=False)
            evaluate_and_print_results(args.task, dataloader, model, eval_metric)
        except Exception:
            # If dataset helpers are absent, support direct LAMBADA jsonl eval
            if args.task == 'LAMBADA' and args.valid_data:
                tokenizer = get_tokenizer()
                acc = _evaluate_lambada(model, tokenizer, args.valid_data[0])
                print_rank_0(f'LAMBADA accuracy: {acc:.4%}')
            else:
                raise

    print_rank_0('done :-)')


