import os
import base64
from PIL import Image
import numpy as np
import cv2
import json
import time
from dotenv import load_dotenv
import google.generativeai as genai
from pathlib import Path
import re
import argparse


GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
if not GOOGLE_API_KEY:
    raise ValueError("No Gemini API key found. Please set the GOOGLE_API_KEY environment variable.")


genai.configure(api_key=GOOGLE_API_KEY)

# Prompt for waypointnav case study
waypointnav_cot_reason_prompt = "Provide a description of the trajectory of a robot from the sequence of images it " \
"observed along its path, knowing that it collides in the last image. After that, provide the visual semantic " \
"reasons behind its failure in brief. Pay attention to the surrounding objects. You must provide your answer in " \
"the following format -- trajectory : <trajectory_description> \n failure_reason : <semantic_failure_reason>, " \
"where <trajectory_description> is the description of its trajectory and <semantic_failure_reason> is the semantic " \
"reason behind failure."

# Prompt for driving case study
driving_cot_reason_prompt = "Describe the trajectory of a car from the sequence of images it observed along " \
"its path, knowing that it undergoes a collision. After that, provide the visual semantic reason behind its failure in brief. " \
"Pay attention to the surrounding object, other vehicles, and environment conditions. You must provide your answer in " \
"the following format -- trajectory : <trajectory_description> \n failure_reason : <semantic_failure_reason>, " \
"where <trajectory_description> is the description of its trajectory and <semantic_failure_reason> is the semantic " \
"reason behind failure."


def encode_image_to_bytes(image_path):
    """Load an image file and return it as raw bytes for Gemini API"""
    img = Image.open(image_path)
    return img

def encode_image_to_base64(image_path):
    with open(image_path, 'rb') as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')

def count_tokens(messages):
    """Estimates the token count for a message (simplified function)"""
    text_content = ' '.join([msg.get('text', '') for msg in messages if isinstance(msg, dict) and 'text' in msg])
    # Rough estimation: 1 token ≈ 4 chars for English text
    return len(text_content) // 4

# Define a natural sort key function to handle filenames with mixed text and numbers
def natural_sort_key(s):
    """Sort filenames with mixed text and numbers in natural order (e.g., img1, img2, img10)."""
    return [int(text) if text.isdigit() else text.lower() 
            for text in re.split(r'(\d+)', s)]

def process_image_with_vision(image_paths, prompt):
    """Process images with Gemini Vision API"""
    print(image_paths)
    
    # Load all images in the directory
    images = []
    
    # Sort files using natural sort order to ensure correct sequence ordering
    sorted_paths = sorted(os.listdir(image_paths), key=natural_sort_key)
    
    # If there are more than 10 images, use only the last 10
    if len(sorted_paths) > 10:
        print(f"Found {len(sorted_paths)} images, using only the last 10.")
        sorted_paths = sorted_paths[-10:]
    
    for path in sorted_paths:
        full_path = os.path.join(image_paths, path)
        try:
            img = encode_image_to_bytes(full_path)
            images.append(img)
        except Exception as e:
            print(f"Error loading image {full_path}: {e}")
    
    try:
        # Configure the model
        model = genai.GenerativeModel('gemini-2.5-pro-preview-03-25')
        # 'gemini-2.5-pro-exp-03-25'
        # gemini-2.5-flash-preview-04-17
        
        # Prepare the prompt with images
        response = model.generate_content([prompt, *images])
        
        # Extract and return the response text
        output = response.text
        
        # Try to print some token information if available
        try:
            usage_metadata = response.usage_metadata
            if usage_metadata:
                print(f"Input tokens: {usage_metadata.prompt_token_count}")
                print(f"Output tokens: {usage_metadata.candidates_token_count}")
                print(f"Total tokens: {usage_metadata.total_token_count }")
        except (AttributeError, IndexError) as e:
            print(f"Could not get token usage data: {e}")
            
        return output
    
    except Exception as e:
        print(f"Error calling Gemini API: {e}")
        return None

def is_already_processed(filename, output_file):
    """Check if the filename has already been processed in the output file"""
    if not os.path.exists(output_file):
        return False
    
    with open(output_file, "r") as f:
        content = f.read()
        # Use a regex pattern to match the exact filename followed by a newline or end of string
        pattern = f"Filename: {re.escape(filename)}\n"
        return re.search(pattern, content) is not None

def process_images_in_order(folder_path, num_samples, output_file, prompt):
    """Process multiple image sequences in order and save results to a file"""
    # Get a sorted list of directories (each containing a sequence of images)
    samples = sorted(os.listdir(folder_path))[:num_samples]
    
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    print(f"Using prompt: {prompt}")
    print(f"Number of samples: {num_samples}")
    print(f"Output file: {output_file}")
    print(f"Samples: {samples}")
    print(f"Folder path: {folder_path}")
    print(f"Output file: {output_file}")
    
    # Open output file in append mode
    with open(output_file, "a") as f:
        for i, sub_dir in enumerate(samples):
            # Check if this file has already been processed
            if is_already_processed(sub_dir, output_file):
                print(f"Skipping {i+1}/{num_samples}: {sub_dir} (already processed)")
                continue
                
            image_path = os.path.join(folder_path, sub_dir)
            print(f"Processing {i+1}/{num_samples}: {sub_dir}")
            
            # Process the image sequence with Gemini Vision API
            result = process_image_with_vision(image_path, prompt)
            
            # Log results periodically
            if i % 5 == 0:
                print(f"Result for {sub_dir}:")
                print(result)
                print("")
                
            # Write results to file
            if result:
                f.write(f"Filename: {sub_dir}\n")
                f.write(f"Description: {result}\n\n")
            else:
                print(f"Error: Failed to generate description for {sub_dir}")
                
            # Add a delay to avoid rate limiting
            time.sleep(10)

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(
        description='Process image sequences with Gemini Vision API for failure analysis'
    )
    
    parser.add_argument(
        '--experiment', '-e',
        type=str, 
        # required=True,
        choices=['waypointnav', 'nexar'],
        default='waypointnav',
        help='Type of experiment to run (waypointnav, nexar)'
    )
    
    parser.add_argument(
        '--samples', '-s', 
        type=int, 
        default=None,
        help='Number of samples to process (default: all available)'
    )
    
    return parser.parse_args()

if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Map experiment type to folder path
    experiment_folders = {
        'waypointnav': "../failure_clustering_datasets/waypointnav/area1/patch0/positive/",
        'nexar': "../failure_clustering_datasets/nexar/train_croped_3fps_croped_top85_bottom60_imgs/",
    }
    
    # Map experiment type to prompt
    experiment_prompts = {
        'waypointnav': waypointnav_cot_reason_prompt,
        'nexar': driving_cot_reason_prompt,  # Using ego_crash prompt for nexar
    }
    
    folder_path = experiment_folders[args.experiment]
    prompt = experiment_prompts[args.experiment]
    
    # Determine number of samples
    total_samples = len(os.listdir(folder_path))
    num_samples = args.samples if args.samples else total_samples
    
    # Determine output path
    output_path = '../results/clustering/' + args.experiment
        
    output_file = os.path.join(output_path, 'failure_description_resoning_gemini25pro.txt')

    print(f"Experiment: {args.experiment}")
    print(f"Using prompt: {prompt}")
    print(f"Processing {num_samples} out of {total_samples} image sequences with Gemini API")
    print(f"Results will be saved to {output_file}")
    
    process_images_in_order(folder_path, num_samples, output_file, prompt)