from unsloth import FastLanguageModel
import json
from vllm import SamplingParams
import re
import os
import numpy as np
from collections import defaultdict
from ..data.loaders import load_aime_dataset, load_amc_dataset, load_math500_dataset
from ..utils.eval import *
from ..utils.rewards import reward
from ..utils.seed import set_seed
import argparse
import yaml

def _get_json_path(cfg):
    model_short = cfg["model_name"].split("/")[-1]    # "Qwen2.5-Math-7B"
    json_out_path = os.path.join(
        "outputs",
        "eval",
        model_short,
        "evaluation_results.json"
    )
    return json_out_path

def _parse_step(s: str) -> int:
    m = re.search(r'(\d+)$', s) or re.search(r'(\d+)', s)
    return int(m.group(1)) if m else -1

def quick_eval(
    model,
    tokenizer,
    lora_request = None,
    eval_dataset=None,
    temperature=0.3,
    n_sampling=16,              # pass@k에서의 k
    max_tokens=768,
    top_p=0.95,
    seed=42
):
    sampling_params = SamplingParams(
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        stop=[tokenizer.eos_token],
        n=n_sampling,
        seed=seed,
    )

    n_total = 0
    n_acc_first = 0
    n_pass_at_k = 0
    for item in eval_dataset:
        prompt_text = tokenizer.apply_chat_template(
            item["prompt"],
            add_generation_prompt=True,
            tokenize=False,
        )

        outs = model.fast_generate(
            [prompt_text],
            sampling_params=sampling_params,
            lora_request = lora_request,
            use_tqdm=False,
        )[0].outputs

        responses = [o.text for o in outs]

        gt = item["answer"]
        correct_responses = [reward(responses[i], gt) for i in range(len(responses))]

        if correct_responses[0]:
            n_acc_first += 1
        if any(correct_responses):
            n_pass_at_k += 1
        n_total += 1

    accuracy = (n_acc_first / n_total) if n_total else 0.0
    pass_at_k = (n_pass_at_k / n_total) if n_total else 0.0

    return float(accuracy), float(pass_at_k)



def eval_per_step(
    model,
    tokenizer,
    cfg,
):
    per_seed_records = []

    data_loaders = {
        "MATH-500": load_math500_dataset,
        "AIME": load_aime_dataset,
        "AMC23": load_amc_dataset,
    }
    datasets = {}
    model_short = cfg["model_name"].split("/")[-1]    # "Qwen2.5-Math-7B"
    model_root  = os.path.join("outputs", "models", model_short)

    baseline_path  = os.path.join(model_root, "baseline")
    baseline_label = "baseline"

    seeds = cfg["seeds"]
    k = cfg["k"]
    json_out_path = _get_json_path(cfg)
    eval_steps = cfg.get("eval_steps", True)

    algos_whitelist = cfg.get("algorithms", None)
    if algos_whitelist is not None:

        algos_whitelist = {str(a) for a in algos_whitelist}

    dirs = []
    if not os.path.isdir(model_root):
        print(f"⚠️ Model root directory not found: {model_root}")
    else:
        for algo_name in sorted(os.listdir(model_root)):
            algo_dir = os.path.join(model_root, algo_name)
            if not os.path.isdir(algo_dir):
                continue
            if algo_name == "baseline":
                continue
            if algos_whitelist is not None and algo_name not in algos_whitelist:
                continue
            ckpts_dir = os.path.join(algo_dir, "ckpts")
            if os.path.isdir(ckpts_dir):
                dirs.append(ckpts_dir)
            else:
                print(f"⚠️ No ckpts/ found for algo '{algo_name}' at {ckpts_dir}")

    print(f"📁 Auto-discovered ckpt dirs: {dirs}")

    completed_keys = set()  # (dataset, algo_type, step, seed, k)

    if os.path.isfile(json_out_path):
        try:
            with open(json_out_path, "r", encoding="utf-8") as f:
                old_payload = json.load(f)
            old_k = old_payload.get("k", k)
            old_per_seed = old_payload.get("per_seed_records", [])

            if old_k != k:
                print(f"⚠️ Existing JSON has k={old_k}, but current config uses k={k}.")
                print("   To avoid mixing different pass@k, ignoring old file and starting fresh.")
            else:
                per_seed_records.extend(old_per_seed)
                for r in old_per_seed:
                    completed_keys.add((
                        r["dataset"],
                        r["algo_type"],
                        int(r["step"]),
                        int(r["seed"]),
                        int(k),
                    ))
                print(f"📂 Loaded existing JSON with {len(old_per_seed)} per-seed records.")
                print(f"   Will skip already evaluated (dataset, algo_type, step, seed, k={k}) combos.")
        except Exception as e:
            print(f"⚠️ Failed to load existing JSON at {json_out_path}: {e}")
            print("   Proceeding as if no previous results exist.")


    def compute_and_save():
        key2rows = defaultdict(list)
        for r in per_seed_records:
            key = (r["algo_type"], int(r["step"]), r["dataset"])
            key2rows[key].append(r)

        aggregates = []
        for (algo, step, ds_name), rows in key2rows.items():
            p1s = np.array([r["pass_at_1"] for r in rows], dtype=float)
            pks = np.array([r["pass_at_k"] for r in rows], dtype=float)
            aggregates.append({
                "dataset": ds_name,
                "algo_type": algo,
                "step": int(step),
                "n": int(len(rows)),
                "pass_at_1_mean": float(p1s.mean()) if len(p1s) else 0.0,
                "pass_at_1_std":  float(p1s.std(ddof=1)) if len(p1s) > 1 else 0.0,
                "pass_at_k_mean": float(pks.mean()) if len(pks) else 0.0,
                "pass_at_k_std":  float(pks.std(ddof=1)) if len(pks) > 1 else 0.0,
            })

        all_seeds = sorted({int(r["seed"]) for r in per_seed_records})
        all_datasets = sorted({r["dataset"] for r in per_seed_records})

        payload = {
            "k": int(k),
            "seeds": all_seeds,
            "datasets": all_datasets,
            "per_seed_records": sorted(
                per_seed_records,
                key=lambda r: (r["dataset"], r["algo_type"], int(r["step"]), int(r["seed"]))
            ),
            "aggregates": sorted(
                aggregates,
                key=lambda r: (r["dataset"], r["algo_type"], int(r["step"]))
            ),
        }
        dirpath = os.path.dirname(json_out_path)
        if dirpath and dirpath.strip():
          os.makedirs(dirpath, exist_ok=True)
        with open(json_out_path, "w", encoding="utf-8") as f:
            json.dump(payload, f, ensure_ascii=False, indent=2)
        print(f"\n💾 Saved JSON to {json_out_path}")
        

    for name, loader in data_loaders.items():
        if name in cfg["datasets"]:
            try:
                ds = loader()
                if not ds:
                    print(f"⚠️ Dataset '{name}' is empty.")
                datasets[name] = ds
            except Exception as e:
                print(f"⚠️ Failed to load dataset '{name}': {e}")
                datasets[name] = []


    if os.path.isdir(baseline_path):
        print(f"\n🧩 Evaluating baseline: {baseline_label}")
        try:
            base_ckpt_name = os.path.basename(baseline_path.rstrip("/"))
        except Exception as e:
            print(f"    ⚠️ Skipping baseline (name parse error): {e}")
            base_ckpt_name = None

        if base_ckpt_name:
            for ds_name, ds in datasets.items():
                if not ds:
                    continue
                for sd in seeds:
                    key = (ds_name, baseline_label, 0, int(sd), int(k))
                    if key in completed_keys:
                        print(f"    ↩️ Skipping baseline / {ds_name} / seed={sd} (already evaluated).")
                        continue
                    try:
                        p1, pk = quick_eval(
                            model,
                            tokenizer,
                            lora_request = model.load_lora(baseline_path),
                            eval_dataset = ds,
                            n_sampling = k,
                            seed = sd,
                            temperature=cfg['temperature'],
                            max_tokens=cfg['max_tokens'],
                            top_p=cfg['top_p']
                        )
                        record = {
                            "dataset": ds_name,
                            "algo_type": baseline_label,
                            "ckpt": base_ckpt_name,
                            "step": 0,
                            "seed": int(sd),
                            "pass_at_1": float(p1),
                            "pass_at_k": float(pk),
                        }
                        per_seed_records.append(record)
                        completed_keys.add(key)
                    except Exception as e:
                        print(f"    ⚠️ Error @ baseline / {ds_name} / seed={sd}: {e}")
            if per_seed_records:
                compute_and_save()
    else:
        print(f"\n⚠️ Baseline path not found: {baseline_path}. Skipping baseline.")

    for dir in dirs:
        algo_type = dir.rstrip("/").split("/")[-2] if len(dir.rstrip("/").split("/")) >= 2 else dir.rstrip("/").split("/")[-1]
        print(f"\n🧩 Evaluating algorithm type: {algo_type}")

        all_ckpt_dirs = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
        ckpt_dirs = sorted(all_ckpt_dirs, key=lambda x: _parse_step(x))

        if eval_steps:
            target_ckpts = ckpt_dirs
        else:
            target_ckpts = ckpt_dirs[-1:] if ckpt_dirs else []

        for ckpt in target_ckpts:
            ckpt_path = os.path.join(dir, ckpt)
            step = _parse_step(ckpt)
            print(f"  → Loading {ckpt_path}")

            for ds_name, ds in datasets.items():
                if not ds:
                    continue
                for sd in seeds:
                    key = (ds_name, algo_type, int(step), int(sd), int(k))
                    if key in completed_keys:
                        print(f"    ↩️ Skipping {algo_type} / {ckpt} / {ds_name} / seed={sd} (already evaluated).")
                        continue
                    try:
                        p1, pk = quick_eval(
                            model,
                            tokenizer,
                            lora_request = model.load_lora(ckpt_path),
                            eval_dataset = ds,
                            n_sampling = k,
                            seed = sd,
                            temperature=cfg['temperature'],
                            max_tokens=cfg['max_tokens'],
                            top_p=cfg['top_p']
                        )
                        record = {
                            "dataset": ds_name,
                            "algo_type": algo_type,
                            "ckpt": ckpt,
                            "step": int(step),
                            "seed": int(sd),
                            "pass_at_1": float(p1),
                            "pass_at_k": float(pk),
                        }
                        per_seed_records.append(record)
                        completed_keys.add(key)
                    except Exception as e:
                        print(f"    ⚠️ Error @ {algo_type} / {ckpt} / {ds_name} / seed={sd}: {e}")
                        continue


            if per_seed_records:
                compute_and_save()

    if per_seed_records:
        compute_and_save()
    else:
        print("⚠️ No per-seed records were produced. Nothing to save.")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    with open(args.config, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    set_seed(cfg["seed"])

    max_seq_length = cfg["max_seq_length"] # Can increase for longer reasoning traces
    lora_rank = cfg["lora_rank"] # Larger rank = smarter, but slower

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = cfg["model_name"],
        max_seq_length = max_seq_length,
        load_in_4bit = False, # False for LoRA 16bit
        fast_inference = True, # Enable vLLM fast inference
        max_lora_rank = lora_rank,
        gpu_memory_utilization = 0.9, # Reduce if out of memory
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
        target_modules = [
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj",
        ],
        lora_alpha = lora_rank*2, # *2 speeds up training
        use_gradient_checkpointing = "unsloth", # Reduces memory usage
        random_state = cfg["random_state"],
    )

    reasoning_start = "<start_working_out>" # Acts as <think>
    reasoning_end   = "<end_working_out>"   # Acts as </think>
    solution_start  = "<SOLUTION>"
    solution_end    = "</SOLUTION>"

    system_prompt = \
    f"""You are given a problem.
    Think about the problem and provide your working out.
    Place it between {reasoning_start} and {reasoning_end}.
    Then, provide your solution between {solution_start}{solution_end}"""

    chat_template = \
        "{% if messages[0]['role'] == 'system' %}"\
            "{{ messages[0]['content'] + eos_token }}"\
            "{% set loop_messages = messages[1:] %}"\
        "{% else %}"\
            "{{ '{system_prompt}' + eos_token }}"\
            "{% set loop_messages = messages %}"\
        "{% endif %}"\
        "{% for message in loop_messages %}"\
            "{% if message['role'] == 'user' %}"\
                "{{ message['content'] }}"\
            "{% elif message['role'] == 'assistant' %}"\
                "{{ message['content'] + eos_token }}"\
            "{% endif %}"\
        "{% endfor %}"\
        "{% if add_generation_prompt %}{{ '{reasoning_start}' }}"\
        "{% endif %}"

    # Replace with out specific template:
    if not cfg["instruct"]:
        chat_template = chat_template\
            .replace("'{system_prompt}'",   f"'{system_prompt}'")\
            .replace("'{reasoning_start}'", f"'{reasoning_start}'")
        tokenizer.chat_template = chat_template

    eval_per_step(
        model= model, 
        tokenizer= tokenizer,
        cfg= cfg
    )



if __name__ == "__main__":
    main()

