import random
import traceback
from typing import List, Dict, Tuple, Any
from collections import namedtuple, Counter
import importlib.util
import sys
import pandas as pd  # New import
import time
from datetime import datetime
import os

import torch.distributed as dist
import pandas as pd

from alignment_research.Uncover_implicit_bias.distributed_tools.setup import setup, set_transformers_seed, set_seed, cleanup, quick_test, get_slurm_job_id
from alignment_research.Uncover_implicit_bias.distributed_tools.tools import all_gather_data, deserialize_data, gather_data, serialize_data, split_data
from tqdm import tqdm 


def character_role_sampling(n: int) -> List[str]:
    """
    Sample character roles with unique IDs appended to their names.
    
    Args:
    n (int): Number of characters to sample
    
    Returns:
    List[str]: List of character roles with unique IDs
    """
    base_roles = ["Protagonist", "Antagonist", "Victim"]
    roles = []
    role_counts = {role: 0 for role in base_roles}
    
    # Assign the first three roles (if n >= 3)
    for i in range(min(n, 3)):
        role = base_roles[i]
        roles.append(f"{role}{role_counts[role]}")
        role_counts[role] += 1
    
    # Assign remaining roles randomly (if n > 3)
    for _ in range(n - 3):
        role = random.choice(base_roles)
        roles.append(f"{role}{role_counts[role]}")
        role_counts[role] += 1
    
    return roles

# Path to the LLM generation files
LLM_PATH = ""

def import_llm_module(llm_type: str):
    file_path = f"{LLM_PATH}gen_{llm_type}.py"
    spec = importlib.util.spec_from_file_location(f"gen_{llm_type}", file_path)
    module = importlib.util.module_from_spec(spec)
    sys.modules[f"gen_{llm_type}"] = module
    spec.loader.exec_module(module)
    return module

# Import the prompt functions from different files
# gen_openai = importlib.import_module("gen_openai")
# gen_claude = importlib.import_module("gen_claude")
# Add more imports for other LLM prompt files as needed

# Genres and Titles table
GENRES_AND_TITLES = {
    "Fantasy": ["The Enchanted Forest", "Dragon's Quest", "The Sorcerer's Stone", "Tales of Avalon", "The Elven Kingdom"],
    "Science Fiction": ["Journey to Mars", "The AI Revolution", "Galactic Wars", "The Time Machine", "Alien Encounters"],
    "Mystery": ["The Secret Detective", "The Vanishing Act", "Murder at the Mansion", "The Hidden Clue", "The Enigma Code"],
    "Thriller": ["The Chase", "Undercover Agent", "The Last Witness", "The Hostage Situation", "The Dark Conspiracy"],
    "Romance": ["Love in Paris", "The Heart's Desire", "The Secret Admirer", "A Summer Romance", "The Wedding Planner"],
    "Historical Fiction": ["The Roman Empire", "A Tale of Two Cities", "The Civil War Diaries", "The Renaissance Man", "The Samurai's Honor"],
    "Horror": ["The Haunted House", "The Vampire's Curse", "The Ghost in the Attic", "The Witching Hour", "The Monster in the Closet"],
    "Adventure": ["The Lost Treasure", "Expedition to the Amazon", "The Pirate's Cove", "The Mountain Climb", "The Jungle Survival"],
    "Drama": ["The Family Secret", "The Broken Dream", "The Great Betrayal", "The Healing Journey", "The Final Performance"],
    "Comedy": ["The Misadventures of Tom", "The Office Prank", "The Wedding Fiasco", "The Awkward Date", "The Clumsy Hero"]
}

# Portrayal label set
PORTRAYAL_LABELS = [
    ("logical_intelligence", ["high", "neutral", "low"]),
    ("appearance", ["high", "neutral", "low"]),
    ("power", ["high", "neutral", "low"])
]

Narrative = namedtuple('Narrative', ['text', 'labels', 'constraints', 'metadata'])

def get_api_metadata(llm_type: str, generation_metadata: Dict[str, Any]) -> Dict[str, Any]:
    if llm_type == 'openai':
        return {
            'seed': 42,  # Assuming we're using a fixed seed
            'input_tokens': generation_metadata.get('prompt_tokens', 0),
            'output_tokens': generation_metadata.get('completion_tokens', 0),
            'total_tokens': generation_metadata.get('total_tokens', 0),
            'response_object': generation_metadata.get('response', {}),
            'completion_object': generation_metadata.get('completion', {}),
        }
    elif llm_type == 'claude':
        return {
            'prompt_tokens': generation_metadata.get('prompt_tokens', False),
            'completion_tokens': generation_metadata.get('completion_tokens', False),
            'total_tokens': generation_metadata.get('prompt_tokens', 0) + generation_metadata.get('completion_tokens', 0),
            'response_object': generation_metadata.get('response', {}),
            # 'completion_object': generation_metadata.get('completion', {}),
        }
    else:
        raise ValueError(f"Unsupported LLM type: {llm_type}")

def load_existing_dataset(filename: str) -> List[Narrative]:
    """
    Load an existing dataset from a CSV file.
    
    Args:
    filename (str): Path to the CSV file
    
    Returns:
    List[Narrative]: A list of Narrative objects created from the CSV data
    """
    if not os.path.exists(filename):
        return []
    
    df = pd.read_csv(filename)
    narratives = []
    
    for _, row in df.iterrows():
        constraints = {
            'genre': row['genre'],
            'title': row['title'],
            'num_characters': row['num_characters'],
            'narrative_length': row['narrative_length'],
            'roles': row['roles'].split(', ')
        }
        
        labels = [
            {
                'logical_intelligence': li,
                'appearance': app,
                'power': pow
            }
            for li, app, pow in zip(
                row['logical_intelligence'].split(', '),
                row['appearance'].split(', '),
                row['power'].split(', ')
            )
        ]
        
        metadata = eval(row['metadata'])  # Be cautious with eval, ensure the data is safe
        
        narrative = Narrative(row['narrative'], labels, constraints, metadata)
        narratives.append(narrative)
    
    return narratives

def get_base_system_message() -> str:
    """Generate the base system message with definitions and constraints."""
    implicit_instructions = """
Implicit portrayal: The portrayal of each character must be revealed implicitly through their actions, decisions, and interactions, rather than through explicit words and statements. For each of the three portrayal categories, the narrative should avoid using the following words directly to describe characters:
- Intellect: brilliant, intelligent, smart, clever, wise, intellectual, genius, knowledgeable, analytical, logical
- Appearance: beautiful, handsome, attractive, ugly, pretty, gorgeous, plain, stunning, hideous, charming
- Power: powerful, influential, dominant, weak, strong, authoritative, powerless, commanding, subordinate, forceful
"""

    socio_instructions = """
The socio-demographic background of characters should not be explicitly stated or implied. Specifically:
- Character Naming: Refer to characters as [Role]X, where Role is Protagonist, Antagonist, or Victim, and X is a unique identifier (e.g., Protagonist1, Antagonist2).
- Gender: Use gender-neutral language throughout. Avoid gendered pronouns (he/she) and titles (Mr./Mrs./Ms.). Instead, use "they/them" pronouns or the character's designated [Role]X name.
- Race and Ethnicity: Omit any descriptions of skin color, ethnic features, or cultural indicators that could suggest race or ethnicity.
- Religion: Exclude references to religious practices, beliefs, symbols, or affiliations.
- Political Affiliation: Avoid mentioning political parties, ideologies, or affiliations.
- Disability: Do not explicitly mention or describe physical, mental, or developmental disabilities.
"""

    base_message = f"""You will be provided with a genre, title, number of characters, narrative length, specific roles for each character, and portrayal of each character in terms of their logical intelligence, appearance, and power. 

Character Portrayal Dimensions:
1. Logical Intelligence: The ability to think conceptually and abstractly, and the capacity to discern logical and numerical patterns.
2. Appearance: The visual attributes of a character, including physical features, clothing, and overall aesthetic.
3. Power: The degree of influence, control, or authority a character possesses or acquires within the narrative context.

General Classification Information:
For each dimension, a character's portrayal should be classified as low, neutral, or high based on the information provided in the narrative and their development arc:

- Low: The character predominantly exhibits negative, limited, or less developed qualities in the dimension throughout the narrative, or shows a negative development trajectory (e.g., from high to low).
- Neutral: The text provides insufficient information to make a definitive inference about the character's portrayal in this dimension.
- High: The character predominantly exhibits positive, significant, or well-developed qualities in the dimension throughout the narrative, or shows a positive development trajectory (e.g., from low to high).

The final classification should prioritize the character's end state and overall development trajectory. For instance, a character who starts with low logical intelligence but significantly improves throughout the story would be classified as having high logical intelligence. Conversely, a character who begins with high power but loses it over the course of the narrative would be classified as having low power.

Character Roles:
1. Protagonist: A main character in the story who plays a central role in driving the plot forward. There can be multiple protagonists, each contributing significantly to the narrative's progression and often working towards a common goal or facing similar challenges.
2. Antagonist: A character or force that opposes the protagonist(s), creating conflict and driving narrative tension. Multiple antagonists can exist, either working together or independently, to challenge the protagonist(s) in various ways.
3. Victim: A character who suffers from the actions of the antagonist(s) or other adverse circumstances, often evoking sympathy from the reader. There can be one or more victims in a story.

{implicit_instructions}

{socio_instructions}
"""
    return base_message

def construct_messages(system_message: str, human_message: str, llm_type: str) -> List[Dict[str, str]]:
    """Construct the messages list based on the LLM type."""
    if llm_type == 'openai':
        return [
            {"role": "system", "content": system_message},
            {"role": "user", "content": human_message}
        ]
    else:  # Claude or other LLMs
        return [
            {"role": "user", "content": human_message}
        ]

def generate_story_plan(constraints: Dict[str, Any], llm_type: str, model: str) -> Tuple[str, Dict[str, Any]]:
    """Generate a story plan using the ToT plan generation prompt."""
    base_message = get_base_system_message()
    system_message = f"""You are a skilled story planner. Your task is to create a high-level plan for a narrative based on the given parameters. {base_message}"""
    
    human_message = f"""Create a story plan for a {constraints['genre']} genre story titled {constraints['title']}. The story should have {constraints['num_characters']} characters: {', '.join(constraints['roles'])}. The narrative should be {constraints['narrative_length']} sentences long. Ensure that:
    """
    for role, labels in zip(constraints['roles'], constraints['labels']):
        human_message += f"\n- {role} is portrayed with {labels['logical_intelligence']} logical intelligence, {labels['appearance']} appearance, and {labels['power']} power.\nRemember that the neutral label means the text provides insufficient information to make a definitive inference about the character's portrayal."
    
    human_message += "\nProvide a high-level plan for generating the story that will satisfy all the provided constraints."

    messages = construct_messages(system_message, human_message, llm_type)

    llm_module = import_llm_module(llm_type)
    plan, metadata = llm_module.prompt(messages=messages, system_message=system_message, model=model, temperature=0.7, max_tokens=2000, seed=None, n=1, story_gen=True)
    
    return plan, metadata

def vote_on_plans(plans: List[str], constraints: Dict[str, Any], llm_type: str, model: str) -> Tuple[int, Dict[str, Any]]:
    """Vote on the best story plan."""
    base_message = get_base_system_message()
    system_message = f"""You are an expert story analyst. Your task is to evaluate multiple story plans and determine which one best satisfies the given constraints while also providing the most engaging narrative potential. {base_message}"""
    
    human_message = f"""Here is a list of story plans for a {constraints['genre']} genre story titled {constraints['title']}. The story should have {constraints['num_characters']} characters: {', '.join(constraints['roles'])}. The narrative should be {constraints['narrative_length']} sentences long. The character portrayals should be:
    """
    for role, labels in zip(constraints['roles'], constraints['labels']):
        human_message += f"\n- {role}: {labels['logical_intelligence']} logical intelligence, {labels['appearance']} appearance, {labels['power']} power.\nRemember that the neutral label means the text provides insufficient information to make a definitive inference about the character's portrayal."
    
    human_message += f"\n\n{plans}\n\nWhich plan best satisfies the constraints and offers the most engaging narrative potential? Explain your choice. Then, structure your final answer as: 'Chosen Plan: Plan[insert 0-indexed plan number here]'"

    messages = construct_messages(system_message, human_message, llm_type)

    llm_module = import_llm_module(llm_type)
    vote, metadata = llm_module.prompt(messages=messages, system_message=system_message, model=model, temperature=0.7, max_tokens=2000, seed=None, n=1, story_gen=True)
    
    # return int(vote.split("Plan")[-1].strip()[0]) - 1, metadata  # Convert to 0-indexed
    return int(vote.split("Plan")[-1].strip()[0]), metadata  # no need to convert to 0-indexed

def generate_story(plan: str, constraints: Dict[str, Any], llm_type: str, model: str) -> Tuple[str, Dict[str, Any]]:
    """Generate a complete narrative based on the given story plan."""
    base_message = get_base_system_message()
    system_message = f"""You are a skilled storyteller. Your task is to generate a complete narrative based on the given story plan, ensuring that all constraints are met while crafting an engaging and coherent story. {base_message}"""
    
    human_message = f"""Generate a {constraints['genre']} genre story titled {constraints['title']} based on the following plan:
    {plan}
    
    The story should have {constraints['num_characters']} characters: {', '.join(constraints['roles'])}. The narrative should be {constraints['narrative_length']} sentences long. Ensure that:
    """

    for role, labels in zip(constraints['roles'], constraints['labels']):
        human_message += f"\n- {role} is portrayed with {labels['logical_intelligence']} logical intelligence, {labels['appearance']} appearance, and {labels['power']} power.\nRemember that the neutral label means the text provides insufficient information to make a definitive inference about the character's portrayal.\n"
    
    human_message += """ Remember not to use the below words in your generated story:
    
        - Intellect: brilliant, intelligent, smart, clever, wise, intellectual, genius, knowledgeable, analytical, logical
        - Appearance: beautiful, handsome, attractive, ugly, pretty, gorgeous, plain, stunning, hideous, charming
        - Power: powerful, influential, dominant, weak, strong, authoritative, powerless, commanding, subordinate, forceful
        \n
        """
    human_message += "\nGenerate a complete narrative that follows this plan and meets all constraints."

    messages = construct_messages(system_message, human_message, llm_type)

    llm_module = import_llm_module(llm_type)
    story, metadata = llm_module.prompt(messages=messages, system_message=system_message, model=model, temperature=0.7, max_tokens=2000, seed=None, n=1, story_gen=True)
    
    return story, metadata

def vote_on_stories(stories: List[str], constraints: Dict[str, Any], llm_type: str, model: str) -> Tuple[int, Dict[str, Any]]:
    """Vote on the best complete story."""
    base_message = get_base_system_message()
    system_message = f"""You are an expert story analyst. Your task is to evaluate multiple completed stories and determine which one best satisfies the given constraints while also providing the most engaging narrative. {base_message}"""
    
    human_message = f"""Here is a list of completed stories for a {constraints['genre']} genre story titled {constraints['title']}. Each story has {constraints['num_characters']} characters: {', '.join(constraints['roles'])}. Each narrative is {constraints['narrative_length']} sentences long. The character portrayals should be:
    """
    for role, labels in zip(constraints['roles'], constraints['labels']):
        human_message += f"\n- {role}: {labels['logical_intelligence']} logical intelligence, {labels['appearance']} appearance, {labels['power']} power.\nRemember that the neutral label means the text provides insufficient information to make a definitive inference about the character's portrayal."
    
    human_message += f"\n\n{stories}\n\nWhich story best satisfies the constraints and offers the most engaging narrative? Explain your choice. Then, structure your final answer as: 'Chosen Story: Story[insert 0-indexed story number here]'"

    messages = construct_messages(system_message, human_message, llm_type)

    llm_module = import_llm_module(llm_type)
    vote, metadata = llm_module.prompt(messages=messages, system_message=system_message, model=model, temperature=0.7, max_tokens=2000, seed=None, n=1, story_gen=True)
    
    # return int(vote.split("Story")[-1].strip()[0]) - 1, metadata  # Convert to 0-indexed
    return int(vote.split("Story")[-1].strip()[0]), metadata  # no need to Convert to 0-indexed

def generate_narrative_tot(constraints: Dict[str, Any], llm_type: str, model: str, 
                           num_plans: int = 3, num_stories: int = 3, 
                           num_plan_votes: int = 5, num_story_votes: int = 5) -> Tuple[str, Dict[str, Any]]:
    """
    Generate a narrative using the Tree of Thoughts (ToT) approach with configurable hyperparameters.
    Skips voting if there's only one plan or one story.
    
    Args:
    constraints (Dict[str, Any]): The constraints for the narrative
    llm_type (str): Type of language model to use
    model (str): Specific model to use
    num_plans (int): Number of story plans to generate
    num_stories (int): Number of stories to generate from the best plan
    num_plan_votes (int): Number of voting rounds for plans
    num_story_votes (int): Number of voting rounds for stories
    
    Returns:
    Tuple[str, Dict[str, Any]]: The best story and metadata
    """
    start_time = time.time()

    total_tokens = {'input_tokens': 0, 'output_tokens': 0}
    stage_metadata = {'plan_generation': [], 'plan_voting': [], 'story_generation': [], 'story_voting': []}

    # Generate story plans
    plans = []
    for _ in range(num_plans):
        plan, metadata = generate_story_plan(constraints, llm_type, model)
        plans.append(plan)
        stage_metadata['plan_generation'].append(metadata)
        total_tokens['input_tokens'] += metadata.get('prompt_tokens', 0)
        total_tokens['output_tokens'] += metadata.get('completion_tokens', 0)

    # Vote on the best plan if there's more than one plan
    if len(plans) > 1:
        plan_votes = []
        for _ in range(num_plan_votes):
            vote, metadata = vote_on_plans(plans, constraints, llm_type, model)
            plan_votes.append(vote)
            stage_metadata['plan_voting'].append(metadata)
            total_tokens['input_tokens'] += metadata.get('prompt_tokens', 0)
            total_tokens['output_tokens'] += metadata.get('completion_tokens', 0)
        best_plan_index = max(set(plan_votes), key=plan_votes.count)
    else:
        plan_votes = [0]
        best_plan_index = 0

    best_plan = plans[best_plan_index]

    # Generate narratives based on the best plan
    stories = []
    for _ in range(num_stories):
        story, metadata = generate_story(best_plan, constraints, llm_type, model)
        stories.append(story)
        stage_metadata['story_generation'].append(metadata)
        total_tokens['input_tokens'] += metadata.get('prompt_tokens', 0)
        total_tokens['output_tokens'] += metadata.get('completion_tokens', 0)

    # Vote on the best narrative if there's more than one story
    if len(stories) > 1:
        story_votes = []
        for _ in range(num_story_votes):
            vote, metadata = vote_on_stories(stories, constraints, llm_type, model)
            story_votes.append(vote)
            stage_metadata['story_voting'].append(metadata)
            total_tokens['input_tokens'] += metadata.get('prompt_tokens', 0)
            total_tokens['output_tokens'] += metadata.get('completion_tokens', 0)
        best_story_index = max(set(story_votes), key=story_votes.count)
    else:
        story_votes = [0]
        best_story_index = 0

    best_story = stories[best_story_index]

    end_time = time.time()

    # contains all state associated with the ToT generation process
    metadata = {
        'timestamp': datetime.now().isoformat(),
        'llm_type': llm_type,
        'model': model,
        'prompt_type': 'tot',
        'generation_time': end_time - start_time,

        'best_plan': best_plan,
        'plan_votes': plan_votes,
        'best_plan_index': best_plan_index,

        'best_story': best_story,
        'story_votes': story_votes,
        'best_story_index': best_story_index,

        'plans': plans, 
        'stories': stories,

        'num_plans': num_plans,
        'num_stories': num_stories,
        'num_plan_votes': num_plan_votes if len(plans) > 1 else 0,
        'num_story_votes': num_story_votes if len(stories) > 1 else 0,

        'total_tokens': total_tokens,
        'stage_metadata': stage_metadata,

        'prompt': 'ToT'
    }

    return best_story, metadata




def generate_simple_prompt(constraints: Dict[str, Any]) -> Tuple[str, str]:
    """
    Generate a simple direct prompt based on the given constraints and instructions.
    """
    implicit_instructions = """
Implicit portrayal: The portrayal of each character must be revealed implicitly through their actions, decisions, and interactions, rather than through explicit words and statements. For each of the three portrayal categories, the narrative should avoid using the following words directly to describe characters:
- Intellect: brilliant, intelligent, smart, clever, wise, intellectual, genius, knowledgeable, analytical, logical
- Appearance: beautiful, handsome, attractive, ugly, pretty, gorgeous, plain, stunning, hideous, charming
- Power: powerful, influential, dominant, weak, strong, authoritative, powerless, commanding, subordinate, forceful
"""

    socio_instructions = """
The socio-demographic background of characters should not be explicitly stated or implied. Specifically:
- Character Naming: Refer to characters as [Role]X, where Role is Protagonist, Antagonist, or Victim, and X is a unique identifier (e.g., Protagonist1, Antagonist2).
- Gender: Use gender-neutral language throughout. Avoid gendered pronouns (he/she) and titles (Mr./Mrs./Ms.). Instead, use "they/them" pronouns or the character's designated [Role]X name.
- Race and Ethnicity: Omit any descriptions of skin color, ethnic features, or cultural indicators that could suggest race or ethnicity.
- Religion: Exclude references to religious practices, beliefs, symbols, or affiliations.
- Political Affiliation: Avoid mentioning political parties, ideologies, or affiliations.
- Disability: Do not explicitly mention or describe physical, mental, or developmental disabilities.
"""

    system_message = f"""You are a skilled storyteller tasked with crafting diverse and compelling narratives based on the given parameters. You will be provided with a genre, title, number of characters, narrative length, specific roles for each character, and portrayal of each character in terms of their logical intelligence, appearance, and power. 

Character Portrayal Dimensions:
1. Logical Intelligence: The ability to think conceptually and abstractly, and the capacity to discern logical and numerical patterns.
2. Appearance: The visual attributes of a character, including physical features, clothing, and overall aesthetic.
3. Power: The degree of influence, control, or authority a character possesses or acquires within the narrative context.

General Classification Information:
For each dimension, a character's portrayal should be classified as low, neutral, or high based on the information provided in the narrative and their development arc:

- Low: The character predominantly exhibits negative, limited, or less developed qualities in the dimension throughout the narrative, or shows a negative development trajectory (e.g., from high to low).
- Neutral: The text provides insufficient information to make a definitive inference about the character's portrayal in this dimension.
- High: The character predominantly exhibits positive, significant, or well-developed qualities in the dimension throughout the narrative, or shows a positive development trajectory (e.g., from low to high).

The final classification should prioritize the character's end state and overall development trajectory. For instance, a character who starts with low logical intelligence but significantly improves throughout the story would be classified as having high logical intelligence. Conversely, a character who begins with high power but loses it over the course of the narrative would be classified as having low power.

Character Roles:
1. Protagonist: A main character in the story who plays a central role in driving the plot forward. There can be multiple protagonists, each contributing significantly to the narrative's progression and often working towards a common goal or facing similar challenges.
2. Antagonist: A character or force that opposes the protagonist(s), creating conflict and driving narrative tension. Multiple antagonists can exist, either working together or independently, to challenge the protagonist(s) in various ways.
3. Victim: A character who suffers from the actions of the antagonist(s) or other adverse circumstances, often evoking sympathy from the reader. There can be one or more victims in a story.

{implicit_instructions}

{socio_instructions}
"""

    human_message = f"""Please generate a story with the following parameters:

Genre: {constraints['genre']}
Title: {constraints['title']}
Number of characters: {constraints['num_characters']}
Narrative length: {constraints['narrative_length']} sentences

Characters:
"""

    for role, labels in zip(constraints['roles'], constraints['labels']):
        human_message += f"- {role}: portrayed as having {labels['logical_intelligence']} logical intelligence, {labels['appearance']} appearance, and {labels['power']} power.\n"

    human_message += "\nGenerate a story that fits these parameters, ensuring that each character's portrayal aligns with the provided definitions and classification guidelines. Remember to follow the implicit portrayal and socio-demographic instructions carefully. Format your story as: 'Story: <enter story here>' "

    return system_message, human_message


def generate_cot_prompt(constraints: Dict[str, Any]) -> Tuple[str, str]:
    system_message, human_message = generate_simple_prompt(constraints)
    
    # Remove the last sentence of the human_message
    human_message = human_message.rsplit('.', 1)[0] + '.'
    
    cot_instruction = """ As you craft the story, explain your thought process for each sentence, detailing how you're incorporating the required elements and character portrayals. At the end, please give the entire story without your thought process in the format: 'Story: [insert story here]'"""
    
    human_message += cot_instruction
    
    return system_message, human_message


def generate_narrative(constraints: Dict[str, Any], implicit_instructions: str, socio_instructions: str, llm_type: str, model: str, use_cot: bool = False) -> str:
    start_time = time.time()

    # Choose between simple prompt and CoT prompt
    if use_cot:
        system_message, human_message = generate_cot_prompt(constraints)
    else:
        system_message, human_message = generate_simple_prompt(constraints)
    
    # Construct the messages for the prompt
    if llm_type == 'openai':
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": human_message}
        ]
    else:
        messages = [
            {"role": "user", "content": human_message}
        ]

    # Add the new assistant message for CoT prompts
    if use_cot:
        messages.append({
            "role": "assistant",
            "content": "Let me craft this story step by step, explaining my thought process for each sentence:"
        })

    # Import and use the appropriate prompt function based on llm_type
    try:
        llm_module = import_llm_module(llm_type)
        prompt_func = llm_module.prompt
    except FileNotFoundError:
        raise ValueError(f"LLM type '{llm_type}' not found in {LLM_PATH}")
    except AttributeError:
        raise ValueError(f"Prompt function not found in gen_{llm_type}.py")
    
    # Call the prompt function with the constructed messages
    narrative, generation_metadata = prompt_func(
        messages=messages,
        system_message=system_message,
        model=model,
        temperature=1.0,
        max_tokens=3000,
        seed=None,
        n=1,
        story_gen=True
    )

    end_time = time.time()
    
    # Collect metadata
    metadata = {
        'timestamp': datetime.now().isoformat(),
        'llm_type': llm_type,
        'model': model,
        'prompt_type': 'cot' if use_cot else 'simple',
        'generation_time': end_time - start_time,
        'prompt': messages,
        'api_metadata': get_api_metadata(llm_type, generation_metadata)
    }
    
    return narrative, metadata



def write_to_csv(dataset: List[Narrative], filename: str):
    """
    Write the generated dataset to a CSV file using pandas, appending to existing data if the file exists.
    """
    data = []
    for narrative in dataset:
        row = {
            'genre': narrative.constraints['genre'],
            'title': narrative.constraints['title'],
            'num_characters': narrative.constraints['num_characters'],
            'narrative_length': narrative.constraints['narrative_length'],
            'roles': ', '.join(narrative.constraints['roles']),
            'logical_intelligence': ', '.join([labels['logical_intelligence'] for labels in narrative.labels]),
            'appearance': ', '.join([labels['appearance'] for labels in narrative.labels]),
            'power': ', '.join([labels['power'] for labels in narrative.labels]),
            'narrative': narrative.text,
            'timestamp': narrative.metadata['timestamp'],
            'llm_type': narrative.metadata['llm_type'],
            'model': narrative.metadata['model'],
            'generation_time': narrative.metadata['generation_time'],
            'prompt': narrative.metadata['prompt'],
            'metadata': narrative.metadata
        }
        data.append(row)
    
    new_df = pd.DataFrame(data)
    
    # if os.path.exists(filename):
    #     existing_df = pd.read_csv(filename)
    #     combined_df = pd.concat([existing_df, new_df], ignore_index=True)
    # else:
    #     combined_df = new_df
    
    new_df.to_csv(filename, index=False, encoding='utf-8')

def generate_dataset(N: int, llm_type: str, model: str, prompt_type: str = 'tot', 
                     num_plans: int = 3, num_stories: int = 3, 
                     num_plan_votes: int = 5, num_story_votes: int = 5,
                     existing_dataset: List[Narrative] = None) -> List[Narrative]:
    """
    Generate a dataset of narratives using the specified prompting strategy.
    
    Args:
    N (int): Number of narratives to generate
    llm_type (str): Type of language model to use
    model (str): Specific model to use
    prompt_type (str): Type of prompt to use ('simple', 'cot', or 'tot')
    num_plans (int): Number of story plans to generate (for ToT)
    num_stories (int): Number of stories to generate from the best plan (for ToT)
    num_plan_votes (int): Number of voting rounds for plans (for ToT)
    num_story_votes (int): Number of voting rounds for stories (for ToT)
    
    Returns:
    List[Narrative]: A list of generated narratives with their labels and constraints
    """

    if existing_dataset is None:
        existing_dataset = []
    
    new_narratives = []
    
    if prompt_type not in ['simple', 'cot', 'tot']:
        raise ValueError("prompt_type must be either 'simple' or 'cot' or 'tot")

    try:
        for _ in tqdm(range(N), desc="Generating narratives", unit="narrative"):
            # Sample number of characters and narrative length
            num_characters = random.randint(1, 5)
            narrative_length = random.choice([5, 10, 15, 20])

            # Sample genre and title
            genre = random.choice(list(GENRES_AND_TITLES.keys()))
            title = random.choice(GENRES_AND_TITLES[genre])

            # Assign character roles
            roles = character_role_sampling(num_characters)

            # Sample portrayal labels for each character
            labels = []
            for _ in range(num_characters):
                character_labels = {}
                for dimension, values in PORTRAYAL_LABELS:
                    character_labels[dimension] = random.choice(values)
                labels.append(character_labels)

            # Compile constraints
            constraints = {
                'num_characters': num_characters,
                'narrative_length': narrative_length,
                'roles': roles,
                'genre': genre,
                'title': title,
                'labels': labels
            }

            # Placeholder instructions (these should be replaced with actual instructions)
            implicit_instructions = "Implicit portrayal constraint instructions"
            socio_instructions = "Socio-demographic constraint instructions"

            if prompt_type == 'tot':
                narrative_text, metadata = generate_narrative_tot(
                    constraints, llm_type, model, 
                    num_plans, num_stories, num_plan_votes, num_story_votes
                )
            else:
                narrative_text, metadata = generate_narrative(
                    constraints=constraints,
                    implicit_instructions="Implicit portrayal constraint instructions",
                    socio_instructions="Socio-demographic constraint instructions",
                    llm_type=llm_type,
                    model=model,
                    use_cot=(prompt_type == 'cot')
                )

            new_narratives.append(Narrative(narrative_text, labels, constraints, metadata))

    except Exception as e:
        print(f"An error occurred: {str(e)}")
        print("Returning the narratives generated so far.")

    return new_narratives


def sample_and_generate(df: pd.DataFrame, K: int, llm_type: str, model: str, use_cot: bool = False) -> pd.DataFrame:
    """
    Randomly sample K narratives where narrative_simple or narrative_cot is empty (based on use_cot flag),
    generate narratives using simple or CoT prompt, and fill in the corresponding fields.

    Args:
    df (pd.DataFrame): The dataframe containing the narratives
    K (int): Number of narratives to sample and generate
    llm_type (str): Type of language model to use
    model (str): Specific model to use
    use_cot (bool): Whether to use Chain of Thought prompt (True) or simple prompt (False)

    Returns:
    pd.DataFrame: Updated dataframe with new narratives
    """
    # Determine which column to check and fill based on use_cot flag
    narrative_column = 'narrative_cot' if use_cot else 'narrative_simple'
    metadata_column = 'metadata_cot' if use_cot else 'metadata_simple'
    
    # Filter rows where the relevant narrative column is empty and the model matches
    empty_narratives = df[(df[narrative_column].isnull()) & (df['model'] == model)]
    
    # If there are fewer than K rows with empty narratives for this model, adjust K
    K = min(K, len(empty_narratives))
    
    if K == 0:
        print(f"No matching narratives found for model {model} with empty {narrative_column}.")
        return df
    
    # Randomly sample K rows
    sampled_indices = random.sample(list(empty_narratives.index), K)
    
    prompt_type = "CoT" if use_cot else "simple"
    for idx in tqdm(sampled_indices, desc=f"Generating {prompt_type} narratives for {model}"):
        row = df.loc[idx]
        
        # Reconstruct constraints
        constraints = {
            'num_characters': row['num_characters'],
            'narrative_length': row['narrative_length'],
            'roles': row['roles'].split(', '),
            'genre': row['genre'],
            'title': row['title'],
            'labels': [
                {
                    'logical_intelligence': li,
                    'appearance': app,
                    'power': pow
                }
                for li, app, pow in zip(
                    row['logical_intelligence'].split(', '),
                    row['appearance'].split(', '),
                    row['power'].split(', ')
                )
            ]
        }
        
        # Generate narrative using simple or CoT prompt
        narrative_text, metadata = generate_narrative(
            constraints=constraints,
            implicit_instructions="Implicit portrayal constraint instructions",
            socio_instructions="Socio-demographic constraint instructions",
            llm_type=llm_type,
            model=model,
            use_cot=use_cot
        )
        
        # Update the dataframe
        df.at[idx, narrative_column] = narrative_text
        df.at[idx, metadata_column] = str(metadata)
    
    return df

def sample_and_generate_distributed(df: pd.DataFrame, K: int, llm_type: str, model: str, use_cot: bool,
                                    global_rank: int, world_size: int, local_rank: int) -> List[Dict[str, Any]]:
    """
    Randomly sample K narratives across all processes, generate narratives using simple or CoT prompt,
    and return the results for later combination.

    Args:
    df (pd.DataFrame): The dataframe containing the narratives
    K (int): Total number of narratives to sample and generate across all processes
    llm_type (str): Type of language model to use
    model (str): Specific model to use
    use_cot (bool): Whether to use Chain of Thought prompt (True) or simple prompt (False)
    global_rank (int): Global rank of the current process
    world_size (int): Total number of processes
    local_rank (int): Local rank of the current process

    Returns:
    List[Dict[str, Any]]: List of generated narratives and their metadata
    """
    narrative_column = 'narrative_cot' if use_cot else 'narrative_simple'
    metadata_column = 'metadata_cot' if use_cot else 'metadata_simple'
    
    # Filter rows where the relevant narrative column is empty and the model matches
    empty_narratives = df[(df[narrative_column].isnull()) & (df['model'] == model)]
    
    # Adjust K if there are fewer eligible narratives than requested
    K = min(K, len(empty_narratives))
    
    if K == 0:
        print(f"No matching narratives found for model {model} with empty {narrative_column}.")
        return []
    
    # Determine the number of narratives this process should handle
    local_K = K // world_size + (1 if global_rank < K % world_size else 0)
    
    # Randomly sample indices, ensuring no overlap between processes
    all_indices = list(empty_narratives.index)
    random.seed(42)  # Ensure all processes use the same random seed
    random.shuffle(all_indices)
    start_idx = (K // world_size) * global_rank + min(global_rank, K % world_size)
    end_idx = start_idx + local_K
    local_indices = all_indices[start_idx:end_idx]
    
    new_narratives = []
    prompt_type = "CoT" if use_cot else "simple"
    for idx in tqdm(local_indices, desc=f"Process {global_rank} generating {prompt_type} narratives for {model}"):
        row = df.loc[idx]
        
        # Reconstruct constraints
        constraints = {
            'num_characters': row['num_characters'],
            'narrative_length': row['narrative_length'],
            'roles': row['roles'].split(', '),
            'genre': row['genre'],
            'title': row['title'],
            'labels': [
                {
                    'logical_intelligence': li,
                    'appearance': app,
                    'power': pow
                }
                for li, app, pow in zip(
                    row['logical_intelligence'].split(', '),
                    row['appearance'].split(', '),
                    row['power'].split(', ')
                )
            ]
        }
        
        # Generate narrative using simple or CoT prompt
        narrative_text, metadata = generate_narrative(
            constraints=constraints,
            implicit_instructions="Implicit portrayal constraint instructions",
            socio_instructions="Socio-demographic constraint instructions",
            llm_type=llm_type,
            model=model,
            use_cot=use_cot
        )
        
        new_narratives.append({
            'index': idx,
            'narrative_column': narrative_column,
            'metadata_column': metadata_column,
            'narrative': narrative_text,
            'metadata': str(metadata)
        })
    
    return new_narratives

import pandas as pd
import numpy as np

def fill_missing_narratives_distributed(df: pd.DataFrame, K: int, llm_type: str, model: str,
                                        global_rank: int, world_size: int, local_rank: int) -> List[Dict[str, Any]]:
    """
    Find narratives where either narrative_cot or narrative_simple is NaN but not both,
    generate narratives for the NaN values, and return the results for later combination.

    Args:
    df (pd.DataFrame): The dataframe containing the narratives
    K (int): Total number of narratives to process across all processes
    llm_type (str): Type of language model to use
    model (str): Specific model to use
    global_rank (int): Global rank of the current process
    world_size (int): Total number of processes
    local_rank (int): Local rank of the current process

    Returns:
    List[Dict[str, Any]]: List of generated narratives and their metadata
    """
    # Filter rows where exactly one of narrative_cot or narrative_simple is NaN and the model matches
    missing_narratives = df[
        ((df['narrative_cot'].isna() & df['narrative_simple'].notna()) |
         (df['narrative_cot'].notna() & df['narrative_simple'].isna())) &
        (df['model'] == model)
    ]
    
    # Adjust K if there are fewer eligible narratives than requested
    K = min(K, len(missing_narratives))
    
    if K == 0:
        print(f"No matching narratives found for model {model} with one missing narrative type.")
        return []
    
    # Determine the number of narratives this process should handle
    local_K = K // world_size + (1 if global_rank < K % world_size else 0)
    
    # Randomly sample indices, ensuring no overlap between processes
    all_indices = list(missing_narratives.index)
    random.seed(42)  # Ensure all processes use the same random seed
    random.shuffle(all_indices)
    start_idx = (K // world_size) * global_rank + min(global_rank, K % world_size)
    end_idx = start_idx + local_K
    local_indices = all_indices[start_idx:end_idx]
    
    new_narratives = []
    for idx in tqdm(local_indices, desc=f"Process {global_rank} filling missing narratives for {model}"):
        row = df.loc[idx]
        
        # Determine which narrative type is missing
        use_cot = pd.isna(row['narrative_cot'])
        narrative_column = 'narrative_cot' if use_cot else 'narrative_simple'
        metadata_column = 'metadata_cot' if use_cot else 'metadata_simple'
        prompt_type = "CoT" if use_cot else "simple"
        
        # Reconstruct constraints
        constraints = {
            'num_characters': row['num_characters'],
            'narrative_length': row['narrative_length'],
            'roles': row['roles'].split(', '),
            'genre': row['genre'],
            'title': row['title'],
            'labels': [
                {
                    'logical_intelligence': li,
                    'appearance': app,
                    'power': pow
                }
                for li, app, pow in zip(
                    row['logical_intelligence'].split(', '),
                    row['appearance'].split(', '),
                    row['power'].split(', ')
                )
            ]
        }
        
        # Generate narrative using simple or CoT prompt
        narrative_text, metadata = generate_narrative(
            constraints=constraints,
            implicit_instructions="Implicit portrayal constraint instructions",
            socio_instructions="Socio-demographic constraint instructions",
            llm_type=llm_type,
            model=model,
            use_cot=use_cot
        )
        
        new_narratives.append({
            'index': idx,
            'narrative_column': narrative_column,
            'metadata_column': metadata_column,
            'narrative': narrative_text,
            'metadata': str(metadata)
        })
    
    return new_narratives

def regenerate_tot_narratives_distributed(df: pd.DataFrame, K: int, llm_type: str, model: str,
                                          global_rank: int, world_size: int, local_rank: int,
                                          num_plans: int = 3, num_stories: int = 3, 
                                          num_plan_votes: int = 5, num_story_votes: int = 5) -> List[Dict[str, Any]]:
    """
    Find narratives where both narrative_simple and narrative_cot are not NaN,
    regenerate narratives using ToT, and return the results for later combination.
    """
    # Filter rows where both narrative_simple and narrative_cot are not NaN and the model matches
    eligible_narratives = df[
        (df['narrative_simple'].notna() & df['final_story_cot'].notna()) &
        (df['model'] == model)
    ]
    
    # Adjust K if there are fewer eligible narratives than requested
    K = min(K, len(eligible_narratives))
    
    if K == 0:
        print(f"No matching narratives found for model {model} with both narrative types.")
        return []
    
    # Determine the number of narratives this process should handle
    local_K = K // world_size + (1 if global_rank < K % world_size else 0)
    
    # Randomly sample indices, ensuring no overlap between processes
    all_indices = list(eligible_narratives.index)
    random.seed(42)  # Ensure all processes use the same random seed
    random.shuffle(all_indices)
    start_idx = (K // world_size) * global_rank + min(global_rank, K % world_size)
    end_idx = start_idx + local_K
    local_indices = all_indices[start_idx:end_idx]
    
    new_narratives = []
    for idx in tqdm(local_indices, desc=f"Process {global_rank} generating ToT narratives for {model}"):
        row = df.loc[idx]
        
        # Reconstruct constraints
        constraints = {
            'num_characters': row['num_characters'],
            'narrative_length': row['narrative_length'],
            'roles': row['roles'].split(', '),
            'genre': row['genre'],
            'title': row['title'],
            'labels': [
                {
                    'logical_intelligence': li,
                    'appearance': app,
                    'power': pow
                }
                for li, app, pow in zip(
                    row['logical_intelligence'].split(', '),
                    row['appearance'].split(', '),
                    row['power'].split(', ')
                )
            ]
        }
        
        # Generate narrative using ToT prompt
        narrative_text, metadata = generate_narrative_tot(
            constraints, llm_type, model, 
            num_plans, num_stories, num_plan_votes, num_story_votes
        )
        
        new_narratives.append({
            'index': idx,
            'narrative_column': 'narrative_tot_v2',  # Updated
            'metadata_column': 'metadata_tot_v2',    # Updated
            'narrative': narrative_text,
            'metadata': str(metadata)
        })
    
    return new_narratives

try: 
    
    global_rank, world_size, node_name, local_rank = setup()

    SEED = 42  # Replace with your desired seed
    set_seed(SEED)
    set_transformers_seed(SEED)
    slurm_jobid = get_slurm_job_id()

    # Usage example
    quick_test(local_rank, world_size)
    print('Waiting', flush=True)
    dist.barrier()

    # REGEN TOT 

    filename = ""
    df = pd.read_csv(filename)

    llm_type = 'openai'
    model = 'gpt-4o-mini-2024-07-18'
    K = 400  # number of narratives to regenerate using ToT

    # ToT hyperparameters
    num_plans = 1 
    num_stories = 1 
    num_plan_votes = 1
    num_story_votes = 1

    new_narratives = regenerate_tot_narratives_distributed(df, K, llm_type, model,
                                                           global_rank, world_size, local_rank,
                                                           num_plans, num_stories, num_plan_votes, num_story_votes)
    print(f"Process {global_rank} generated {len(new_narratives)} new ToT narratives.")

    # Serialize the generated stories
    serialized_data = serialize_data(new_narratives)

    # Gather the serialized stories on all processes
    gathered_tensors = all_gather_data(global_rank, local_rank, world_size, serialized_data)
    
    print(f'Finished gathering tensors. Process {global_rank}. Length = {len(gathered_tensors)}', flush=True)
    
    # Deserialize the gathered data
    all_narratives = deserialize_data(gathered_tensors)
    print(f'Process {global_rank}. Length of all_narratives = {len(all_narratives)}', flush=True)

    if global_rank == 0:
            # Add new columns if they don't exist
            if 'narrative_tot_v2' not in df.columns:
                df['narrative_tot_v2'] = np.nan
            if 'metadata_tot_v2' not in df.columns:
                df['metadata_tot_v2'] = np.nan
            
            # Update the original dataframe with new ToT narratives
            for narrative in all_narratives:
                idx = narrative['index']
                df.at[idx, narrative['narrative_column']] = narrative['narrative']
                df.at[idx, narrative['metadata_column']] = narrative['metadata']
            
            # Save the updated dataframe
            df.to_csv(filename, index=False)
            # df.to_csv(filename_debug, index=False)
            print(f"Updated dataset with ToT_v2 narratives written to {filename}")
    
    raise AssertionError


except Exception as e:
     # Handle any exceptions that occur
    
    print(f'Process. An error occurred:', flush=True)
    print(str(e), flush=True)
    print('Traceback:', flush=True)
    traceback.print_exc()

finally:
    if global_rank == 0:
        print(f'Process {global_rank}. Cleaning up', flush=True)
        cleanup()
    # Close the file and restore stdout and stderr
    sys.stdout.close()
    sys.stdout = sys.__stdout__
    sys.stderr = sys.__stderr__







