import os
import pandas as pd
import logging
import ast
import numpy as np
import torch
import tiktoken
import yaml  # Import YAML parsing library
from model import GPT

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

EOS_TOKEN = 50256

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def load_config(config_path="config.yaml"):
    """Load YAML configuration file"""
    try:
        with open(config_path, "r", encoding="utf-8") as f:
            config = yaml.safe_load(f)
        if not config.get("models") or not config.get("out_dir"):
            raise ValueError("YAML file missing 'models' or 'out_dir' field")
        logger.info(f"Successfully loaded config file {config_path}")
        return config
    except FileNotFoundError:
        logger.error(f"Config file not found: {config_path}")
        raise
    except Exception as e:
        logger.error(f"Failed to load YAML config file: {e}")
        raise

def load_data():
    """Read test.tsv file"""
    try:
        df = pd.read_csv("test.tsv", sep="\t", dtype=str)
        logger.info("Successfully loaded test.tsv")
    except FileNotFoundError:
        logger.error("test.tsv file not found in current directory")
        raise
    except Exception as e:
        logger.error(f"Failed to load test.tsv: {e}")
        raise

    required_cols = ["question", "possible_answers", "prop", "o_pop", "s_pop"]
    if not all(c in df.columns for c in required_cols):
        logger.error("test.tsv missing required columns, must include " + ",".join(required_cols))
        raise ValueError("Missing required columns")
    return df

def tokenize(text, tokenizer):
    """Convert text to token IDs"""
    return tokenizer.encode(text)

def eval_loss(model, text, tokenizer=None):
    """
    Evaluate the loss of a single text, only calculating token loss after "A:".
    
    Args:
        model: GPT model instance
        text: A single string
        tokenizer: Tokenizer instance, if None, loads from tiktoken
        
    Returns:
        tuple: (loss value as float, list of log probabilities per token)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Use provided tokenizer or load from tiktoken
    if tokenizer is None:
        tokenizer = tiktoken.get_encoding("gpt2")
    
    # Encode text
    encoded_text = tokenize(text, tokenizer)
    max_len = len(encoded_text)
    
    # Convert to tensor
    input_ids = torch.tensor([encoded_text], dtype=torch.long, device=device)
    target_ids = input_ids[:, 1:].clone()
    input_ids = input_ids[:, :-1]
    
    # Get loss and logits
    model.eval()
    with torch.no_grad():
        logits, _ = model(input_ids, target_ids)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        
        seq_len = len(encoded_text)
        token_log_probs = log_probs[0, range(seq_len-1), target_ids[0][:seq_len-1]]
        
        # Locate token index of the answer part
        prompt_without_answer = text[:text.rfind("A:") + 2]  # Include "A:"
        encoded_prompt = tokenize(prompt_without_answer, tokenizer)
        answer_start_idx = len(encoded_prompt)
        
        # Only calculate loss of answer part
        if answer_start_idx < seq_len:
            answer_log_probs = token_log_probs[answer_start_idx-1:]  # Adjust offset
            if len(answer_log_probs) > 0:
                answer_loss = sum(answer_log_probs).item() / len(answer_log_probs)
            else:
                logger.warning("Answer section is empty")
                answer_loss = float("inf")
        else:
            logger.warning("Answer section not found")
            answer_loss = float("inf")
        
        # Release GPU memory
        del input_ids, target_ids, logits, log_probs
        torch.cuda.empty_cache()
    
    return answer_loss, token_log_probs.tolist()

def generate_sample_text(df, prop):
    """Generate a sample_text (prompt with three examples) for the given prop and return text list for evaluation and related info"""
    group = df[df["prop"] == prop]
    if group.empty:
        logger.warning(f"No data for prop {prop}")
        return None, []

    # Randomly select three examples
    if len(group) >= 3:
        sample_rows = group.sample(n=3, random_state=None)
    else:
        sample_rows = group  # Use all if fewer than 3

    # Build prompt with three examples
    prompt = f"I am preparing training data for a large language model to enhance its ability to extract knowledge from Wikipedia paragraphs. The goal is to teach the model to identify and understand relationships between entities (e.g., {prop}) rather than memorizing text. Below, I provide three example questions and answers related to {prop}:\n\n"
    for _, row in sample_rows.iterrows():
        question = row["question"]
        try:
            possible_answers = ast.literal_eval(row["possible_answers"])
            answer = possible_answers[0] if isinstance(possible_answers, (list, tuple)) else str(possible_answers)
        except Exception as e:
            logger.warning(f"Failed to parse possible_answers for prop {prop}: {row['possible_answers']}, error: {e}")
            continue
        prompt += f"Q: {question}\nA: {answer}\n\n"
    prompt += f"Based on the examples above, please answer the following question:\n\n"

    # Get evaluation questions not in sample
    eval_rows = group[~group.index.isin(sample_rows.index)]
    eval_data = []
    for _, row in eval_rows.iterrows():
        question = row["question"]
        try:
            possible_answers = ast.literal_eval(row["possible_answers"])
            answer = possible_answers[0] if isinstance(possible_answers, (list, tuple)) else str(possible_answers)
        except Exception as e:
            logger.warning(f"Failed to parse possible_answers for prop {prop}: {row['possible_answers']}, error: {e}")
            continue
        eval_text = prompt + f"Q: {question}\nA: {answer}"
        eval_data.append({
            "text": eval_text,
            "o_pop": row["o_pop"],
            "s_pop": row["s_pop"]
        })

    return prompt, eval_data

def main():
    # Load YAML config file
    try:
        config = load_config()
        models = config["models"]
        output_dir = config["out_dir"]
    except Exception as e:
        logger.error(f"Failed to load config file: {e}")
        return

    # Load data
    try:
        df = load_data()
    except Exception as e:
        logger.error(f"Failed to load data: {e}")
        return

    # Get all unique props
    props = df["prop"].unique()

    # Iterate over each model
    for model_path in models:
        # Extract model name from path
        model_name = os.path.basename(model_path).replace(".pt", "")
        logger.info(f"Processing model: {model_name}")

        # Load model
        try:
            model = GPT.from_pretrained(model_path, "cuda")
            model.eval()
            tokenizer = tiktoken.get_encoding("gpt2")
            logger.info(f"Successfully loaded model from {model_path}")
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            continue

        # Create model-specific output directory
        model_output_dir = os.path.join(output_dir, model_name)
        os.makedirs(model_output_dir, exist_ok=True)
        output_file = os.path.join(model_output_dir, "out.txt")

        sample_texts = {}
        losses = {}
        prop_eval_data = {}  # Store eval_data for each prop

        # Generate sample_text and compute loss
        for prop in props:
            sample_text, eval_data = generate_sample_text(df, prop)
            if sample_text is None or not eval_data:
                logger.warning(f"No valid eval text for prop {prop}")
                continue
            sample_texts[prop] = sample_text
            prop_eval_data[prop] = eval_data
            eval_losses = []

            # Evaluate each text one by one
            for item in eval_data:
                try:
                    loss, _ = eval_loss(model, item["text"], tokenizer)
                    eval_losses.append(loss)
                except Exception as e:
                    logger.error(f"Failed to compute loss for single text (prop {prop}): {e}")
                    eval_losses.append(float("inf"))

            # Store loss per eval text
            for i, item in enumerate(eval_data):
                item["loss"] = eval_losses[i]

            # Compute average loss
            valid_losses = [l for l in eval_losses if l != float("inf")]
            if valid_losses:
                mean_loss = sum(valid_losses) / len(valid_losses)
                losses[prop] = -mean_loss  # Keep original negative value logic
                logger.info(f"Average answer loss for prop {prop}: {-mean_loss}")
            else:
                losses[prop] = float("inf")
                logger.warning(f"All eval text losses for prop {prop} are inf")

        # Write to output file
        with open(output_file, "w", encoding="utf-8") as f:
            # Merge all props' eval results
            for prop in sample_texts:
                for item in prop_eval_data.get(prop, []):
                    f.write(f"Prompt: {item['text']}\n")
                    f.write(f"loss: {-item['loss']}\n")
                    f.write(f"o_pop: {item['o_pop']}\n")
                    f.write(f"s_pop: {item['s_pop']}\n")
                    f.write("-" * 20 + "\n")

            # Compute and write summary stats
            valid_losses = [loss for loss in losses.values() if loss != float("inf")]
            if valid_losses:
                mean_loss = sum(valid_losses) / len(valid_losses)
                max_loss = max(valid_losses)
                min_loss = min(valid_losses)
                max_prop = [prop for prop, loss in losses.items() if loss == max_loss][0]
                min_prop = [prop for prop, loss in losses.items() if loss == min_loss][0]

                f.write("Summary:\n")
                f.write(f"Mean Answer Loss: {mean_loss}\n")
                f.write(f"Max Answer Loss: {max_loss} (Prop: {max_prop})\n")
                f.write(f"Min Answer Loss: {min_loss} (Prop: {min_prop})\n")
            else:
                f.write("Summary: No valid losses calculated.\n")

        logger.info(f"Results written to {output_file}")

        # Release GPU memory
        del model
        torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
