#!/usr/bin/env python3
import os
import csv
import json
import ollama  # Changed from openai to ollama
import argparse
import time
import random
from pathlib import Path
from datetime import datetime
from collections import defaultdict

"""
other prompt:2. Provide only the grid coordinate for this most representative cell. If the abnormality spans multiple cells, choose the cell that is most representative. A grid coordinate is defined as a letter followed by a number.
"""

def load_ground_truth(ground_truth_dir):
    """Load ground truth data from CSV files"""
    ground_truth = {}
    
    for view in ["frontal", "lateral"]:
        view_path = os.path.join(ground_truth_dir, f"test_gt_csvs_32x32", view)
        
        if not os.path.exists(view_path):
            print(f"Warning: Ground truth directory not found: {view_path}")
            continue
            
        for csv_file in os.listdir(view_path):
            if csv_file.endswith('.csv'):
                condition = csv_file.replace('.csv', '')
                csv_path = os.path.join(view_path, csv_file)
                
                try:
                    with open(csv_path, 'r') as f:
                        reader = csv.DictReader(f)
                        for row in reader:
                            image_name = row['Image Name']
                            significant_cells = row['Significant Overlapping Cells'].strip()
                            if significant_cells:
                                # Split cells and clean them
                                cells = [cell.strip() for cell in significant_cells.split() if cell.strip()]
                                ground_truth[image_name] = {
                                    'view': view,
                                    'condition': condition,
                                    'significant_cells': cells,
                                    'all_cells': row['All Overlapping Cells'].strip().split()
                                }
                except Exception as e:
                    print(f"Warning: Could not load ground truth from {csv_path}: {e}")
    
    print(f"Loaded ground truth for {len(ground_truth)} images")
    return ground_truth

def check_answer_correctness(image_file, response, ground_truth):
    """Check if the model's response matches ground truth"""
    # Remove file extension to match ground truth format
    image_name_base = os.path.splitext(image_file)[0]
    
    if image_name_base not in ground_truth:
        return "no_gt", "No ground truth available"
    
    gt_data = ground_truth[image_name_base]
    significant_cells = gt_data['significant_cells']
    all_cells = gt_data['all_cells']
    
    # Clean the response (remove any extra text, spaces, etc.)
    cleaned_response = response.strip().upper()
    
    # Check if response is in significant cells (full hit)
    if cleaned_response in [cell.upper() for cell in significant_cells]:
        return "correct", f"GT: {' '.join(significant_cells)}"
    
    # Check if response is in all overlapping cells (partial hit)
    elif cleaned_response in [cell.upper() for cell in all_cells]:
        return "partial", f"Partial hit - GT significant: {' '.join(significant_cells)}, all: {' '.join(all_cells)}"
    
    # Completely wrong
    else:
        return "incorrect", f"GT: {' '.join(significant_cells)}"

# Removed encode_image function as Ollama handles images directly

def make_ollama_call(image_path, view, condition, model, max_retries=3, base_delay=2):
    """Make Ollama API call with retry mechanism and exponential backoff"""
    
    for attempt in range(max_retries):
        try:
            response = ollama.chat(
                model=model,
                messages=[
                    {
                        "role": "system",
                        "content": f"""You are an expert chest radiologist specializing in analyzing {view} chest X-rays. Your task is to precisely localize abnormalities using a grid overlay."""
                    },
                    {
                        "role": "user",
                        "content": f"""This is a gridded {view} view of a chest X-ray. The abnormality '{condition}' is confirmed to be present in this image.

Your task:
1. Identify the single grid cell where this abnormality - '{condition}' is the MOST prominent.
2. Provide only the grid coordinate for this most representative cell. If the abnormality spans multiple cells, choose the cell that is most representative. A grid coordinate is defined as a letter followed by a number.
3. Do not include any explanations or additional text in your response.""",
                        "images": [image_path]
                    }
                ],
                options = {
                    "temperature": 0
                }
            )
            
            return response["message"]["content"].strip()
            
        except Exception as e:
            error_msg = str(e)
            print(f"    Attempt {attempt + 1}/{max_retries} failed: {error_msg}")
            
            if attempt < max_retries - 1:  # Don't sleep on the last attempt
                # Exponential backoff with jitter
                delay = base_delay * (2 ** attempt) + random.uniform(0, 1)
                print(f"    Waiting {delay:.1f} seconds before retry...")
                time.sleep(delay)
            else:
                print(f"    All {max_retries} attempts failed for {image_path}")
                return f"ERROR_AFTER_{max_retries}_RETRIES"
    
    return "ERROR"

def load_progress(progress_file):
    """Load progress from JSON file"""
    if os.path.exists(progress_file):
        try:
            with open(progress_file, 'r') as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            print(f"Warning: Could not load progress file: {e}")
    return {"processed_files": [], "results": [], "last_batch": 0}

def save_progress(progress_file, progress_data):
    """Save progress to JSON file"""
    try:
        with open(progress_file, 'w') as f:
            json.dump(progress_data, f, indent=2)
    except IOError as e:
        print(f"Warning: Could not save progress: {e}")

def get_all_images(base_dir):
    """Get list of all images to process.

    Each returned dict now contains:
      - file_name: basename
      - image_path: absolute path
      - rel_path: path relative to base_dir (used as canonical id)
      - view
      - condition
    """
    all_images = []

    for view in ["frontal", "lateral"]:
        view_dir = os.path.join(base_dir, view)

        if not os.path.exists(view_dir):
            print(f"Directory not found: {view_dir}")
            continue

        print(f"Scanning {view} directory...")
        # Get all condition directories
        conditions = [d for d in os.listdir(view_dir) if os.path.isdir(os.path.join(view_dir, d))]
        print(f"Found conditions: {conditions}")

        for condition in conditions:
            condition_dir = os.path.join(view_dir, condition)

            # Process all images in the condition directory
            try:
                image_files = [f for f in os.listdir(condition_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            except Exception as e:
                print(f"  Warning: Could not list directory {condition_dir}: {e}")
                image_files = []

            print(f"  {condition}: {len(image_files)} images")

            for image_file in image_files:
                image_path = os.path.join(condition_dir, image_file)
                # rel_path relative to base_dir (portable canonical ID)
                rel_path = os.path.relpath(image_path, base_dir)
                all_images.append({
                    'file_name': image_file,
                    'image_path': image_path,
                    'rel_path': os.path.normpath(rel_path),
                    'view': view,
                    'condition': condition
                })

    print(f"Total images found: {len(all_images)}")
    return all_images

def process_images(debug_mode=False, max_images=100, delay=1.0, max_retries=3, model="amsaravi/medgemma-4b-it:q6"):
    """Process all images and generate CSV output (uses rel_path canonical keys)."""

    # Base directory
    base_dir = "./test_images_gridded_32x32"
    ground_truth_dir = "./ground_truth"
    results_dir = "./results_ollama-t0"  # Changed results directory

    # Load ground truth data
    ground_truth = load_ground_truth(ground_truth_dir)

    # Create results directory if it doesn't exist
    os.makedirs(results_dir, exist_ok=True)

    # Output files - updated for model naming
    model_safe = model.replace('/', '_').replace(':', '_')
    csv_path = os.path.join(results_dir, f"localization_results_{model_safe}.csv")
    progress_file = os.path.join(results_dir, f"progress_{model_safe}.json")

    # Get all images to process
    all_images = get_all_images(base_dir)
    total_images = len(all_images)

    print(f"Found {total_images} total images to process")
    print(f"Using model: {model}")

    # Load existing progress
    progress_data = load_progress(progress_file)
    raw_processed = progress_data.get("processed_files", [])
    results = progress_data.get("results", [])
    last_batch = progress_data.get("last_batch", 0)

    # Normalize stored keys (handles mixed path separators)
    processed_files = set(os.path.normpath(p) for p in raw_processed if isinstance(p, str))

    # Backwards compatibility: collect any legacy basename-only entries
    processed_basenames = set()
    for p in raw_processed:
        if isinstance(p, str):
            # if the entry doesn't contain any path separators, treat as legacy basename
            if ('/' not in p) and ('\\' not in p) and (os.sep not in p):
                processed_basenames.add(p)

    if processed_files or processed_basenames:
        sample_keys = list(processed_files)[:3] + list(processed_basenames)[:3]
        print("Sample processed keys (rel_paths or legacy basenames):", sample_keys)

    # Filter out already processed images:
    # Consider an image processed if its rel_path is in processed_files (new format)
    # OR if its basename is in processed_basenames (legacy)
    remaining_images = [
        img for img in all_images
        if (os.path.normpath(img['rel_path']) not in processed_files) and (img['file_name'] not in processed_basenames)
    ]

    if processed_files or processed_basenames:
        total_processed = len(processed_files) + len(processed_basenames)
        print(f"Resuming from previous run: {total_processed} already processed")
        print(f"Remaining images: {len(remaining_images)}")

    # Process images
    processed_in_current_run = 0
    current_batch = last_batch
    correct_answers = 0
    partial_answers = 0
    incorrect_answers = 0
    total_with_gt = 0

    for i, image_info in enumerate(remaining_images):
        if debug_mode and processed_in_current_run >= max_images:
            print(f"\nDebug mode: Reached limit of {max_images} images")
            print(f"Processed {processed_in_current_run} images in this run")
            break

        image_path = image_info['image_path']
        view = image_info['view']
        condition = image_info['condition']
        image_file = image_info['file_name']
        image_rel = os.path.normpath(image_info['rel_path'])

        current_position = len(processed_files) + len(processed_basenames) + processed_in_current_run + 1
        current_run_position = processed_in_current_run + 1

        print(f"Processing [{current_run_position}/{max_images}] (overall: {current_position}/{total_images}): {view} - {condition} - {image_file}")

        # Make Ollama call with retry mechanism
        response = make_ollama_call(image_path, view, condition, model, max_retries=max_retries)

        # Check correctness against ground truth
        correctness_status, gt_info = check_answer_correctness(image_file, response, ground_truth)
        if correctness_status != "no_gt":
            total_with_gt += 1
            if correctness_status == "correct":
                correct_answers += 1
            elif correctness_status == "partial":
                partial_answers += 1
            elif correctness_status == "incorrect":
                incorrect_answers += 1

        # Store result
        result = {
            'file_name': image_file,
            'output': response,
            'view': view,
            'condition': condition,
            'model': model,
            'correctness_status': correctness_status,
            'ground_truth_info': gt_info,
            'processed_at': datetime.now().isoformat()
        }
        results.append(result)

        # Mark this image as processed using rel_path (canonical new format)
        processed_files.add(image_rel)
        processed_in_current_run += 1

        # Print result with correctness info
        status_indicators = {
            "correct": "✓",
            "partial": "◐",
            "incorrect": "✗",
            "no_gt": "?"
        }
        correctness_indicator = status_indicators.get(correctness_status, "?")
        print(f"    Result: {response} [{correctness_indicator}] {gt_info}")

        # Print running accuracy
        if total_with_gt > 0:
            correct_pct = (correct_answers / total_with_gt) * 100
            partial_pct = (partial_answers / total_with_gt) * 100
            print(f"    Running stats: {correct_answers} correct, {partial_answers} partial, {incorrect_answers} wrong / {total_with_gt} ({correct_pct:.1f}% correct, {partial_pct:.1f}% partial)")

        # Add delay between requests to prevent rate limiting
        if processed_in_current_run < len(remaining_images) and delay > 0:
            time.sleep(delay)

        # Save progress every 10 images or in debug mode every image
        save_frequency = 1 if debug_mode else 10
        if processed_in_current_run % save_frequency == 0:
            correct_pct = (correct_answers / total_with_gt * 100) if total_with_gt > 0 else 0
            partial_pct = (partial_answers / total_with_gt * 100) if total_with_gt > 0 else 0

            # Save only normalized rel_paths from now on (we drop legacy basenames on save).
            progress_data = {
                "processed_files": sorted(list(processed_files)),
                "results": results,
                "last_batch": current_batch,
                "model": model,
                "accuracy_stats": {
                    "correct_answers": correct_answers,
                    "partial_answers": partial_answers,
                    "incorrect_answers": incorrect_answers,
                    "total_with_ground_truth": total_with_gt,
                    "correct_percentage": correct_pct,
                    "partial_percentage": partial_pct
                },
                "last_updated": datetime.now().isoformat()
            }
            save_progress(progress_file, progress_data)

            if debug_mode:
                print(f"    Progress saved ({processed_in_current_run} images completed, {correct_pct:.1f}% correct, {partial_pct:.1f}% partial)")

    # Final save
    final_correct_pct = (correct_answers / total_with_gt * 100) if total_with_gt > 0 else 0
    final_partial_pct = (partial_answers / total_with_gt * 100) if total_with_gt > 0 else 0

    # Write final progress (save normalized rel_paths)
    progress_data = {
        "processed_files": sorted(list(processed_files)),
        "results": results,
        "last_batch": current_batch + 1 if debug_mode and processed_in_current_run >= max_images else current_batch,
        "model": model,
        "accuracy_stats": {
            "correct_answers": correct_answers,
            "partial_answers": partial_answers,
            "incorrect_answers": incorrect_answers,
            "total_with_ground_truth": total_with_gt,
            "correct_percentage": final_correct_pct,
            "partial_percentage": final_partial_pct
        },
        "last_updated": datetime.now().isoformat()
    }
    save_progress(progress_file, progress_data)

    # Write results to CSV
    with open(csv_path, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['file_name', 'output', 'correctness_status', 'ground_truth_info'])

        for result in results:
            writer.writerow([
                result['file_name'],
                result['output'],
                result.get('correctness_status', ''),
                result.get('ground_truth_info', '')
            ])

    print(f"\nResults written to: {csv_path}")
    print(f"Progress saved to: {progress_file}")
    total_processed_now = len(processed_files)
    print(f"Total images processed so far: {total_processed_now}")
    print(f"Images processed in this run: {processed_in_current_run}")
    print(f"Final stats: {correct_answers} correct ({final_correct_pct:.1f}%), {partial_answers} partial ({final_partial_pct:.1f}%), {incorrect_answers} wrong")

    if debug_mode and len(remaining_images) > processed_in_current_run:
        print(f"\nDebug mode: Processed {processed_in_current_run} images in this batch.")
        print(f"Remaining images: {len(remaining_images) - processed_in_current_run}")
        print(f"To continue processing, run the script again.")
    elif len(remaining_images) == processed_in_current_run:
        print(f"\nAll remaining images from this session have been processed!")
        if total_processed_now < total_images:
            print(f"Total progress: {total_processed_now}/{total_images} images completed overall.")
        else:
            print(f"All images in the dataset have been processed!")
    else:
        print(f"\nStopped processing. Run again to continue.")
        print(f"Remaining images: {len(remaining_images) - processed_in_current_run}")

def main():
    parser = argparse.ArgumentParser(description='Medical image analysis with Ollama MedGemma models')
    parser.add_argument('--debug', action='store_true', help='Enable debug mode (process limited number of images)')
    parser.add_argument('-n', '--max-images', type=int, default=100, help='Maximum number of images to process in debug mode (default: 100)')
    parser.add_argument('--delay', type=float, default=1.0, help='Delay between API calls in seconds (default: 1.0)')
    parser.add_argument('--retries', type=int, default=3, help='Max retries for failed API calls (default: 3)')
    parser.add_argument('--reset', action='store_true', help='Reset progress and start from beginning')
    parser.add_argument('--model', type=str, default='amsaravi/medgemma-4b-it:q6', 
                        help='Ollama model to use (default: amsaravi/medgemma-4b-it:q6)')
    
    args = parser.parse_args()
    
    # Check if Ollama is running (basic check)
    try:
        ollama.list()
        print("✓ Ollama connection successful")
    except Exception as e:
        print("Error: Could not connect to Ollama.")
        print("Please make sure Ollama is running: ollama serve")
        print(f"Error details: {e}")
        return
    
    if args.reset:
        results_dir = "./results_ollama-t0"
        model_safe = args.model.replace('/', '_').replace(':', '_')
        progress_file = os.path.join(results_dir, f"progress_{model_safe}.json")
        if os.path.exists(progress_file):
            os.remove(progress_file)
            print("Progress reset. Starting fresh.")
    
    print("Starting medical image analysis with Ollama MedGemma...")
    print(f"Using model: {args.model}")
    if args.debug:
        print(f"Debug mode enabled - will process up to {args.max_images} images")
        print(f"Delay between calls: {args.delay}s, Max retries: {args.retries}")
    
    process_images(debug_mode=args.debug, max_images=args.max_images, 
                  delay=args.delay, max_retries=args.retries, model=args.model)
    print("Analysis complete!")

if __name__ == "__main__":
    main()