import os
import re
import json
import argparse
import random
from typing import Dict, Callable, List

import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from utils import extract_answer,  get_dataset_converter, load_hf_dataset_test

import signal
from contextlib import contextmanager
import contextlib
import io
import math


def _set_all_seeds(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
import re

def extract_answer(dataset_name: str, text: str) -> str:
    m = re.search(r"\\boxed\s*\{([^{}]+)\}", text)
    if m:
        return m.group(1).strip()
    m = re.search(r"(?:####|the answer is:?)\s*([-+]?\d[\d,]*\.?\d*)", text, re.IGNORECASE)
    if m:
        return m.group(1).replace(",", "").strip()
    m = re.search(r"([-+]?\d[\d,]*\.?\d*)\s*$", text)
    if m:
        return m.group(1).replace(",", "").strip()
    m = re.search(r"\\frac\s*\{\s*([-+]?\d+)\s*\}\s*\{\s*([-+]?\d+)\s*\}", text)
    if m:
        return f"{int(m.group(1))}/{int(m.group(2))}"

    return text.strip()

def _normalize_answer(text: str) -> str:
    return text.lower().strip().strip(".")

def check_mcqa_answer(prediction: str, gold: str) -> bool:
    p = _normalize_answer(prediction)
    g = _normalize_answer(gold)
    if not g: return False
    if p.startswith(f"{g}.") or p.startswith(f"{g} "): return True
    if p.startswith(f"({g})"): return True
    return p == g

def get_metric_function(dataset_name: str) -> Callable[[str, str, Dict], bool]:
    if dataset_name in ("gsm8k", "metamathqa"):
        def math_metric(pred, gold, ex):
            pn = extract_answer(pred); gn = extract_answer(gold)
            return pn is not None and gn is not None and abs(pn - gn) < 1e-6
        return math_metric
    if dataset_name in ("arc_c", "arc_e", "obqa"):
        def arc_metric(pred, gold, ex):
            return check_mcqa_answer(pred, gold) or check_mcqa_answer(pred, ex.get("answerKey", ""))
        return arc_metric
    return lambda pred, gold, ex: check_mcqa_answer(pred, gold)


class TimeoutException(Exception): pass

@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    signal.signal(signal.SIGALRM, signal_handler)
    signal.alarm(seconds)
    try:
        yield
    finally:
        signal.alarm(0)

def check_mbpp_code(code: str, test_list: list[str], timeout_seconds: int = 5) -> bool:
    code_blocks = re.findall(r'def\s.*?:\n(?:\s+.*\n?)+', code, re.DOTALL)
    if not code_blocks:
        return False
    main_code = max(code_blocks, key=len)

    full_code = main_code + "\n" + "\n".join(test_list)

    try:
        with time_limit(timeout_seconds):
            with io.StringIO() as buf, contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
                exec(full_code, {})
        return True
    except (TimeoutException, AssertionError, Exception) as e:
        return False
    
def estimate_pass_at_k(num_samples: int, num_correct: int, k: int) -> float:
    if num_samples - num_correct < k:
        return 1.0
    return 1.0 - math.prod(1.0 - k / (num_samples - i) for i in range(num_correct))

def check_humaneval_code(
    prompt: str,
    generated_code: str,
    test: str,
    entry_point: str,
    timeout_seconds: int = 5
) -> bool:
    full_code = generated_code[1:] + '\n' + test + f'\ncheck({entry_point})'
    
    try:
        with time_limit(timeout_seconds):
            with io.StringIO() as buf, contextlib.redirect_stdout(buf), contextlib.redirect_stderr(buf):
                exec(full_code, {})
        return True
    except (TimeoutException, AssertionError, Exception):
        return False

def run_evaluation(args: argparse.Namespace):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    _set_all_seeds(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    dataset = load_hf_dataset_test(args.dataset)
    converter = get_dataset_converter(args.dataset)
    ds = dataset.map(converter, load_from_cache_file=False, desc="rebuild cache without 'List'")

    print(f"Start eval | base: {args.model_name} | ckpt: {args.checkpoint_path} | dataset: {args.dataset}")
    if args.checkpoint_path is None or not os.path.exists(args.checkpoint_path):
        raise ValueError("checkpoint_path is required and must exist.")
    
    missing_keys_list = []
    unexpected_keys_list = []
    
    peft = os.path.join(args.checkpoint_path, "base_lora")
    fft = os.path.join(args.checkpoint_path, "model.safetensors.index.json")

    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if os.path.exists(peft) and os.path.isdir(peft):
        model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16, local_files_only=True ).to(device)
        from peft import PeftModel
        base_adapter_name = "base_lora"
        all_subdirs = [d for d in os.listdir(args.checkpoint_path) if os.path.isdir(os.path.join(args.checkpoint_path, d))]
        growth_adapters = sorted([d for d in all_subdirs if d != base_adapter_name])
        
        active_adapters = [base_adapter_name] + growth_adapters

        base_adapter_path = os.path.join(args.checkpoint_path, base_adapter_name)
        print(f"Loading base PeftModel from: {base_adapter_path}")
        model = PeftModel.from_pretrained(model, base_adapter_path, adapter_name=base_adapter_name).to(device)
        
        if growth_adapters:
            print("Loading additional growth adapters...")
            for adapter_name in growth_adapters:
                adapter_path = os.path.join(args.checkpoint_path, adapter_name)
                model.load_adapter(adapter_path, adapter_name=adapter_name)
                print(f"  - Loaded: {adapter_name}")
        if active_adapters:
            model.base_model.set_adapter(active_adapters)
        else:
            model.base_model.set_adapter("default")  

    elif os.path.exists(fft):
        model = AutoModelForCausalLM.from_pretrained(
            args.checkpoint_path,
            torch_dtype=torch.bfloat16, device_map="auto" 
        )    
    else:
        model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)
              
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    model.eval()

    if 0 < args.sample_num < len(ds):
        idxs = random.sample(range(len(ds)), args.sample_num)
        ds = ds.select(idxs)

    print("Total samples:", len(ds))
    metric_fn = get_metric_function(args.dataset)

    correct = 0
    samples = []
    total_samples_generated = 0
    total_samples_passed = 0
    
    k = 10
    batch_size = args.batch_size

    if args.dataset in ["mbpp", "humaneval"]:
        desc = f"Evaluating {args.dataset} (pass@{k} with batch_size={args.batch_size})"
        for i in tqdm(range(0, len(ds), batch_size), desc=desc):
            batch = ds[i : i + batch_size]
            
            if args.dataset == "mbpp":
                prompts = [f"User: {messages[0]['content']}\n\nAssistant:" for messages in batch['messages']]
                tests = batch['test_list']
            elif args.dataset == "humaneval":
                prompts = [f"User: {messages[0]['content']}\n\nAssistant:" for messages in batch['messages']]
                tests = batch['test']

            inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_length).to(device)

            with torch.no_grad():
                out = model.generate(
                    **inputs,
                    max_new_tokens=args.max_new_tokens,
                    do_sample=True,
                    temperature=0.7,
                    pad_token_id=tokenizer.pad_token_id,
                    num_return_sequences=k
                )
            
            input_len = inputs["input_ids"].shape[1]
            gen_tokens = out[:, input_len:]
            
            num_prompts_in_batch = len(prompts)
            
            reshaped_tokens = gen_tokens.view(num_prompts_in_batch, k, -1)

            for problem_idx, problem_outputs in enumerate(reshaped_tokens):
                test_data = tests[problem_idx]
                
                generated_codes = tokenizer.batch_decode(problem_outputs, skip_special_tokens=True)
                
                num_passing_samples = 0
                is_correct = False

                for code in generated_codes:
                    passed = False
                    if args.dataset == "mbpp":
                        passed = check_mbpp_code(code, test_data)
                    elif args.dataset == "humaneval":
                        passed = check_humaneval_code(batch['prompt'][problem_idx], code, batch['test'][problem_idx], batch['entry_point'][problem_idx]) 
                    
                    if passed:
                        num_passing_samples += 1
                
                is_correct = num_passing_samples > 0
                if is_correct:
                    correct += 1
                
                total_samples_passed += num_passing_samples
                total_samples_generated += k

                log_entry = {
                    "prompt": prompts[problem_idx],
                    "is_correct": is_correct,
                    "num_passing_samples": num_passing_samples,
                    "all_generated_codes": generated_codes,
                }
                
                if args.dataset == "mbpp":
                    log_entry["gold_code_example"] = batch['messages'][problem_idx][1]['content']
                    log_entry["test_list"] = test_data
                    log_entry["successful_code_example"] = next((c for c in generated_codes if check_mbpp_code(c, test_data)), "")
                elif args.dataset == "humaneval":
                    log_entry["gold_code_example"] = batch['canonical_solution'][problem_idx]
                    log_entry["test"] = batch['test'][problem_idx]
                    log_entry["successful_code_example"] = next((c for c in generated_codes if check_humaneval_code(batch['prompt'][problem_idx], c, batch['test'][problem_idx], batch['entry_point'][problem_idx])), "")
                
                samples.append(log_entry)
                
    else:
        for start in tqdm(range(0, len(ds), args.batch_size), desc=f"Evaluating {args.dataset}"):
            batch = ds[start : start + args.batch_size]
            prompts: List[str] = []
            for msgs in batch["messages"]:
                if isinstance(msgs, (list, tuple)) and len(msgs) > 0:
                    m0 = msgs[0]
                    input_text = m0.get("content", "") if isinstance(m0, dict) else str(m0)
                else:
                    input_text = str(msgs)
                prompt = f"User: {input_text}\n\nAssistant:"
                prompts.append(prompt)
            inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=args.max_length).to(device)

            with torch.no_grad():
                out = model.generate(**inputs, max_new_tokens=args.max_new_tokens, eos_token_id=tokenizer.eos_token_id, do_sample=False)

            input_len = inputs["input_ids"].shape[1]
            gen = out[:, input_len:]
            preds = tokenizer.batch_decode(gen, skip_special_tokens=True)

            for i, pred in enumerate(preds):
                ex = {k: batch[k][i] for k in batch.keys()}

                m = ex.get("messages", [])
                if isinstance(m, (list, tuple)) and len(m) > 1:
                    gentry = m[1]
                    gold_text = gentry.get("content", "") if isinstance(gentry, dict) else str(gentry)
                else:
                    gold_text = ex.get("answer", "") or ex.get("label", "") or ex.get("target", "")
                def safe_extract(ds_name, text, fallback=""):
                    try:
                        return extract_answer(ds_name, text)
                    except NotImplementedError:
                        print(f"Warning: extract_answer not implemented for dataset '{ds_name}'. Using raw text.")
                        return text or fallback

                pred_norm = safe_extract(args.dataset, pred or "", fallback="")
                gold_norm = safe_extract(args.dataset, gold_text or "", fallback="")

                if args.dataset == "prefeval":
                    samples.append({
                        "prompt": prompts[i],
                        "pref_generation": "response_to_pref",
                        "output": pred,
                    })

                else:
                    if pred_norm and gold_norm:
                        ok = (pred_norm == gold_norm)
                    else:
                        ok = bool(metric_fn(pred, gold_text, ex))
                    if ok:
                        correct += 1
                    samples.append({
                        "prompt": prompts[i],
                        "output": pred,
                        "prediction": pred_norm if pred_norm else (pred or "").strip(),
                        "gold": gold_norm if gold_norm else (gold_text or ""),
                        "is_correct": bool(ok),
                    })

    if args.dataset in ["mbpp", "humaneval"]:
        problem_accuracy = (correct / len(ds) * 100) if len(ds) > 0 else 0.0
        sample_accuracy = (total_samples_passed / total_samples_generated * 100) if total_samples_generated > 0 else 0.0
    
    else:
        accuracy = (correct / len(ds) * 100) if len(ds) > 0 else 0.0

    if args.dataset in ["mbpp", "humaneval"]:
        result = {
            "base_model": args.model_name,
            "checkpoint_path": args.checkpoint_path,
            "dataset": args.dataset,
            "num_problems": len(ds),
            "problem_accuracy(pass_at_10)": f"{problem_accuracy:.2f}%",
            "sample_accuracy": f"{sample_accuracy:.2f}%",
            "correct_problems": correct,
            "total_problems": len(ds),
            "total_samples_generated": total_samples_generated,
            "total_samples_passed": total_samples_passed,
            "samples": samples
        }
        print(f"Problem Accuracy: {problem_accuracy:.2f}%")
        print(f"Sample Accuracy: {sample_accuracy:.2f}%")
    else:
        result = {
            "base_model": args.model_name,
            "checkpoint_path": args.checkpoint_path,
            "dataset": args.dataset,
            "accuracy": f"{accuracy:.2f}%",
            "correct": correct,
            "total": len(ds),
            "missing_keys_count": len(missing_keys_list),
            "unexpected_keys_count": len(unexpected_keys_list),
            "unexpected_keys": unexpected_keys_list,
            "samples": samples
        }
        

    out_dir = args.checkpoint_path if args.checkpoint_path and os.path.isdir(args.checkpoint_path) else os.getcwd()
    os.makedirs(out_dir, exist_ok=True)
    out_file = os.path.join(args.checkpoint_path, f"evaluation_results_{args.dataset}.json")
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2, ensure_ascii=False)

    print("Done. Results saved to", out_file)
    
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, required=True, help="base model (e.g. facebook/opt-1.3b)")
    parser.add_argument("--dataset", type=str, required=True, help="dataset name")
    parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint dir or weights file (optional)")
    parser.add_argument("--sample_num", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--max_new_tokens", type=int, default=512)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpu_id", type=int, default=0)
    args = parser.parse_args()

    run_evaluation(args)
