#!/usr/bin/env python
"""
Creativity Experiment Template
ICLR 2026 Submission Code

This script implements parallel vs sequential comparison for creative tasks.
Uses joke generation to measure creativity scaling paradigms.
"""

import requests
import json
import numpy as np
from sentence_transformers import SentenceTransformer, util
import os
import sys
from collections import Counter
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize
import nltk
import time

# --- Prerequisites ---
# NLTK requires one-time downloads of tokenizer data.
REQUIRED_NLTK_RESOURCES = ['punkt', 'punkt_tab']
for resource in REQUIRED_NLTK_RESOURCES:
    try:
        nltk.data.find(f'tokenizers/{resource}')
    except LookupError:
        print(f"NLTK '{resource}' resource not found. Downloading...")
        nltk.download(resource, quiet=True)
        print("Download complete.")

# --- Configuration ---
API_KEY = "YOUR_OPENROUTER_KEY"  # Replace with your OpenRouter API key
API_URL = "https://openrouter.ai/api/v1/chat/completions"
MODEL_ID = "openai/gpt-4o"  # Default model for creativity experiments
SENTENCE_TRANSFORMER_MODEL = 'all-MiniLM-L6-v2'
STEPS_TO_RUN = [3, 6, 9]  # Different chain lengths to test
NUM_ITERATIONS = 10  # Number of iterations for statistical significance

# ==============================================================================
# 1. ENHANCED PROMPTS
# ==============================================================================

SYSTEM_PROMPT = """You are a highly original comedian, a master of wit and surprise, performing at a prestigious comedy festival. Your goal is to generate a single, high-concept joke. Avoid clichés, puns, and simple one-liners. Aim for clever observational humor, intellectual absurdity, or a short narrative with an unexpected twist. Your humor should be sophisticated. Your response MUST be ONLY the joke text, with no preamble, explanation, or quotation marks."""

INITIAL_USER_PROMPT = "A genuinely original joke, please."

def get_refinement_prompt(previous_joke, history):
    """
    Creates a directive prompt for sequential refinement, encouraging
    structural changes rather than simple paraphrasing.
    """
    history_str = "\\n".join([f"- Step {i+1}: {j}" for i, j in enumerate(history)])
    return f"""An earlier version of your joke was:
"{previous_joke}"

Now, perform a creative leap. Your task is to radically improve it, not just rephrase it. Consider one of the following refinement strategies:
1.  **Change the Premise:** Keep the punchline's core idea, but invent a completely new setup for it.
2.  **Invert the Logic:** Flip the joke's core assumption on its head.
3.  **Introduce an Unexpected Element:** Add a character, object, or concept that sends the joke in a completely new direction.
4.  **Punch-Up the Punchline:** Keep the setup, but write a punchline that is far more surprising or clever.

Your goal is novelty and a significant creative evolution. For context, here is the refinement chain so far:
{history_str}

Generate the next, dramatically improved version of the joke. Your response must be ONLY the new joke text."""

# ==============================================================================
# 2. OPENROUTER API INTERACTION
# ==============================================================================

def call_openrouter_api(prompt, system_prompt, retries=3, delay=5):
    """Invokes the specified model via OpenRouter API with retry logic."""
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    data = {
        "model": MODEL_ID,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}
        ],
        "temperature": 0.8
    }

    for i in range(retries):
        try:
            response = requests.post(API_URL, headers=headers, json=data, timeout=60)
            response.raise_for_status()
            response_body = response.json()
            return response_body['choices'][0]['message']['content'].strip()
        except requests.exceptions.RequestException as e:
            if i < retries - 1:
                print(f"  API call failed, retrying in {delay}s... ({e})")
                time.sleep(delay)
            else:
                print(f"ERROR: Could not connect to OpenRouter after {retries} attempts.")
                print(f"Details: {e}")
                sys.exit(1)

# ==============================================================================
# 3. DIVERSITY METRICS
# ==============================================================================

def calculate_semantic_diversity(jokes, model):
    """Calculate semantic diversity using sentence transformers."""
    if len(jokes) < 2: return 0.0
    embeddings = model.encode(jokes, convert_to_tensor=True)
    cosine_scores = util.cos_sim(embeddings, embeddings)
    total_similarity = sum(cosine_scores[i][j].item() for i in range(len(jokes)) for j in range(i + 1, len(jokes)))
    return total_similarity / (len(jokes) * (len(jokes) - 1) / 2)

def calculate_lexical_diversity(jokes):
    """Calculate lexical diversity (MSTTR)."""
    if not jokes: return 0.0
    text = " ".join(jokes)
    tokens = word_tokenize(text.lower())
    if not tokens: return 0.0
    segment_len = 100
    if len(tokens) < segment_len:
        return len(set(tokens)) / len(tokens) if tokens else 0.0
    else:
        ttrs = [len(set(tokens[i:i+segment_len])) / segment_len for i in range(0, len(tokens), segment_len) if len(tokens[i:i+segment_len]) == segment_len]
        return sum(ttrs) / len(ttrs) if ttrs else 0.0

def calculate_self_bleu(jokes):
    """Calculate Self-BLEU score for n-gram diversity."""
    if len(jokes) < 2: return 0.0
    total_bleu = 0.0
    chencherry = SmoothingFunction().method1
    tokenized_jokes = [word_tokenize(j.lower()) for j in jokes]
    for i, hypothesis in enumerate(tokenized_jokes):
        references = tokenized_jokes[:i] + tokenized_jokes[i+1:]
        total_bleu += sentence_bleu(references, hypothesis, smoothing_function=chencherry)
    return total_bleu / len(jokes)

# ==============================================================================
# 4. EXPERIMENT EXECUTION LOGIC
# ==============================================================================

def run_experiment(steps):
    """Runs one full experiment (parallel and sequential) for a given step count."""
    print(f"\\n{'='*25} RUNNING EXPERIMENT FOR {steps} STEPS {'='*25}")
    
    # Parallel
    print(f"\\n--- Starting Parallel Generation ({steps} jokes) ---")
    parallel_jokes = [call_openrouter_api(INITIAL_USER_PROMPT, SYSTEM_PROMPT) for i in range(steps)]
    print("--- Parallel Generation Complete ---")

    # Sequential
    print(f"\\n--- Starting Sequential Refinement ({steps} steps) ---")
    sequential_jokes, history = [], []
    joke = call_openrouter_api(INITIAL_USER_PROMPT, SYSTEM_PROMPT)
    sequential_jokes.append(joke)
    history.append(joke)
    for i in range(1, steps):
        prompt = get_refinement_prompt(joke, history)
        joke = call_openrouter_api(prompt, SYSTEM_PROMPT)
        sequential_jokes.append(joke)
        history.append(joke)
    print("--- Sequential Refinement Complete ---")
    
    return parallel_jokes, sequential_jokes

# ==============================================================================
# 5. MAIN SCRIPT
# ==============================================================================

def main():
    print(f"Loading sentence transformer model ('{SENTENCE_TRANSFORMER_MODEL}')...")
    st_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
    
    # Initialize a structure to hold the results from all iterations
    results = {}
    for steps in STEPS_TO_RUN:
        results[steps] = {
            "parallel":   {"sem_div": [], "msttr": [], "self_bleu": []},
            "sequential": {"sem_div": [], "msttr": [], "self_bleu": []}
        }

    for i in range(1, NUM_ITERATIONS + 1):
        print(f"\\n\\n{'#'*30} STARTING ITERATION {i}/{NUM_ITERATIONS} {'#'*30}")
        for steps in STEPS_TO_RUN:
            p_jokes, s_jokes = run_experiment(steps)

            print(f"\\n--- Analyzing Diversity for {steps} steps (Iteration {i}) ---")
            results[steps]["parallel"]["sem_div"].append(calculate_semantic_diversity(p_jokes, st_model))
            results[steps]["sequential"]["sem_div"].append(calculate_semantic_diversity(s_jokes, st_model))
            
            results[steps]["parallel"]["msttr"].append(calculate_lexical_diversity(p_jokes))
            results[steps]["sequential"]["msttr"].append(calculate_lexical_diversity(s_jokes))
            
            results[steps]["parallel"]["self_bleu"].append(calculate_self_bleu(p_jokes))
            results[steps]["sequential"]["self_bleu"].append(calculate_self_bleu(s_jokes))
            print("--- Analysis Complete ---")

    # --- Print Final Report with Averaged Results ---
    print("\\n\\n" + "="*80)
    print(" " * 15 + f"Creative Scaling Experiment: Final Results Summary ({NUM_ITERATIONS} Iterations)")
    print("="*80)
    print(f"Model: {MODEL_ID}\\n")
    print(f"{'Metric':<25} {'Steps':<10} {'Parallel (μ ± σ)':<25} {'Sequential (μ ± σ)':<25} {'Winner'}")
    print("-" * 80)
    
    metrics = [
        ("Semantic Diversity (Cosine Sim)", "sem_div", "Lower is better"),
        ("Lexical Diversity (MSTTR)", "msttr", "Higher is better"),
        ("N-Gram Diversity (Self-BLEU)", "self_bleu", "Lower is better")
    ]
    
    for metric_name, key, note in metrics:
        print(f"\\033[1m{metric_name}\\033[0m ({note})")
        for steps in STEPS_TO_RUN:
            p_vals = results[steps]["parallel"][key]
            s_vals = results[steps]["sequential"][key]
            
            p_mean, p_std = np.mean(p_vals), np.std(p_vals)
            s_mean, s_std = np.mean(s_vals), np.std(s_vals)

            is_seq_winner = (s_mean < p_mean if "Lower" in note else s_mean > p_mean)
            winner = "\\033[92mSequential\\033[0m" if is_seq_winner else "\\033[94mParallel\\033[0m"
            
            p_str = f"{p_mean:.4f} ± {p_std:.4f}"
            s_str = f"{s_mean:.4f} ± {s_std:.4f}"

            print(f"{'':<25} {steps:<10} {p_str:<25} {s_str:<25} {winner}")
        print("-" * 80)
    
    # Save results to file
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    results_file = f"creativity_results_{MODEL_ID.replace('/', '_')}_{timestamp}.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\\nResults saved to: {results_file}")

if __name__ == "__main__":
    main()