import os
import boto3
import json
import utils
import csv
from tqdm import tqdm
from botocore.config import Config

def setup_bedrock_client():
    # Configure timeouts and retries
    config = Config(
        region_name="us-east-1",
        retries={"max_attempts": 5, "mode": "adaptive"}
    )
    # Create a session and Bedrock runtime client
    session = boto3.Session(profile_name='default')
    return session.client("bedrock-runtime", config=config)

def analyze_emotional_scores(transcript_text, brt_client):
    """Analyzes transcript using Bedrock to get emotional intensity scores."""
    prompt = """You are analyzing a therapy session that consists of six distinct conversations. Each conversation represents a specific stage of Prolonged Exposure (PE) therapy. For each conversation, rate the client's emotional intensity on a scale from 1 (calm) to 10 (extreme distress). Consider expressions of anxiety, fear, sadness, or relief. Provide only the six numerical scores in the order presented below.

    Here's the transcript:
    {transcript}

    Conversation Breakdown:

    Conversation 1 (Orientation to Imaginal Exposure):
    The therapist prepares the client for imaginal exposure, discussing expectations and addressing initial anxiety.

    # For brevity, Conversations 2–5 are omitted. Full content will be shared upon paper acceptance.


    Conversation 6 (Processing the Imaginal):
    The therapist and client process the experience together, discussing insights and emotional reactions.

    Respond only with a JSON array containing exactly 6 numbers between 1 and 10, like this:
    {{
        "emotional_scores": [3, 8, 7, 6, 5, 2]
    }}
    """
    
    full_prompt = prompt.format(transcript=transcript_text)
    try:
        response = utils.call_bedrock(full_prompt)
        
        if response:
            try:
                # Find the first '{' and last '}'
                start_idx = response.find('{')
                end_idx = response.rfind('}') + 1
                if start_idx != -1 and end_idx != 0:
                    json_str = response[start_idx:end_idx]
                    # Parse the JSON to ensure it's valid
                    scores_data = json.loads(json_str)
                    return scores_data.get('emotional_scores')
            except json.JSONDecodeError as e:
                print(f"Error parsing JSON response: {e}")
                return None
        return None
    except Exception as e:
        print(f"Error calling Bedrock: {e}")
        return None

def process_json_file(file_path):
    """Reads and processes a .json file."""
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
            # Convert the conversation array into a single string
            if "full_conversation" in data:
                return "\n".join(data["full_conversation"])
            return None
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None

def get_processed_files(csv_path):
    """Get list of already processed files from existing CSV."""
    processed = set()
    if os.path.exists(csv_path):
        with open(csv_path, 'r', newline='', encoding='utf-8') as csvfile:
            reader = csv.reader(csvfile)
            next(reader)  # Skip header
            for row in reader:
                if row:  # Check if row is not empty
                    processed.add(row[0])  # First column is filename
    return processed

def main():
    # Initialize Bedrock client
    brt_client = setup_bedrock_client()

    # Directory paths
    input_dir = 'Synthetic_Combined/'
    output_dir = 'Emotional_Scores/'
    output_file = os.path.join(output_dir, 'emotional_scores.csv')
    backup_file = os.path.join(output_dir, 'emotional_scores_backup.csv')
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # CSV header
    headers = ['filename', 'orientation_score',  'processing_score'] # For brevity, headers are omitted. Full content will be shared upon paper acceptance.
    
    try:
        # Get all .json files in the directory
        json_files = [f for f in os.listdir(input_dir) if f.endswith('_combined_conversation.json')]
        json_files.sort(key=lambda x: int(x.split('_')[0]))

        if not json_files:
            print("No combined conversation JSON files found in the input directory.")
            return

        # Check for existing progress
        processed_files = get_processed_files(output_file)
        if processed_files:
            print(f"Found {len(processed_files)} previously processed files")
            proceed = input("Continue processing remaining files? (y/n): ")
            if proceed.lower() != 'y':
                print("Exiting...")
                return
            
        # Filter out already processed files
        remaining_files = [f for f in json_files if f not in processed_files]
        if not remaining_files:
            print("All files have been processed already!")
            return
            
        print(f"Processing {len(remaining_files)} remaining files...")

        # Open files in append mode if they exist, write mode if they don't
        file_mode = 'a' if processed_files else 'w'
        with open(output_file, file_mode, newline='', encoding='utf-8') as csvfile, \
             open(backup_file, file_mode, newline='', encoding='utf-8') as backup_csvfile:
            
            writer = csv.writer(csvfile)
            backup_writer = csv.writer(backup_csvfile)
            
            # Write headers only if starting fresh
            if not processed_files:
                writer.writerow(headers)
                backup_writer.writerow(headers)
            
            # Track progress and success/failure counts
            processed = 0
            failed = 0
            batch = []
            
            for file_name in tqdm(remaining_files, desc="Processing transcripts"):
                try:
                    # Read the transcript
                    file_path = os.path.join(input_dir, file_name)
                    transcript_text = process_json_file(file_path)
                    
                    if transcript_text is None:
                        print(f"Skipping {file_name} due to file reading error.")
                        failed += 1
                        continue
                    
                    # Generate analysis
                    scores = analyze_emotional_scores(transcript_text, brt_client)
                    
                    if scores is None or len(scores) != 6:
                        print(f"Skipping {file_name} due to invalid scores.")
                        failed += 1
                        continue
                    
                    # Add to batch
                    row = [file_name] + scores
                    batch.append(row)
                    processed += 1
                    
                    # Write batch every 10 files or at the end
                    if len(batch) >= 10 or file_name == remaining_files[-1]:
                        for row in batch:
                            writer.writerow(row)
                            backup_writer.writerow(row)
                        csvfile.flush()
                        backup_csvfile.flush()
                        print(f"\nSaved batch of {len(batch)} files. Total processed: {processed}")
                        batch = []  # Clear batch after saving
                        
                except Exception as e:
                    print(f"Error processing {file_name}: {e}")
                    failed += 1
            
            total_processed = len(processed_files) + processed
            print(f"\nProcessing complete:")
            print(f"Previously processed: {len(processed_files)} files")
            print(f"Newly processed: {processed} files")
            print(f"Total processed: {total_processed} files")
            print(f"Failed to process: {failed} files")
            print(f"Results saved to: {output_file}")
            print(f"Backup saved to: {backup_file}")

    except Exception as e:
        print(f"Fatal error: {e}")

if __name__ == "__main__":
    main()