import json
import pandas as pd
import os
from pathlib import Path

def load_ground_truth_data(gt_directory):
    """
    Load all ground truth CSV files and organize by view and condition.
    Returns a nested dictionary: {view: {condition: dataframe}}
    """
    gt_data = {}
    gt_path = Path(gt_directory)
    
    if not gt_path.exists():
        raise FileNotFoundError(f"Ground truth directory not found: {gt_directory}")
    
    # Iterate through view directories
    for view_dir in gt_path.iterdir():
        if view_dir.is_dir():
            view_name = view_dir.name
            gt_data[view_name] = {}
            
            # Iterate through CSV files in each view directory
            for csv_file in view_dir.glob("*.csv"):
                condition_name = csv_file.stem  # filename without extension
                try:
                    df = pd.read_csv(csv_file)
                    gt_data[view_name][condition_name] = df
                    print(f"Loaded GT data for {view_name}/{condition_name}: {len(df)} entries")
                except Exception as e:
                    print(f"Error loading {csv_file}: {e}")
    
    return gt_data

def extract_image_name_without_extension(file_name):
    """
    Extract image name from filename like 'patient64761_study1_view1_frontal.jpg'
    Returns 'patient64761_study1_view1_frontal' (without extension)
    """
    if file_name.endswith('.jpg'):
        return file_name[:-4]  # Remove .jpg extension
    return file_name

def get_ground_truth_for_image(gt_data, view, condition, image_name_no_ext):
    """
    Get ground truth information for a specific image from the CSV data.
    CSV structure: Image Name, All Overlapping Cells, Significant Overlapping Cells
    Returns: (gt_info_string, significant_cells_list, all_cells_list)
    """
    if view not in gt_data or condition not in gt_data[view]:
        return f"GT data not found for {view}/{condition}", [], []
    
    df = gt_data[view][condition]
    
    # Check if required columns exist
    required_columns = ['Image Name', 'Significant Overlapping Cells']
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        return f"Missing columns in {view}/{condition} CSV: {missing_columns}", [], []
    
    # Find the row for this image
    image_row = df[df['Image Name'] == image_name_no_ext]
    
    if image_row.empty:
        return f"Image {image_name_no_ext} not found in {view}/{condition} GT data", [], []
    
    # Extract ground truth information
    significant_cells = image_row['Significant Overlapping Cells'].iloc[0]
    all_cells = image_row['All Overlapping Cells'].iloc[0]
    
    # Handle NaN values
    if pd.isna(significant_cells):
        significant_cells = "None"
        significant_cells_list = []
    else:
        significant_cells_list = significant_cells.split()
    
    if pd.isna(all_cells):
        all_cells = "None"  
        all_cells_list = []
    else:
        all_cells_list = all_cells.split()
    
    # Format the ground truth info similar to your original format
    gt_info = f"GT significant: {significant_cells}, all: {all_cells}"
    
    return gt_info, significant_cells_list, all_cells_list

def determine_correctness_status(model_output, significant_cells_list, all_cells_list):
    """
    Determine correctness status based on model output and ground truth.
    - 'correct' if model output matches any significant cell
    - 'partial' if model output matches any cell in all cells (but not significant)
    - 'incorrect' otherwise
    """
    if not model_output or not model_output.strip():
        return 'incorrect'
    
    # Split model output in case it contains multiple cells
    output_cells = model_output.split()
    
    # Check if any output cell is in significant cells
    for cell in output_cells:
        if cell in significant_cells_list:
            return 'correct'
    
    # Check if any output cell is in all cells (but not already in significant)
    for cell in output_cells:
        if cell in all_cells_list:
            return 'partial'
    
    return 'incorrect'

def update_ground_truth_info(results_file, gt_directory, output_file=None):
    """
    Update ground truth information in the results JSON file.
    If an image name contains '(1)', that entry will be removed.
    """
    # Load the results file
    with open(results_file, 'r') as f:
        data = json.load(f)
    
    # Load ground truth data
    gt_data = load_ground_truth_data(gt_directory)
    
    updated_results = []
    updated_count = 0
    removed_count = 0
    
    for result in data['results']:
        # Skip images with "(1)" in the name
        if "(1)" in result['file_name']:
            print(f"Removed {result['file_name']} (duplicate with (1))")
            removed_count += 1
            continue
        
        # Extract image name without extension
        image_name_no_ext = extract_image_name_without_extension(result['file_name'])
        
        # Get ground truth for this image
        new_gt_info, significant_cells_list, all_cells_list = get_ground_truth_for_image(
            gt_data, 
            result['view'], 
            result['condition'], 
            image_name_no_ext
        )
        
        # Determine correctness status
        model_output = result.get('output', '')
        new_correctness_status = determine_correctness_status(
            model_output, significant_cells_list, all_cells_list
        )
        
        # Update the ground truth info and correctness status
        result['ground_truth_info'] = new_gt_info
        result['correctness_status'] = new_correctness_status
        
        print(f"Updated {result['file_name']}:")
        print(f"  Model output: {model_output}")
        print(f"  New GT: {new_gt_info}")
        print(f"  New correctness: {new_correctness_status}")
        print()
        
        updated_results.append(result)
        updated_count += 1
    
    # Replace results with filtered + updated list
    data['results'] = updated_results
    
    # Save the updated results
    if output_file is None:
        output_file = results_file.replace('.json', '_updated.json')
    
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)
    
    print(f"Updated {updated_count} entries, removed {removed_count} entries with (1), and saved to: {output_file}")
    return output_file

def process_multiple_results(results_files, gt_directory):
    """
    Process multiple results JSON files, applying ground truth updates and 
    removing '(1)' duplicate entries.
    """
    for results_file in results_files:
        print("\n" + "="*80)
        print(f"Processing results file: {results_file}")
        print("="*80)
        
        try:
            updated_file = update_ground_truth_info(results_file, gt_directory)
            print(f"✅ Successfully updated: {updated_file}")
        except Exception as e:
            print(f"❌ Error processing {results_file}: {e}")


if __name__ == "__main__":
    # Ground truth directory
    gt_directory = "../ground_truth/test_gt_csvs_16x16" 

    # Array of result files to process
    results_files = [
        # "./progress_gpt_4o_2024_05_13.json",
        # "./progress_fewshot_gpt_4o_2024_05_13.json",
        # "./progress_gpt_5.json",
        # "./progress_fewshot_gpt_5.json",
        # "./progress_puyangwang_medgemma-27b-it_q8_updated.json"  ,
        "./progress_gpt_4o_2024_05_13.json",
        "./progress_gpt_5.json",
        "./progress_puyangwang_medgemma-27b-it_q8.json"
    ]

    # Process all files
    process_multiple_results(results_files, gt_directory)
        
    # Optional: Inspect CSV structure if needed
    # print("\n=== CSV Structure ===")
    # inspect_csv_structure(gt_directory)

# Additional utility function to inspect CSV structure
def inspect_csv_structure(gt_directory):
    """
    Helper function to inspect the structure of your CSV files
    to understand what columns are available.
    """
    gt_path = Path(gt_directory)
    
    for view_dir in gt_path.iterdir():
        if view_dir.is_dir():
            print(f"\nView: {view_dir.name}")
            for csv_file in view_dir.glob("*.csv"):
                print(f"  Condition: {csv_file.stem}")
                try:
                    df = pd.read_csv(csv_file)
                    print(f"    Columns: {list(df.columns)}")
                    print(f"    Shape: {df.shape}")
                    if len(df) > 0:
                        print(f"    Sample row:")
                        sample = df.iloc[0]
                        for col in df.columns:
                            print(f"      {col}: {sample[col]}")
                except Exception as e:
                    print(f"    Error: {e}")

def verify_image_matches(results_file, gt_directory):
    """
    Helper function to verify which images have matching ground truth data.
    """
    with open(results_file, 'r') as f:
        data = json.load(f)
    
    gt_data = load_ground_truth_data(gt_directory)
    
    print("Checking image matches:")
    for result in data['results'][:5]:  # Check first 5 entries
        image_name_no_ext = extract_image_name_without_extension(result['file_name'])
        view = result['view']
        condition = result['condition']
        
        if view in gt_data and condition in gt_data[view]:
            df = gt_data[view][condition]
            match = df[df['Image Name'] == image_name_no_ext]
            if not match.empty:
                print(f"✓ Found match: {image_name_no_ext} in {view}/{condition}")
                print(f"  Significant: {match['Significant Overlapping Cells'].iloc[0]}")
            else:
                print(f"✗ No match: {image_name_no_ext} in {view}/{condition}")
                print(f"  Available images: {df['Image Name'].head().tolist()}")
        else:
            print(f"✗ No GT data for {view}/{condition}")
        print()

# Uncomment the line below to inspect your CSV structure first
# inspect_csv_structure("ground_truth/test_gt_csvs_32x32")