#!/usr/bin/env python
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

import re
import ast
import pandas as pd
import logging
import traceback
import yaml  # Import YAML parsing library
from model import GPT, inference  # Import the new GPT and inference functions

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def load_config(config_path="config.yaml"):
    """Load YAML configuration file"""
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)
        if not config.get("models") or not config.get("out_dir"):
            raise ValueError("YAML file is missing 'models' or 'out_dir' fields")
        logger.info(f"Successfully loaded configuration file {config_path}")
        return config
    except FileNotFoundError:
        logger.error(f"Configuration file {config_path} not found")
        raise
    except Exception as e:
        logger.error(f"Failed to load YAML configuration file: {e}")
        raise

def load_data():
    """Read test.tsv and check the format of possible_answers"""
    try:
        df = pd.read_csv("test.tsv", sep="\t", dtype=str)
        logger.info("Successfully loaded test.tsv")
    except FileNotFoundError:
        logger.error("test.tsv file not found in the current directory")
        raise
    except Exception as e:
        logger.error(f"Failed to load test.tsv: {e}")
        raise

    required_cols = ["question", "possible_answers", "prop", "o_pop", "s_pop"]
    if not all(c in df.columns for c in required_cols):
        logger.error("test.tsv is missing required columns, please include " + ",".join(required_cols))
        raise ValueError("Missing required columns")

    # Check possible_answers
    for idx, row in df.iterrows():
        pa = row["possible_answers"]
        try:
            ast.literal_eval(pa)
        except:
            logger.warning(f"Row {idx} possible_answers has format issues: {pa}")
    return df

def clean_text(text: str) -> str:
    text = text.lower().strip()
    # Remove all "'s" endings
    text = re.sub(r"'s\b", "", text)
    return text

def is_correct_answer(model_output: str, possible_answers) -> bool:
    mo = clean_text(model_output)
    for ans in possible_answers:
        if clean_text(ans) in mo:
            return True
    return False

def evaluate_model(model_name: str, df: pd.DataFrame, out_dir: str):
    try:
        model_short = model_name.replace("EleutherAI/", "").replace("Qwen/", "")
        os.makedirs(out_dir, exist_ok=True)

        # Load GPT model
        model = GPT.from_pretrained(model_name, device="cuda")  # Assuming CUDA device
        # Assuming tokenizer is available; you may need to adjust this based on your setup
        from transformers import AutoTokenizer

        # Construct few-shot example set
        DEMO_DICT = {}
        DEMO_QUESTIONS = set()
        for prop, group in df.groupby("prop"):
            demos = []
            for _, r in group.head(3).iterrows():
                try:
                    lst = ast.literal_eval(r["possible_answers"])
                    demo_ans = lst[0] if isinstance(lst, (list, tuple)) else str(lst)
                except:
                    demo_ans = ""
                demos.append((r["question"], demo_ans))
                DEMO_QUESTIONS.add(r["question"])
            DEMO_DICT[prop] = demos

        # Output file (for debugging)
        fout_path = os.path.join(out_dir, "output.txt")
        with open(fout_path, "w", encoding="utf-8") as fout:
            fout.write(f"Model: {model_name}\n===\n\n")

            for _, row in df.iterrows():
                q = row["question"]
                # Skip examples
                if q in DEMO_QUESTIONS:
                    continue
                # Parse possible_answers
                try:
                    pa = ast.literal_eval(row["possible_answers"])
                except:
                    logger.warning(f"Parse failure: {row['possible_answers']}")
                    continue

                # Construct prompt, combining s_wiki_title, o_wiki_title, prop
                prop = row["prop"]
                demos = DEMO_DICT.get(prop, [])
                prompt = f"I am preparing training data for a large language model to enhance its ability to extract knowledge from Wikipedia paragraphs. The goal is to teach the model to identify and understand relationships between entities (e.g., {prop}) rather than memorizing text. Below are example questions and answers related to {prop}:\n\n"
                for dq, da in demos:
                    prompt += f"Q: {dq}\nA: {da}\n\n"
                prompt += f"Q: {q}\nA:"
                import tiktoken
                # Generate using GPT model
                response = inference(
                    model=model,
                    input_text=prompt,
                    tokenizer=tiktoken.get_encoding("gpt2"),
                    max_new_tokens=100,
                    stop_token=198,
                    temperature=0,
                )
                out_text = response.strip()

                # Determine correctness
                corr = is_correct_answer(out_text, pa)
                
                # Write log
                fout.write(f"Q: {q}\nOut: {out_text}\nAns: {pa}\nCorr: {corr}\n")
                fout.write(f"o_pop: {row['o_pop']}\n")
                fout.write(f"s_pop: {row['s_pop']}\n")
                fout.write("-"*20 + "\n")

        # Release GPU memory
        del model
        import torch
        torch.cuda.empty_cache()

    except Exception:
        logger.error(f"Model {model_name} crashed:\n" + traceback.format_exc())

def main():
    try:
        # Load YAML configuration file
        config = load_config()
        models = config["models"]
        out_dir = config["out_dir"]
        df = load_data()
    except:
        return

    # Iterate through models
    for m in models:
        logger.info(f"=== Evaluating {m} ===")
        try:
            evaluate_model(m, df, out_dir)
        except Exception as e:
            logger.error(f"Model {m} fails: {e}")

    print("All done. See output.txt in model directories.")

if __name__ == "__main__":
    main()