#!/usr/bin/env python
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import re
import ast
import pandas as pd
import logging
import traceback
import yaml  # Import YAML parsing library
from vllm import LLM, SamplingParams

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def load_config(config_path="config_existing_model.yaml"):
    """Load YAML config file"""
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)
        if not config.get("models"):
            raise ValueError("YAML file missing 'models' field")
        logger.info(f"Successfully loaded config file {config_path}")
        return config
    except FileNotFoundError:
        logger.error(f"Config file not found: {config_path}")
        raise
    except Exception as e:
        logger.error(f"Failed to load YAML config: {e}")
        raise

def load_data():
    """Read test.tsv and validate possible_answers format"""
    try:
        df = pd.read_csv("test.tsv", sep="\t", dtype=str)
        logger.info("Successfully loaded test.tsv")
    except FileNotFoundError:
        logger.error("test.tsv not found in current directory")
        raise
    except Exception as e:
        logger.error(f"Failed to load test.tsv: {e}")
        raise

    required_cols = ["question", "possible_answers", "prop", "o_pop", "s_pop"]
    if not all(c in df.columns for c in required_cols):
        logger.error("test.tsv missing required columns, please include: " + ",".join(required_cols))
        raise ValueError("Missing required columns")

    # Validate possible_answers format
    for idx, row in df.iterrows():
        pa = row["possible_answers"]
        try:
            ast.literal_eval(pa)
        except Exception:
            logger.warning(f"Row {idx} has malformed possible_answers: {pa}")
    return df

def clean_text(text: str) -> str:
    text = text.lower().strip()
    # Remove all "'s" suffixes
    text = re.sub(r"'s\b", "", text)
    return text

def is_correct_answer(model_output: str, possible_answers) -> bool:
    mo = clean_text(model_output)
    for ans in possible_answers:
        if clean_text(ans) in mo:
            return True
    return False

def evaluate_model(model_name: str, df: pd.DataFrame):
    try:
        model_short = model_name.replace("EleutherAI/", "").replace("Qwen/", "")
        out_dir = f"./{model_short}"
        os.makedirs(out_dir, exist_ok=True)

        # Load vLLM model
        llm = LLM(
            model=model_name,
            trust_remote_code=True,
        )

        DEMO_DICT = {}
        DEMO_QUESTIONS = set()
        for prop, group in df.groupby("prop"):
            demos = []
            for _, r in group.head(3).iterrows():
                try:
                    lst = ast.literal_eval(r["possible_answers"])
                    demo_ans = lst[0] if isinstance(lst, (list, tuple)) else str(lst)
                except:
                    demo_ans = ""
                demos.append((r["question"], demo_ans))
                DEMO_QUESTIONS.add(r["question"])
            DEMO_DICT[prop] = demos

        fout_path = os.path.join(out_dir, "output.txt")
        with open(fout_path, "w", encoding="utf-8") as fout:
            fout.write(f"Model: {model_name}\n===\n\n")

            for _, row in df.iterrows():
                q = row["question"]
                # Skip demo questions
                if q in DEMO_QUESTIONS:
                    continue
                # Parse possible_answers
                try:
                    pa = ast.literal_eval(row["possible_answers"])
                except:
                    logger.warning(f"Parse failure: {row['possible_answers']}")
                    continue

                prop = row["prop"]
                demos = DEMO_DICT.get(prop, [])
                prompt = f"I am preparing training data for a large language model to enhance its ability to extract knowledge from Wikipedia paragraphs. The goal is to teach the model to identify and understand relationships between entities (e.g., {prop}) rather than memorizing text. Below are example questions and answers related to {prop}:\n\n"
                for dq, da in demos:
                    prompt += f"Q: {dq}\nA: {da}\n\n"
                prompt += f"Q: {q}\nA:"

                # Generate with vLLM
                response = llm.generate([prompt])
                out_text = response[0].outputs[0].text.strip()

                corr = is_correct_answer(out_text, pa)

                fout.write(f"Q: {q}\nOut: {out_text}\nAns: {pa}\nCorr: {corr}\n")
                fout.write(f"o_pop: {row['o_pop']}\n")
                fout.write(f"s_pop: {row['s_pop']}\n")
                fout.write("-"*20 + "\n")

        del llm
    except Exception:
        logger.error(f"Model {model_name} crashed:\n" + traceback.format_exc())

def main():
    try:
        # Load YAML config
        config = load_config()
        models = config["models"]
        df = load_data()
    except:
        return
    
    for m in models:
        logger.info(f"=== Evaluating {m} ===")
        try:
            evaluate_model(m, df)
        except Exception as e:
            logger.error(f"Model {m} fails: {e}")

    print("All done. See output.txt in model directories.")

if __name__ == "__main__":
    main()