#!/usr/bin/env python3
import os
import csv
import json
import openai
import argparse
import time
import random
import base64
from pathlib import Path
from datetime import datetime
from collections import defaultdict

def load_ground_truth(ground_truth_dir):
    """Load ground truth data from CSV files"""
    ground_truth = {}
    
    for view in ["frontal", "lateral"]:
        view_path = os.path.join(ground_truth_dir, f"test_gt_csvs_32x32", view)
        
        if not os.path.exists(view_path):
            print(f"Warning: Ground truth directory not found: {view_path}")
            continue
            
        for csv_file in os.listdir(view_path):
            if csv_file.endswith('.csv'):
                condition = csv_file.replace('.csv', '')
                csv_path = os.path.join(view_path, csv_file)
                
                try:
                    with open(csv_path, 'r') as f:
                        reader = csv.DictReader(f)
                        for row in reader:
                            image_name = row['Image Name']
                            significant_cells = row['Significant Overlapping Cells'].strip()
                            if significant_cells:
                                # Split cells and clean them
                                cells = [cell.strip() for cell in significant_cells.split() if cell.strip()]
                                ground_truth[image_name] = {
                                    'view': view,
                                    'condition': condition,
                                    'significant_cells': cells,
                                    'all_cells': row['All Overlapping Cells'].strip().split()
                                }
                except Exception as e:
                    print(f"Warning: Could not load ground truth from {csv_path}: {e}")
    
    print(f"Loaded ground truth for {len(ground_truth)} images")
    return ground_truth

def check_answer_correctness(image_file, response, ground_truth):
    """Check if the model's response matches ground truth"""
    # Remove file extension to match ground truth format
    image_name_base = os.path.splitext(image_file)[0]
    
    if image_name_base not in ground_truth:
        return "no_gt", "No ground truth available"
    
    gt_data = ground_truth[image_name_base]
    significant_cells = gt_data['significant_cells']
    all_cells = gt_data['all_cells']
    
    # Clean the response (remove any extra text, spaces, etc.)
    cleaned_response = response.strip().upper()
    
    # Check if response is in significant cells (full hit)
    if cleaned_response in [cell.upper() for cell in significant_cells]:
        return "correct", f"GT: {' '.join(significant_cells)}"
    
    # Check if response is in all overlapping cells (partial hit)
    elif cleaned_response in [cell.upper() for cell in all_cells]:
        return "partial", f"Partial hit - GT significant: {' '.join(significant_cells)}, all: {' '.join(all_cells)}"
    
    # Completely wrong
    else:
        return "incorrect", f"GT: {' '.join(significant_cells)}"

def encode_image(image_path):
    """Encode image to base64 for OpenAI API"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')
    """Encode image to base64 for OpenAI API"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode('utf-8')

def make_openai_call(image_path, view, condition, max_retries=3, base_delay=2, model="gpt-5"):
    """Make OpenAI API call with retry mechanism and exponential backoff"""
    
    for attempt in range(max_retries):
        try:
            # Encode the image
            base64_image = encode_image(image_path)
            
            if model != "gpt-5":
                response = openai.chat.completions.create(
                    model="gpt-4o-2024-05-13",  
                    messages=[
                        {
                            "role": "system",
                            "content": f"""You are an expert chest radiologist specializing in analyzing {view} chest X-rays. Your task is to precisely localize abnormalities using a grid overlay."""
                        },
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": f"""This is a gridded {view} view of a chest X-ray. The abnormality '{condition}' is confirmed to be present in this image.

    Your task:
    1. Identify the single grid cell where this abnormality - '{condition}' is the MOST prominent.
    2. Provide only the grid coordinate for this most representative cell. If the abnormality spans multiple cells, choose the cell that is most representative. A grid coordinate is defined as a letter followed by a number.
    3. Do not include any explanations or additional text in your response."""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    }
                                }
                            ]
                        }
                    ],
                    temperature=0.0
                )
            else:
                response = openai.chat.completions.create(
                    model="gpt-5",  
                    messages=[
                        {
                            "role": "system",
                            "content": f"""You are an expert chest radiologist specializing in analyzing {view} chest X-rays. Your task is to precisely localize abnormalities using a grid overlay."""
                        },
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": f"""This is a gridded {view} view of a chest X-ray. The abnormality '{condition}' is confirmed to be present in this image.

    Your task:
    1. Identify the single grid cell where this abnormality - '{condition}' is the MOST prominent.
    2. Provide only the grid coordinate for this most representative cell. If the abnormality spans multiple cells, choose the cell that is most representative. A grid coordinate is defined as a letter followed by a number.
    3. Do not include any explanations or additional text in your response."""
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/jpeg;base64,{base64_image}"
                                    }
                                }
                            ]
                        }
                    ],
                )
                
            return response.choices[0].message.content.strip()
            
        except Exception as e:
            error_msg = str(e)
            print(f"    Attempt {attempt + 1}/{max_retries} failed: {error_msg}")
            
            if attempt < max_retries - 1:  # Don't sleep on the last attempt
                # Exponential backoff with jitter
                delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
                print(f"    Waiting {delay:.1f} seconds before retry...")
                time.sleep(delay)
            else:
                print(f"    All {max_retries} attempts failed for {image_path}")
                return f"ERROR_AFTER_{max_retries}_RETRIES"
    
    return "ERROR"

def load_progress(progress_file):
    """Load progress from JSON file"""
    if os.path.exists(progress_file):
        try:
            with open(progress_file, 'r') as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Could not load progress file: {e}")
    return {"processed_files": [], "results": [], "last_batch": 0}

def save_progress(progress_file, progress_data):
    """Save progress to JSON file"""
    try:
        with open(progress_file, 'w') as f:
            json.dump(progress_data, f, indent=2)
    except IOError as e:
        print(f"Warning: Could not save progress: {e}")

def get_all_images(base_dir):
    """Get list of all images to process.

    Each returned dict now contains:
      - file_name: basename
      - image_path: absolute path
      - rel_path: path relative to base_dir (used as canonical id)
      - view
      - condition
    """
    all_images = []

    for view in ["frontal", "lateral"]:
        view_dir = os.path.join(base_dir, view)

        if not os.path.exists(view_dir):
            print(f"Directory not found: {view_dir}")
            continue

        print(f"Scanning {view} directory...")
        # Get all condition directories
        conditions = [d for d in os.listdir(view_dir) if os.path.isdir(os.path.join(view_dir, d))]
        print(f"Found conditions: {conditions}")

        for condition in conditions:
            condition_dir = os.path.join(view_dir, condition)

            # Process all images in the condition directory
            try:
                image_files = [f for f in os.listdir(condition_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            except Exception as e:
                print(f"  Warning: Could not list directory {condition_dir}: {e}")
                image_files = []

            print(f"  {condition}: {len(image_files)} images")

            for image_file in image_files:
                image_path = os.path.join(condition_dir, image_file)
                # rel_path relative to base_dir (portable canonical ID)
                rel_path = os.path.relpath(image_path, base_dir)
                all_images.append({
                    'file_name': image_file,
                    'image_path': image_path,
                    'rel_path': os.path.normpath(rel_path),
                    'view': view,
                    'condition': condition
                })

    print(f"Total images found: {len(all_images)}")
    return all_images

def process_images(debug_mode=False, max_images=100, delay=1.0, max_retries=3, model="gpt-4o"):
    """Process all images and generate CSV output (uses rel_path canonical keys)."""

    # Base directory
    base_dir = "./test_images_gridded_32x32"
    ground_truth_dir = "./ground_truth"
    results_dir = "./results_openai_gpt5"
    if model != "gpt-5":
        results_dir = "./results_openai_gpt4o_51324"

    # Load ground truth data
    ground_truth = load_ground_truth(ground_truth_dir)

    # Create results directory if it doesn't exist
    os.makedirs(results_dir, exist_ok=True)

    # Output files
    csv_path = os.path.join(results_dir, f"localization_results_{model.replace('-', '_')}.csv")
    progress_file = os.path.join(results_dir, f"progress_{model.replace('-', '_')}.json")

    # Get all images to process
    all_images = get_all_images(base_dir)
    total_images = len(all_images)

    print(f"Found {total_images} total images to process")
    print(f"Using model: {model}")

    # Load existing progress
    progress_data = load_progress(progress_file)
    raw_processed = progress_data.get("processed_files", [])
    results = progress_data.get("results", [])
    last_batch = progress_data.get("last_batch", 0)

    # Normalize stored keys (handles mixed path separators)
    processed_files = set(os.path.normpath(p) for p in raw_processed if isinstance(p, str))

    # Backwards compatibility: collect any legacy basename-only entries
    processed_basenames = set()
    for p in raw_processed:
        if isinstance(p, str):
            # if the entry doesn't contain any path separators, treat as legacy basename
            if ('/' not in p) and ('\\' not in p) and (os.sep not in p):
                processed_basenames.add(p)

    if processed_files or processed_basenames:
        sample_keys = list(processed_files)[:3] + list(processed_basenames)[:3]
        print("Sample processed keys (rel_paths or legacy basenames):", sample_keys)

    # Filter out already processed images:
    # Consider an image processed if its rel_path is in processed_files (new format)
    # OR if its basename is in processed_basenames (legacy)
    remaining_images = [
        img for img in all_images
        if (os.path.normpath(img['rel_path']) not in processed_files) and (img['file_name'] not in processed_basenames)
    ]

    if processed_files or processed_basenames:
        total_processed = len(processed_files) + len(processed_basenames)
        print(f"Resuming from previous run: {total_processed} already processed")
        print(f"Remaining images: {len(remaining_images)}")

    # Process images
    processed_in_current_run = 0
    current_batch = last_batch
    correct_answers = 0
    partial_answers = 0
    incorrect_answers = 0
    total_with_gt = 0

    for i, image_info in enumerate(remaining_images):
        if debug_mode and processed_in_current_run >= max_images:
            print(f"\nDebug mode: Reached limit of {max_images} images")
            print(f"Processed {processed_in_current_run} images in this run")
            break

        image_path = image_info['image_path']
        view = image_info['view']
        condition = image_info['condition']
        image_file = image_info['file_name']
        image_rel = os.path.normpath(image_info['rel_path'])

        current_position = len(processed_files) + len(processed_basenames) + processed_in_current_run + 1
        current_run_position = processed_in_current_run + 1

        print(f"Processing [{current_run_position}/{max_images}] (overall: {current_position}/{total_images}): {view} - {condition} - {image_file}")

        # Make OpenAI call with retry mechanism
        response = make_openai_call(image_path, view, condition, max_retries=max_retries, model=model)

        # Check correctness against ground truth
        correctness_status, gt_info = check_answer_correctness(image_file, response, ground_truth)
        if correctness_status != "no_gt":
            total_with_gt += 1
            if correctness_status == "correct":
                correct_answers += 1
            elif correctness_status == "partial":
                partial_answers += 1
            elif correctness_status == "incorrect":
                incorrect_answers += 1

        # Store result
        result = {
            'file_name': image_file,
            'output': response,
            'view': view,
            'condition': condition,
            'model': model,
            'correctness_status': correctness_status,
            'ground_truth_info': gt_info,
            'processed_at': datetime.now().isoformat()
        }
        results.append(result)

        # Mark this image as processed using rel_path (canonical new format)
        processed_files.add(image_rel)
        processed_in_current_run += 1

        # Print result with correctness info
        status_indicators = {
            "correct": "✓",
            "partial": "◐",
            "incorrect": "✗",
            "no_gt": "?"
        }
        correctness_indicator = status_indicators.get(correctness_status, "?")
        print(f"    Result: {response} [{correctness_indicator}] {gt_info}")

        # Print running accuracy
        if total_with_gt > 0:
            correct_pct = (correct_answers / total_with_gt) * 100
            partial_pct = (partial_answers / total_with_gt) * 100
            print(f"    Running stats: {correct_answers} correct, {partial_answers} partial, {incorrect_answers} wrong / {total_with_gt} ({correct_pct:.1f}% correct, {partial_pct:.1f}% partial)")

        # Add delay between requests to prevent rate limiting
        if processed_in_current_run < len(remaining_images) and delay > 0:
            time.sleep(delay)

        # Save progress every 10 images or in debug mode every image
        save_frequency = 1 if debug_mode else 10
        if processed_in_current_run % save_frequency == 0:
            correct_pct = (correct_answers / total_with_gt * 100) if total_with_gt > 0 else 0
            partial_pct = (partial_answers / total_with_gt * 100) if total_with_gt > 0 else 0

            # Save only normalized rel_paths from now on (we drop legacy basenames on save).
            progress_data = {
                "processed_files": sorted(list(processed_files)),
                "results": results,
                "last_batch": current_batch,
                "model": model,
                "accuracy_stats": {
                    "correct_answers": correct_answers,
                    "partial_answers": partial_answers,
                    "incorrect_answers": incorrect_answers,
                    "total_with_ground_truth": total_with_gt,
                    "correct_percentage": correct_pct,
                    "partial_percentage": partial_pct
                },
                "last_updated": datetime.now().isoformat()
            }
            save_progress(progress_file, progress_data)

            if debug_mode:
                print(f"    Progress saved ({processed_in_current_run} images completed, {correct_pct:.1f}% correct, {partial_pct:.1f}% partial)")

    # Final save
    final_correct_pct = (correct_answers / total_with_gt * 100) if total_with_gt > 0 else 0
    final_partial_pct = (partial_answers / total_with_gt * 100) if total_with_gt > 0 else 0

    # Write final progress (save normalized rel_paths)
    progress_data = {
        "processed_files": sorted(list(processed_files)),
        "results": results,
        "last_batch": current_batch + 1 if debug_mode and processed_in_current_run >= max_images else current_batch,
        "model": model,
        "accuracy_stats": {
            "correct_answers": correct_answers,
            "partial_answers": partial_answers,
            "incorrect_answers": incorrect_answers,
            "total_with_ground_truth": total_with_gt,
            "correct_percentage": final_correct_pct,
            "partial_percentage": final_partial_pct
        },
        "last_updated": datetime.now().isoformat()
    }
    save_progress(progress_file, progress_data)

    # Write results to CSV
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file_name', 'output', 'correctness_status', 'ground_truth_info'])

        for result in results:
            writer.writerow([
                result['file_name'],
                result['output'],
                result.get('correctness_status', ''),
                result.get('ground_truth_info', '')
            ])

    print(f"\nResults written to: {csv_path}")
    print(f"Progress saved to: {progress_file}")
    total_processed_now = len(processed_files)
    print(f"Total images processed so far: {total_processed_now}")
    print(f"Images processed in this run: {processed_in_current_run}")
    print(f"Final stats: {correct_answers} correct ({final_correct_pct:.1f}%), {partial_answers} partial ({final_partial_pct:.1f}%), {incorrect_answers} wrong")

    if debug_mode and len(remaining_images) > processed_in_current_run:
        print(f"\nDebug mode: Processed {processed_in_current_run} images in this batch.")
        print(f"Remaining images: {len(remaining_images) - processed_in_current_run}")
        print(f"To continue processing, run the script again.")
    elif len(remaining_images) == processed_in_current_run:
        print(f"\nAll remaining images from this session have been processed!")
        if total_processed_now < total_images:
            print(f"Total progress: {total_processed_now}/{total_images} images completed overall.")
        else:
            print(f"All images in the dataset have been processed!")
    else:
        print(f"\nStopped processing. Run again to continue.")
        print(f"Remaining images: {len(remaining_images) - processed_in_current_run}")

def main():
    parser = argparse.ArgumentParser(description='Medical image analysis with OpenAI GPT models')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode (process limited number of images)')
    parser.add_argument('-n', '--max-images', type=int, default=100, help='Maximum number of images to process in debug mode (default: 100)')
    parser.add_argument('--delay', type=float, default=1.0, help='Delay between API calls in seconds (default: 1.0)')
    parser.add_argument('--retries', type=int, default=3, help='Max retries for failed API calls (default: 3)')
    parser.add_argument('--reset', action='store_true', help='Reset progress and start from beginning')
    parser.add_argument('--model', type=str, default='gpt-5',
                        help='OpenAI model to use (default: gpt-5)')
    
    args = parser.parse_args()
    
    # Check if OpenAI API key is set
    if not os.getenv("OPENAI_API_KEY"):
        print("Error: OPENAI_API_KEY environment variable not set.")
        print("Please set your OpenAI API key: export OPENAI_API_KEY='your-key-here'")
        return
    
    if args.reset:
        results_dir = "./results_openai"
        progress_file = os.path.join(results_dir, f"progress_{args.model.replace('-', '_')}.json")
        if os.path.exists(progress_file):
            os.remove(progress_file)
            print("Progress reset. Starting fresh.")
    
    print("Starting medical image analysis with OpenAI...")
    print(f"Using model: {args.model}")
    if args.debug:
        print(f"Debug mode enabled - will process up to {args.max_images} images")
        print(f"Delay between calls: {args.delay}s, Max retries: {args.retries}")
    
    process_images(debug_mode=args.debug, max_images=args.max_images, 
                  delay=args.delay, max_retries=args.retries, model=args.model)
    print("Analysis complete!")

if __name__ == "__main__":
    main()