import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import random
import time
import os
import pandas as pd

# --- 1. Global Parameter Setup ---
MASTER_SEED = 2025
NUM_TRIALS = 10
NUM_ITERATIONS = 20
GENERATION_LENGTH = 75
CONTEXT_WINDOW = 1000
MODEL_NAME = "distilgpt2"
ENTROPY_RESERVOIR = [
    "The study of information geometry provides a powerful lens for understanding machine learning.",
    "Quantum computing promises to revolutionize fields from medicine to materials science.",
    "The Amazon rainforest is a vital carbon sink for the entire planet.",
    "Ancient Roman aqueducts were marvels of civil engineering.",
    "A well-balanced diet should include a variety of fruits and vegetables.",
    "Jazz music often features complex improvisation and syncopated rhythms.",
    "The theory of relativity fundamentally changed our understanding of space and time.",
    "Machine learning models can sometimes amplify biases present in their training data.",
    "Photosynthesis is the process plants use to convert light into chemical energy.",
    "Literary modernism broke from traditional styles of prose and verse.",
]
initial_prompt = "Tell me about Artificial intelligence."
CSV_FILENAME = "experiment_results_v2.csv"


# --- 2. Helper Functions ---
def get_ngrams(text, n):
    """Extracts n-grams from a given text."""
    tokens = text.split()
    if len(tokens) < n:
        return set()
    return set([" ".join(tokens[i : i + n]) for i in range(len(tokens) - n + 1)])


def calculate_diversity_metrics(text):
    """Calculates the number of unique bigrams and trigrams in a text."""
    unique_bigrams = len(get_ngrams(text, 2))
    unique_trigrams = len(get_ngrams(text, 3))
    return unique_bigrams, unique_trigrams


# --- 3. Core Simulation Function ---
def run_llm_simulation_final(
    model, tokenizer, use_reservoir, initial_prompt, num_iterations, device
):
    """
    Runs the LLM self-referential loop simulation to validate the Entropy-Reservoir Bregman Projection (ERBP) framework.
    This function models the model's iterative process through the four core steps of the ERBP framework.

    Mapping of ERBP Framework Steps to Code Implementation:
    --------------------------------------------------------------------------------
    1. Empirical Sampling:
       - Theory: Sample m points from the current state distribution P_t to get the empirical distribution P_hat_t.
       - Implementation: `model.generate(...)` generates `newly_generated_text` (representing empirical samples P_hat_t)
         based on `current_prompt` (representing the context for P_t).

    2. Measurement:
       - Theory: Evaluate the diversity/entropy of the samples.
       - Implementation: `calculate_diversity_metrics(...)` computes n-gram diversity of the generated text
         as a proxy for entropy.

    3. Mixing with the Reservoir:
       - Theory: Form the mixed target distribution Y_bar_t = (1 - λ) * P_hat_t + λ * P_res_t.
       - Implementation:
         - Collapse Mode (λ=0): `next_prompt` is composed solely of `full_text` (P_hat_t).
         - Stable Mode (λ>0): `next_prompt` is a mix of `full_text` (P_hat_t) and a sentence
           drawn from the `ENTROPY_RESERVOIR` (P_res_t). `random.choice(ENTROPY_RESERVOIR)`
           implements a single sample draw from the reservoir distribution P_res_t.

    4. Projection Update:
       - Theory: Update the model state to P_{t+1} = argmin_P B_F(P, Y_bar_t).
       - Implementation: In this simulation, the "projection" is implicitly implemented by updating
         the context for the next iteration. By setting `current_prompt = next_prompt`, we change
         the model's conditional distribution for the next round, thus inducing a state transition
         from P_t to P_{t+1} that "moves towards" the mixed target Y_bar_t.
    --------------------------------------------------------------------------------

    Args:
        model: The GPT-2 model used for text generation.
        tokenizer: The corresponding tokenizer.
        use_reservoir (bool): Whether to use the entropy reservoir. True for Stable Mode (λ>0), False for Collapse Mode (λ=0).
        initial_prompt (str): The initial prompt to start the simulation.
        num_iterations (int): The total number of iterations for the simulation.
        device: 'cuda' or 'cpu'.

    Returns:
        np.array: An array of shape (num_iterations, 2) recording the (unique_bigrams, unique_trigrams) for each iteration.
    """
    current_prompt = initial_prompt
    diversity_history = []

    for iter_num in range(num_iterations):
        # --- 1. Generation: The model samples from P_t to get empirical samples P_hat_t ---
        inputs = tokenizer(current_prompt, return_tensors="pt", truncation=False).to(
            device
        )
        input_ids = inputs["input_ids"]

        max_len = (
            tokenizer.model_max_length
            if hasattr(tokenizer, "model_max_length")
            else 1024
        )
        if input_ids.shape[1] > max_len - GENERATION_LENGTH:
            input_ids = input_ids[:, -(max_len - GENERATION_LENGTH) :]

        output_sequences = model.generate(
            input_ids=input_ids,
            max_new_tokens=GENERATION_LENGTH,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )

        full_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
        input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        newly_generated_text = full_text[len(input_text) :].strip()

        # --- 2. Measurement: Evaluate the diversity of the empirical samples ---
        metrics = calculate_diversity_metrics(newly_generated_text)
        diversity_history.append(metrics)

        # --- 3. Update: Form the mixed target Y_bar_t for the next round (t+1) ---
        # This is the core difference between the two modes.
        if use_reservoir:
            # STABLE MODE: Mix empirical samples with the entropy reservoir. Injection happens every time.
            # This represents a constant coupling λ > 0.
            next_prompt = full_text + " " + random.choice(ENTROPY_RESERVOIR)
        else:
            # COLLAPSE MODE: The target is composed solely of the empirical samples.
            # This represents a coupling of λ = 0.
            next_prompt = full_text

        # --- 4. Projection: Update the state to P_{t+1} ---
        # In our simulation, the "projection" is simply setting the input prompt for the next iteration.
        all_tokens = tokenizer.encode(next_prompt)
        if len(all_tokens) > CONTEXT_WINDOW:
            current_prompt = tokenizer.decode(all_tokens[-CONTEXT_WINDOW:])
        else:
            current_prompt = next_prompt

    return np.array(diversity_history)


# --- 4. Main Experiment Flow ---
def main():
    print("Loading model and tokenizer...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
    model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(device)
    tokenizer.pad_token = tokenizer.eos_token
    print(f"Model loaded on {device}.")

    start_trial_id = 0
    if os.path.exists(CSV_FILENAME):
        try:
            df = pd.read_csv(CSV_FILENAME)
            if not df.empty:
                start_trial_id = df["trial_id"].max()
        except (pd.errors.EmptyDataError, KeyError):
            start_trial_id = 0

    if start_trial_id == 0 and (
        not os.path.exists(CSV_FILENAME) or os.path.getsize(CSV_FILENAME) == 0
    ):
        with open(CSV_FILENAME, "w") as f:
            f.write("trial_id,iteration,mode,unique_bigrams,unique_trigrams\n")
        print(f"Created new results file: {CSV_FILENAME}")

    print(f"Master Seed is set to: {MASTER_SEED}")
    print(f"Starting from Trial ID {start_trial_id + 1}...")

    results_to_append = []
    for i in range(NUM_TRIALS):
        current_trial_id = start_trial_id + i + 1
        print(f"\n{'='*25} Starting Trial {current_trial_id} {'='*25}\n")

        derived_seed = MASTER_SEED + current_trial_id

        # --- Mode 1: Collapse ---
        print(f"Seeding with {derived_seed} for 'collapse' mode...")
        torch.manual_seed(derived_seed)
        random.seed(derived_seed)
        np.random.seed(derived_seed)
        diversity_collapse = run_llm_simulation_final(
            model,
            tokenizer,
            use_reservoir=False,
            initial_prompt=initial_prompt,
            num_iterations=NUM_ITERATIONS,
            device=device,
        )

        # --- Mode 2: Stable ---
        print(f"Re-seeding with {derived_seed} for 'stable' mode...")
        torch.manual_seed(derived_seed)
        random.seed(derived_seed)
        np.random.seed(derived_seed)
        diversity_stable = run_llm_simulation_final(
            model,
            tokenizer,
            use_reservoir=True,
            initial_prompt=initial_prompt,
            num_iterations=NUM_ITERATIONS,
            device=device,
        )

        # --- Collect Results ---
        for t in range(NUM_ITERATIONS):
            results_to_append.append(
                f"{current_trial_id},{t+1},collapse,{diversity_collapse[t, 0]},{diversity_collapse[t, 1]}\n"
            )
            results_to_append.append(
                f"{current_trial_id},{t+1},stable,{diversity_stable[t, 0]},{diversity_stable[t, 1]}\n"
            )
            print(
                f"Trial {current_trial_id}, Iter {t+1}, collapse: bigrams={diversity_collapse[t, 0]}, trigrams={diversity_collapse[t, 1]}"
            )
            print(
                f"Trial {current_trial_id}, Iter {t+1}, stable:   bigrams={diversity_stable[t, 0]}, trigrams={diversity_stable[t, 1]}"
            )

        print(f"Trial {current_trial_id} completed.")

    if results_to_append:
        with open(CSV_FILENAME, "a") as f:
            f.writelines(results_to_append)
        print(f"\nAll trial data appended to {CSV_FILENAME}.")

    print("\n--- Data generation complete. ---")


if __name__ == "__main__":
    main()