import os
import argparse
import pandas as pd
import logging
import ast
import numpy as np
import torch
import tiktoken
from model import GPT

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

EOS_TOKEN = 50256

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def load_data():
    """Load the test.tsv file"""
    try:
        df = pd.read_csv("test.tsv", sep="\t", dtype=str)
        logger.info("Successfully loaded test.tsv")
    except FileNotFoundError:
        logger.error("test.tsv file not found in the current directory")
        raise
    except Exception as e:
        logger.error(f"Failed to load test.tsv: {e}")
        raise

    required_cols = ["question", "possible_answers", "prop", "o_pop", "s_pop"]
    if not all(c in df.columns for c in required_cols):
        logger.error("test.tsv lacks required columns, please include " + ",".join(required_cols))
        raise ValueError("Missing required columns")
    return df

def tokenize(text, tokenizer):
    """Convert text to token IDs"""
    return tokenizer.encode(text)

def eval_loss(model, text, tokenizer=None):
    """
    Evaluate the loss of a single text, only calculating the loss for tokens after "A:".
    
    Args:
        model: GPT model instance
        text: Single string
        tokenizer: Tokenizer instance, if None, load from tiktoken
        
    Returns:
        tuple: (loss value as float, list of log probabilities for each token)
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Use the provided tokenizer or load from tiktoken
    if tokenizer is None:
        tokenizer = tiktoken.get_encoding("gpt2")
    
    # Encode text
    encoded_text = tokenize(text, tokenizer)
    max_len = len(encoded_text)
    
    # Convert to tensor
    input_ids = torch.tensor([encoded_text], dtype=torch.long, device=device)
    target_ids = input_ids[:, 1:].clone()
    input_ids = input_ids[:, :-1]
    
    # Get loss and logits
    model.eval()
    with torch.no_grad():
        logits, _ = model(input_ids, target_ids)
        log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
        
        seq_len = len(encoded_text)
        token_log_probs = log_probs[0, range(seq_len-1), target_ids[0][:seq_len-1]]
        
        # Locate the token indices for the answer part
        prompt_without_answer = text[:text.rfind("A:") + 2]  # Include "A:"
        encoded_prompt = tokenize(prompt_without_answer, tokenizer)
        answer_start_idx = len(encoded_prompt)
        
        # Only calculate the loss for the answer part
        if answer_start_idx < seq_len:
            answer_log_probs = token_log_probs[answer_start_idx-1:]  # Adjust offset
            if len(answer_log_probs) > 0:
                answer_loss = sum(answer_log_probs).item() / len(answer_log_probs)
            else:
                logger.warning("Answer part is empty")
                answer_loss = float("inf")
        else:
            logger.warning("Answer part not found")
            answer_loss = float("inf")
        
        # Release GPU memory
        del input_ids, target_ids, logits, log_probs
        torch.cuda.empty_cache()
    
    return answer_loss, token_log_probs.tolist()

def generate_sample_text(df, prop):
    """Generate a sample_text (prompt with three examples) for the specified prop and return the list of texts for evaluation along with related information"""
    group = df[df["prop"] == prop]
    if group.empty:
        logger.warning(f"Prop {prop} has no data")
        return None, []

    # Randomly select three samples as examples
    if len(group) >= 3:
        sample_rows = group.sample(n=3, random_state=None)
    else:
        sample_rows = group  # If fewer than 3, use all

    # Construct the prompt, including three examples
    prompt = f"I am preparing training data for a large language model to enhance its ability to extract knowledge from Wikipedia paragraphs. The goal is to teach the model to identify and understand relationships between entities (e.g., {prop}) rather than memorizing text. Below, I provide three example questions and answers related to {prop}:\n\n"
    for _, row in sample_rows.iterrows():
        question = row["question"]
        try:
            possible_answers = ast.literal_eval(row["possible_answers"])
            answer = possible_answers[0] if isinstance(possible_answers, (list, tuple)) else str(possible_answers)
        except Exception as e:
            logger.warning(f"Prop {prop} possible_answers parsing failed: {row['possible_answers']}, error: {e}")
            continue
        prompt += f"Q: {question}\nA: {answer}\n\n"
    prompt += f"Based on the examples above, please answer the following question:\n\n"

    # Get all questions not included in the samples for evaluation
    eval_rows = group[~group.index.isin(sample_rows.index)]
    eval_data = []
    for _, row in eval_rows.iterrows():
        question = row["question"]
        try:
            possible_answers = ast.literal_eval(row["possible_answers"])
            answer = possible_answers[0] if isinstance(possible_answers, (list, tuple)) else str(possible_answers)
        except Exception as e:
            logger.warning(f"Prop {prop} possible_answers parsing failed: {row['possible_answers']}, error: {e}")
            continue
        eval_text = prompt + f"Q: {question}\nA: {answer}"
        eval_data.append({
            "text": eval_text,
            "o_pop": row["o_pop"],
            "s_pop": row["s_pop"]
        })

    return prompt, eval_data

def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Generate sample texts and evaluate loss for each prop")
    parser.add_argument("--model-path", required=True, help="Path to the pre-trained GPT model")
    args = parser.parse_args()

    # Load data
    try:
        df = load_data()
    except Exception as e:
        logger.error(f"Data loading failed: {e}")
        return

    # Load model
    try:
        model = GPT.from_pretrained(args.model_path, "cuda")
        model.eval()
        tokenizer = tiktoken.get_encoding("gpt2")
        logger.info(f"Successfully loaded model from {args.model_path}")
    except Exception as e:
        logger.error(f"Model loading failed: {e}")
        return

    # Create output directory
    output_dir = "./fineweb10B"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, "out.txt")

    # Get all unique props
    props = df["prop"].unique()
    sample_texts = {}
    losses = {}
    prop_eval_data = {}  # Store eval_data for each prop

    # Generate sample_text and calculate loss
    for prop in props:
        sample_text, eval_data = generate_sample_text(df, prop)
        if sample_text is None or not eval_data:
            logger.warning(f"Prop {prop} has no valid evaluation texts")
            continue
        sample_texts[prop] = sample_text
        prop_eval_data[prop] = eval_data  # Save eval_data
        eval_losses = []

        # Evaluate each text individually
        for item in eval_data:
            try:
                loss, _ = eval_loss(model, item["text"], tokenizer)
                eval_losses.append(loss)
            except Exception as e:
                logger.error(f"Prop {prop} single text loss calculation failed: {e}")
                eval_losses.append(float("inf"))

        # Store the loss for each evaluated text
        for i, item in enumerate(eval_data):
            item["loss"] = eval_losses[i]

        # Calculate average loss
        valid_losses = [l for l in eval_losses if l != float("inf")]
        if valid_losses:
            mean_loss = sum(valid_losses) / len(valid_losses)
            losses[prop] = -mean_loss  # Keep the negative value processing from the original code
            logger.info(f"Prop {prop} average answer loss: {-mean_loss}")
        else:
            losses[prop] = float("inf")
            logger.warning(f"Prop {prop} all evaluation text losses are inf")

    # Write to output file
    with open(output_file, "w", encoding="utf-8") as f:
        # Combine evaluation results for all props
        for prop in sample_texts:
            for item in prop_eval_data.get(prop, []):
                f.write(f"Prompt: {item['text']}\n")
                f.write(f"loss: {-item['loss']}\n")
                f.write(f"o_pop: {item['o_pop']}\n")
                f.write(f"s_pop: {item['s_pop']}\n")
                f.write("-" * 20 + "\n")

        # Calculate and write statistics
        valid_losses = [loss for loss in losses.values() if loss != float("inf")]
        if valid_losses:
            mean_loss = sum(valid_losses) / len(valid_losses)
            max_loss = max(valid_losses)
            min_loss = min(valid_losses)
            max_prop = [prop for prop, loss in losses.items() if loss == max_loss][0]
            min_prop = [prop for prop, loss in losses.items() if loss == min_loss][0]

            f.write("Summary:\n")
            f.write(f"Mean Answer Loss: {mean_loss}\n")
            f.write(f"Max Answer Loss: {max_loss} (Prop: {max_prop})\n")
            f.write(f"Min Answer Loss: {min_loss} (Prop: {min_prop})\n")
        else:
            f.write("Summary: No valid losses calculated.\n")

    logger.info(f"Results written to {output_file}")

if __name__ == "__main__":
    main()