import os
import boto3
import pandas as pd
import json
import random
from botocore.exceptions import ClientError
import utils
from tqdm import tqdm
import datetime
from botocore.config import Config

# Your existing variable definitions (age_groups, genders, etc.) remain the same
# Define variables and lists
age_groups = [(18, 30), (31, 40), (41, 50), (51, 60), (61, 70), (71, 80)]
genders = ['male', 'female', 'non-binary']
therapist_ages = [(25, 35), (35, 45), (45, 55), (55, 65)]
therapist_genders = ['male', 'female', 'non-binary']
trauma_types = ['physical abuse', 'combat or war experiences', 'loss of a loved one'] # Only 3 representative examples are included for review; the full code will be released upon acceptance.
behaviors = ['avoidance', 'hypervigilance', 'substance abuse'] # Only 3 representative examples are included for review; the full code will be released upon acceptance.
session_formats = ['individual']
session_topics = ['car accident trauma', 'domestic violence', 'military combat experience'] # Only 3 representative examples are included for review; the full code will be released upon acceptance.
ethnicities = ['South Asian']
co_occurring_conditions = ['Anxiety']

gender_weights = [0.5, 0.49, 0.01]
co_occurring_weights = [0.25, 0.25, 0.30, 0.1, 0.1]

def weighted_choice(choices):
    total = sum(w for c, w in choices)
    r = random.uniform(0, total)
    upto = 0
    for c, w in choices:
        if upto + w >= r:
            return c
        upto += w
    assert False, "Shouldn't get here"

def validate_client_profile(profile):
    if profile["age"] < 30 and profile["relationship_status"] == "Widowed":
        return False
    if profile["age"] < 60 and profile["occupation"] == "Retired":
        return False
    if profile["age"] < 70 and profile["living_situation"] == "Assisted living":
        return False
    return True

def generate_valid_client_profile():
    while True:  # Keep generating until a valid profile is found
        client_age = random.randint(18, 80)
        profile = generate_client_profile(client_age)
        if validate_client_profile(profile):
            return profile

def generate_client_profile(client_age):
    # Find appropriate age group for the client
    client_age_group = next((group for group in age_groups if group[0] <= client_age <= group[1]), age_groups[-1])
    
    # Gender selection using weights
    gender_choices = list(zip(genders, gender_weights))
    gender = weighted_choice(gender_choices)
    
    # Define initial pools
    relationship_statuses = []
    occupations = []
    living_situations = []
    
    if client_age_group[0] <= 25:
        relationship_statuses = [("Single", 65), ("In a relationship", 30), ("Married", 5)]
        occupations = [("Student", 60), ("Entry-level job", 30), ("Unemployed", 10)]
        living_situations = [("With parents", 40), ("Shared accommodation", 40), ("Alone", 20)]
    elif client_age_group[0] <= 40:
        relationship_statuses = [("Single", 30), ("In a relationship", 30), ("Married", 35), ("Divorced", 5)]
        occupations = [("Professional", 60), ("Student", 10), ("Unemployed", 5), ("Self-employed", 25)]
        living_situations = [("Alone", 30), ("With partner", 40), ("Shared accommodation", 20), ("With family", 10)]
    elif client_age_group[0] <= 60:
        relationship_statuses = [("Married", 50), ("Divorced", 20), ("Single", 15), ("Widowed", 15)]
        occupations = [("Professional", 70), ("Self-employed", 20), ("Unemployed", 5), ("Retired", 5)]
        living_situations = [("With family", 50), ("Alone", 30), ("With partner", 20)]
    else:  # 61+
        relationship_statuses = [("Married", 40), ("Widowed", 30), ("Divorced", 20), ("Single", 10)]
        occupations = [("Retired", 80), ("Part-time work", 15), ("Full-time work", 5)]
        living_situations = [("Alone", 40), ("With partner", 30), ("With family", 20), ("Assisted living", 10)]
    
    # Weighted random choices for realistic combinations
    relationship_status = weighted_choice(relationship_statuses)
    
    # Adjust living situation based on relationship status
    if relationship_status == "Married":
        living_situations = [("With partner", 60), ("With family", 30), ("Alone", 10)]
    elif relationship_status == "Widowed":
        living_situations = [("Alone", 50), ("With family", 40), ("Assisted living", 10)]
    elif relationship_status == "Single":
        living_situations = [("Alone", 40), ("Shared accommodation", 40), ("With parents", 20)]
    elif relationship_status == "Divorced":
        living_situations = [("Alone", 50), ("With family", 30), ("Shared accommodation", 20)]
    
    occupation = weighted_choice(occupations)
    living_situation = weighted_choice(living_situations)
    
    # Select random behaviors based on trauma (1-3 behaviors)
    exhibited_behaviors = random.sample(behaviors, random.randint(1, 3))
    
    ethnicity = random.choice(ethnicities)
    co_condition_choices = list(zip(co_occurring_conditions, co_occurring_weights))
    co_condition = weighted_choice(co_condition_choices)  # Use weighted choice for co-occurring conditions
    
    return {
        "age": client_age,
        "age_group": f"{client_age_group[0]}-{client_age_group[1]}",
        "gender": gender,
        "relationship_status": relationship_status,
        "occupation": occupation,
        "living_situation": living_situation,
        "ethnicity": ethnicity,
        "co_occurring_condition": co_condition,
        "exhibited_behaviors": exhibited_behaviors
    }

def generate_therapist_profile():
    """Generate a therapist profile with validated attributes."""
    therapist_age_group = random.choice(therapist_ages)
    therapist_age = random.randint(therapist_age_group[0], therapist_age_group[1])
    gender_choices = list(zip(therapist_genders, gender_weights))
    therapist_gender = weighted_choice(gender_choices)  # Use weighted choice for gender
    return {
        "age": therapist_age,
        "gender": therapist_gender,
    }

def generate_trauma_context():
    """Generate trauma type and session topic."""
    trauma_type = random.choice(trauma_types)
    session_topic = random.choice(session_topics)
    return trauma_type, session_topic

def validate_therapist_client_match(client_profile, therapist_profile):
    """Enhanced validation of therapist-client match."""
    # Age difference validation (therapist should be at least 5 years older)
    if therapist_profile["age"] - client_profile["age"] < 5:
        return False
    
    # Additional rules could be added here
    # For example, specific trauma type matching with therapist specialization
    return True

def generate_therapy_session():
    """Generate a complete therapy session profile."""
    client_profile = generate_valid_client_profile()
    therapist_profile = generate_therapist_profile()
    
    # Keep generating therapist profiles until we find a valid match
    while not validate_therapist_client_match(client_profile, therapist_profile):
        therapist_profile = generate_therapist_profile()
    
    therapy_context = generate_therapy_context()
    
    return {
        "client": client_profile,
        "therapist": therapist_profile,
        "therapy_context": therapy_context,
        "timestamp": datetime.datetime.now().isoformat()
    }

def print_therapy_session(session):
    """Print a formatted version of the therapy session details."""
    print("\n=== Therapy Session Profile ===")
    print("\nClient Information:")
    print(f"Age: {session['client']['age']} (Age Group: {session['client']['age_group']})")
    print(f"Gender: {session['client']['gender']}")
    print(f"Ethnicity: {session['client']['ethnicity']}")
    print(f"Relationship Status: {session['client']['relationship_status']}")
    print(f"Occupation: {session['client']['occupation']}")
    print(f"Living Situation: {session['client']['living_situation']}")
    print(f"Co-occurring Condition: {session['client']['co_occurring_condition']}")
    print(f"Exhibited Behaviors: {', '.join(session['client']['exhibited_behaviors'])}")
    
    print("\nTherapist Information:")
    print(f"Age: {session['therapist']['age']}")
    print(f"Gender: {session['therapist']['gender']}")
    
    print("\nTherapy Context:")
    print(f"Session Format: {session['therapy_context']['session_format']}")
    print(f"Trauma Type: {session['therapy_context']['trauma_type']}")
    print(f"Session Topic: {session['therapy_context']['session_topic']}")

def generate_therapy_context():
    """Generate comprehensive therapy context including format and trauma details."""
    trauma_type = random.choice(trauma_types)
    session_topic = random.choice(session_topics)
    session_format = random.choice(session_formats)
    
    return {
        "trauma_type": trauma_type,
        "session_topic": session_topic,
        "session_format": session_format,
    }


# Define the prompts # Only 1 representative prompt is included for review; the full set of prompts can be viewed in the paper and the full code will be released upon acceptance.
PROMPTS = {
    "P5": """You are an expert in medicine and NLP. Generate a clinical transcript for the following profiles:
    <profiles>
    {profile_info}
    </profiles>
    
    Based on these expectations:
    Generate a detailed creative dialogue where a therapist orients the client to the imaginal exposure planned for a Prolonged Exposure (PE) therapy session.
    Key Features:

    The therapist explains the purpose and rationale behind imaginal exposure in a clear and empathetic manner. 
    The therapist addresses the client’s concerns (if any), hesitations (if any), or questions about the exercise. 
    Include the therapist setting expectations for the session, including what the client might feel and how they will be supported throughout. 
    The dialogue should include the client's responses, such as questions, emotional reactions, or expressions of understanding. 
    The therapist reinforces the importance of the exercise in addressing PTSD symptoms and validates the client’s courage in participating. 
    Ensure the conversation flows naturally, with pauses, realistic emotional exchanges, and detailed descriptions of both the therapist's and client's perspectives.
    Avoid repetitive patterns like using the same emotions or phrases across responses.
    Ensure the therapist's responses are concise, and very short. The Client speaks elaborately.

    Use "Therapist:" for the therapist's lines and "Client:" for the client's responses.

    BEGIN TRANSCRIPT: Therapist:
    """
}

#def setup_bedrock_client():
#    session = boto3.Session(profile_name='default')
#    return session.client("bedrock-runtime", region_name="us-east-1")

def setup_bedrock_client():
    # Configure timeouts and retries
    config = Config(
        region_name="us-east-1",
        retries={"max_attempts": 5, "mode": "adaptive"}
    )
    # Create a session and Bedrock runtime client
    session = boto3.Session(profile_name='default')
    return session.client("bedrock-runtime", config=config)


def generate_transcript(profile_info, prompt_template, brt_client):
    """Generates transcript using Bedrock."""
    full_prompt = prompt_template.format(profile_info=profile_info)
    try:
        response = utils.call_bedrock(full_prompt)
        return response
    except Exception as e:
        print(f"Error calling Bedrock: {e}")
        return None  # Return None on error

def main():
    brt_client = setup_bedrock_client() # Initialize bedrock client
    output_dir = '/GenerateTranscript/New'  # Replace with your desired path
    os.makedirs(os.path.join(output_dir, 'conversations'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'metadata'), exist_ok=True)

    prompts = PROMPTS # Use the existing PROMPTS dictionary

    for i in tqdm(range(3, 6), desc="Generating scenarios"):
        client_profile = generate_valid_client_profile()
        therapist_profile = generate_therapist_profile()
        trauma_type, session_topic = generate_trauma_context()

        # Add gender to client profile for consistent data
        client_profile['gender'] = weighted_choice([('male', 0.5), ('female', 0.45), ('non-binary', 0.05)])

        profile_info = f"""Client Info: Age: {client_profile['age']}, Gender: {client_profile['gender']}, 
        Relationship Status: {client_profile['relationship_status']}, 
        Occupation: {client_profile['occupation']}, 
        Living Situation: {client_profile['living_situation']},
        Ethnicity: {client_profile['ethnicity']},
        Co-occurring Condition: {client_profile['co_occurring_condition']}
        Trauma Info: Trauma Type: {trauma_type}, Session Topic: {session_topic}
        Therapist Info: Age: {therapist_profile['age']}, Gender: {therapist_profile['gender']}"""

        for prompt_name, prompt in prompts.items():
            try:
                conversation = generate_transcript(profile_info, prompt, brt_client)

                if conversation is None: # Check for Bedrock errors
                    print(f"Skipping scenario {i}, prompt {prompt_name} due to Bedrock error.")
                    continue

                turns = []
                for line in conversation.split('\n'):
                    line = line.strip()
                    if line.startswith('Therapist:') or line.startswith('Client:'):
                        turns.append(line)

                three_turn_sequences = []
                for idx in range(max(0, len(turns) - 2)): # Handle cases with less than 3 turns
                    sequence = turns[idx:idx + 3]
                    three_turn_sequences.append(sequence)

                conversation_data = {
                    "full_conversation": turns,
                    "three_turn_sequences": three_turn_sequences
                }

                metadata = {
                    "client_profile": client_profile,
                    "therapist_profile": therapist_profile,
                    "trauma_info": {
                        "type": trauma_type,
                        "session_topic": session_topic
                    }
                }

                conv_filename = f'{i}_{prompt_name}_conversation.json'
                meta_filename = f'{i}_{prompt_name}_metadata.json'

                with open(os.path.join(output_dir, 'conversations', conv_filename), 'w') as f:
                    json.dump(conversation_data, f, indent=2)

                with open(os.path.join(output_dir, 'metadata', meta_filename), 'w') as f:
                    json.dump(metadata, f, indent=2)

            except Exception as e:
                print(f"Error processing scenario {i}, prompt {prompt_name}: {e}")

if __name__ == "__main__":
    main()
