import pathlib
import google.generativeai as genai
import json
import tiktoken
import textwrap
import openai
import os
from PIL import Image
import base64
import time
import random
import logging
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Any, Union, Callable


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("cluster_assignment")

# Get API key from environment variable
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
    raise ValueError("No OpenAI API key found. Please set the OPENAI_API_KEY.")


# Get all trajectory data with filename
def get_trajectory_data(file_path):
    trajectory_data = []
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            if isinstance(data, dict) and 'trajectory' in data and 'filename' in data:
                trajectory_data.append({
                    "filename": data['filename'],
                    "trajectory": data['trajectory'],
                    "failure_reason": data.get('failure_reason', '')
                })
    return trajectory_data


# Retry decorator with exponential backoff
def retry_with_exponential_backoff(
    initial_delay: float = 1,
    exponential_base: float = 2,
    jitter: bool = True,
    max_retries: int = 10,
    errors: tuple = (openai.error.Timeout, openai.error.APIError, openai.error.APIConnectionError, 
                    openai.error.RateLimitError, openai.error.ServiceUnavailableError)
):
    """Retry a function with exponential backoff."""
    
    def decorator(func):
        def wrapper(*args, **kwargs):
            # Initialize variables
            num_retries = 0
            delay = initial_delay
            
            # Loop until a successful response or max_retries is hit
            while True:
                try:
                    return func(*args, **kwargs)
                
                # Retry on specified errors
                except errors as e:
                    # Increment retries
                    num_retries += 1
                    
                    # Check if max retries has been reached
                    if num_retries > max_retries:
                        raise Exception(f"Maximum number of retries ({max_retries}) exceeded: {str(e)}")
                    
                    # Adjust delay with jitter if enabled
                    delay_with_jitter = delay
                    if jitter:
                        delay_with_jitter = delay * (random.random() + 0.5)  # between 0.5 and 1.5
                    
                    # Log the exception and retry
                    logger.warning(f"Request failed with error: {str(e)}")
                    logger.warning(f"Retrying in {delay_with_jitter:.2f} seconds... (Attempt {num_retries}/{max_retries})")
                    
                    # Sleep with jitter
                    time.sleep(delay_with_jitter)
                    
                    # Increase delay for next iteration
                    delay *= exponential_base
        
        return wrapper
    
    return decorator


# Wrapper for OpenAI API calls with retry logic
@retry_with_exponential_backoff(initial_delay=4, max_retries=8)
def call_openai_api(messages, model="o4-mini", functions=None, function_call=None, timeout=900):
    """Call OpenAI API with retry mechanism"""
    # Update default timeout to be longer
    openai.api_requestor.TIMEOUT_SECS = timeout
    
    logger.info(f"Making API request to OpenAI model: {model}")
    
    # Prepare the API call parameters
    kwargs = {
        "model": model,
        "messages": messages,
    }
    
    # Add function calling if provided
    if functions:
        kwargs["functions"] = functions
        if function_call:
            kwargs["function_call"] = function_call
    
    # Make the API call
    response = openai.ChatCompletion.create(**kwargs)
    return response


# Define Pydantic models for structured output
class ClusterAssignment(BaseModel):
    cluster_name: str = Field(..., description="Name of the cluster this trajectory belongs to")
    confidence: Optional[float] = Field(None, description="Confidence score for this assignment (0-1)")
    reasoning: Optional[str] = Field(None, description="Brief explanation for why this assignment was made")


class TrajectoryClusterAssignment(BaseModel):
    filename: str = Field(..., description="The filename/ID of the trajectory")
    trajectory: str = Field(..., description="The trajectory description")
    assignments: List[ClusterAssignment] = Field(..., description="The cluster assignments for this trajectory")


# Define models for cluster data structure
class Cluster(BaseModel):
    cluster_name: str = Field(..., description="Name of the cluster")
    occurrence: str = Field(..., description="Percentage or frequency information if available")
    keywords: List[str] = Field(..., description="A list of keywords associated with the cluster")
    notes: str = Field(..., description="Any additional information or descriptions about the cluster")


class ClustersResponse(BaseModel):
    clusters: List[Cluster]



# clusters_path = "../results/clustering/driving/clusters_prompt_ensemble/aggregated_clusters_text.jsonl"
# trajectory_data_path = "../results/clustering/driving/failure_description_resoning_gemini25pro_combined.jsonl"
clusters_path = "../results/clustering/waypointnav/clusters_prompt_ensemble/aggregated_clusters_text.jsonl"
trajectory_data_path = "../results/clustering/waypointnav/failure_description_resoning_gemini25pro_combined.jsonl"

output_filepath = clusters_path.replace(".jsonl", "_item_assignment.jsonl")

# Read the clusters and parse into Pydantic models
clusters = []
with open(clusters_path, 'r') as f:
    for line in f:
        cluster_data = json.loads(line.strip())
        cluster = Cluster(**cluster_data)
        clusters.append(cluster)

# Get cluster names and keywords for the prompt with more detailed information
cluster_options = []
for cluster in clusters:
    # Include more structured information about each cluster
    cluster_info = f"{cluster.cluster_name} : {', '.join(cluster.keywords)}"
    # Add notes if they provide useful context
    if cluster.notes:
        cluster_info += f" — {cluster.notes}"
    cluster_options.append(cluster_info)

# Add an "Other" option for trajectories that don't fit existing clusters
cluster_options.append("Other: For trajectories that don't clearly fit into any of the defined clusters above")

# Get trajectory data with filenames
trajectory_data = get_trajectory_data(trajectory_data_path)

# Create the base prompt with more detailed cluster information
prompt_template_waypoint = f"""
You are classifying robot trajectory descriptions into predefined clusters based on failure types.
Assign the trajectory to one or more of the following clusters:

{chr(10).join([f"- {option}" for option in cluster_options])}

Analyze the trajectory description and identify which type(s) of failures occurred.
Consider the keywords and notes for each cluster to help with your classification.
Use the "Other" category only when the trajectory doesn't reasonably fit into any of the existing clusters.
"""

prompt_template_driving = f"""
You are classifying car trajectory descriptions into predefined clusters based on failure types.
Assign the trajectory to one or more of the following clusters:

{chr(10).join([f"- {option}" for option in cluster_options])}

Analyze the trajectory description and identify which type(s) of failures occurred.
Consider the keywords and notes for each cluster to help with your classification.
Use the "Other" category only when the trajectory doesn't reasonably fit into any of the existing clusters.
"""

prompt_template = prompt_template_waypoint if "wayptnav" in trajectory_data_path else prompt_template_driving

# Define function schema for structured output
fn_def = {
    "name": "assign_trajectory_to_clusters",
    "description": "Assign a single trajectory to one or more appropriate clusters",
    "parameters": TrajectoryClusterAssignment.model_json_schema()
}

# Process each trajectory individually and write to file
results = []
processed_filenames = set()

# Check if output file already exists and load processed filenames
if os.path.exists(output_filepath):
    with open(output_filepath, 'r') as infile:
        for line in infile:
            try:
                data = json.loads(line.strip())
                if 'filename' in data:
                    processed_filenames.add(data['filename'])
            except json.JSONDecodeError:
                continue
    print(f"Found {len(processed_filenames)} already processed trajectories")

with open(output_filepath, 'a' if processed_filenames else 'w') as outfile:
    for item in trajectory_data:
        # Skip if this filename has already been processed
        if item['filename'] in processed_filenames:
            print(f"Skipping already processed trajectory: {item['filename']}")
            continue
            
        # Create messages for this specific trajectory
        messages = [
            {"role": "system", "content": prompt_template},
            {"role": "user", "content": f"Classify this trajectory: {item['trajectory']}"}
        ]
        
        # Call the OpenAI API for this trajectory
        response = call_openai_api(
            messages=messages,
            model="o4-mini",
            functions=[fn_def],
            function_call={"name": "assign_trajectory_to_clusters"}
        )
        
        # Extract and validate the structured response
        message = response.choices[0].message
        if "function_call" in message:
            args_json = message["function_call"]["arguments"]
            assignment = TrajectoryClusterAssignment.model_validate_json(args_json)
            
            # Ensure filename is set
            assignment.filename = item["filename"]
            
            # Add to processed set
            processed_filenames.add(item["filename"])
            
            # Write to file as we go
            outfile.write(assignment.model_dump_json() + "\n")
            
            # Store in results list
            results.append(assignment)
            
            print(f"Processed trajectory {item['filename']}: {[a.cluster_name for a in assignment.assignments]}")
        
        # Small delay to avoid rate limiting
        time.sleep(0.5)

print(f"Output saved to {output_filepath}")
print(f"Processed {len(results)} trajectories in total")