import os
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import matplotlib.patheffects as PathEffects
import matplotlib.colors as mcolors
import cv2
import re
from collections import defaultdict

# Set professional plotting defaults
plt.rcParams.update({
    'axes.labelsize': 10,
    'font.size': 10,
    'legend.fontsize': 8,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'xtick.major.pad': 8,
    'ytick.major.pad': 8,
    'axes.labelpad': 12
})

def read_json_results(json_path):
    """
    Read JSON results file and organize by condition and view.
    
    Returns:
    dict: {condition: {view: [{'file_name': str, 'output': str, 'ground_truth_info': str}, ...]}}
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    results = defaultdict(lambda: defaultdict(list))
    
    for result in data.get('results', []):
        condition = result.get('condition')
        view = result.get('view', 'frontal')
        
        results[condition][view].append({
            'file_name': result.get('file_name'),
            'output': result.get('output', ''),
            'ground_truth_info': result.get('ground_truth_info', '')
        })
    
    return dict(results)

def extract_ground_truth_from_json(json_path, use_significant=True):
    """
    Extract ground truth data from JSON results.
    
    Parameters:
    json_path: path to JSON file
    use_significant: if True, use "GT significant", else use "all"
    
    Returns:
    dict: {condition: {view: [ground_truth_cells_string, ...]}}
    """
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    results = defaultdict(lambda: defaultdict(list))
    
    for result in data.get('results', []):
        condition = result.get('condition')
        view = result.get('view', 'frontal')
        gt_info = result.get('ground_truth_info', '')
        
        # Parse ground truth info
        # Format: "GT significant: E6 F6 E5 F5, all: E7 F7 G7 D6 E6 F6"
        if use_significant and 'GT significant:' in gt_info:
            # Extract significant cells
            parts = gt_info.split('GT significant:')[1].split(',')[0].strip()
            gt_cells = parts if parts else ""
        elif not use_significant and 'all:' in gt_info:
            # Extract all cells
            parts = gt_info.split('all:')[1].strip()
            gt_cells = parts if parts else ""
        else:
            gt_cells = ""
        
        results[condition][view].append(gt_cells)
    
    return dict(results)

def is_valid_cell(cell):
    return re.match(r'^[A-H][1-8]$', cell) is not None

def calculate_frequencies(grid_data):
    """
    Calculate frequency as (number of times cell appears) / (total number of images)
    """
    freq = {}
    num_images = len(grid_data)
    
    if num_images == 0:
        return freq
    
    for entry in grid_data:
        cells = entry.split()
        unique_cells = set(cell for cell in cells if is_valid_cell(cell))
        for cell in unique_cells:
            freq[cell] = freq.get(cell, 0) + 1
    
    for cell in freq:
        freq[cell] = (freq[cell] / num_images) * 100
    
    return freq

def create_frequency_grid(frequencies):
    grid = np.zeros((8, 8))
    for cell, freq in frequencies.items():
        col = ord(cell[0]) - ord('A')
        row = int(cell[1]) - 1
        grid[row, col] = freq
    return grid

def read_and_preprocess_xray(image_path):
    """Read and preprocess the X-ray image."""
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not read image at {image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.convertScaleAbs(image, alpha=1.1, beta=0)
    return image

def create_grid_overlay(image, grid_size=(8, 8)):
    """Resize image so it is exactly divisible by the grid dimensions."""
    height, width = image.shape[:2]
    cell_height = height // grid_size[0]
    cell_width = width // grid_size[1]
    new_height = cell_height * grid_size[0]
    new_width = cell_width * grid_size[1]
    image = cv2.resize(image, (new_width, new_height))
    return image, cell_height, cell_width

def get_prediction_colors():
    """Returns a dictionary of prediction types and their corresponding colormap settings."""
    base_colors = {
        'gt': [(0, 0, 0, 0),
               (0.75, 0.33, 0.02, 0.15),
               (0.75, 0.33, 0.02, 0.6)],
        'zeroshot': [(0, 0, 0, 0),
                     (0.20, 0.63, 0.17, 0.15),
                     (0.20, 0.63, 0.17, 0.6)]
    }
    return base_colors

def create_colormap(pred_type='gt', style='blocky'):
    """
    Create a colormap based on prediction type.
    """
    colors = get_prediction_colors()
    if pred_type not in colors:
        # Generate a unique color for new model types
        golden_ratio = (1 + 5 ** 0.5) / 2
        hue = (hash(pred_type) * golden_ratio) % 1
        rgb = plt.cm.tab20(hue)[:3]
        colors[pred_type] = [(0, 0, 0, 0), (*rgb, 0.15), (*rgb, 0.6)]
    n_bins = 256
    return mcolors.LinearSegmentedColormap.from_list(f"{pred_type}_colormap", colors[pred_type], N=n_bins)

def overlay_heatmap(image_path, frequency_grid, output_path, title, pred_type='gt'):
    """
    Overlay blocky heatmap on X-ray image.

    Parameters:
    image_path: path to X-ray image
    frequency_grid: 2D numpy array of frequencies
    output_path: path to save the output image
    title: title of the plot
    pred_type: prediction type for colormap
    """
    image = read_and_preprocess_xray(image_path)
    image, cell_height, cell_width = create_grid_overlay(image)

    # Flip the frequency grid so that data for cell "A1" (originally row 0) appears at the bottom.
    grid_flipped = np.flipud(frequency_grid)
    
    fig, ax = plt.subplots(figsize=(6, 6.5))
    ax.imshow(image, alpha=0.9)
    
    colormap = create_colormap(pred_type, style='blocky')
    
    # Determine the maximum frequency from the flipped grid.
    max_freq = np.max(grid_flipped)
    
    height, width = image.shape[:2]
    
    # Draw blocky overlay rectangles based on the flipped grid.
    for i in range(8):
        for j in range(8):
            if grid_flipped[i, j] > 0:
                alpha = min(grid_flipped[i, j] / max_freq, 1.0)
                rect = Rectangle(
                    (j * cell_width, i * cell_height),
                    cell_width, cell_height,
                    facecolor=colormap(1.0)[:3],
                    alpha=alpha * 0.6,
                    edgecolor='none'
                )
                ax.add_patch(rect)
    
    # Draw grid lines.
    for i in range(9):
        ax.axhline(y=i * cell_height, color='white', linewidth=0.8, alpha=0.5)
        ax.axvline(x=i * cell_width, color='white', linewidth=0.8, alpha=0.5)
    
    # Add x-axis labels (letters)
    fontsize_labels = 9
    for i in range(8):
        ax.text(i * cell_width + cell_width/2, height + 5, chr(65 + i),
                ha='center', va='top', color='black', fontsize=fontsize_labels)
    # Add y-axis labels (numbers) so that top row is "8" and bottom is "1"
    for i in range(8):
        ax.text(-5, i * cell_height + cell_height/2, str(8 - i),
                ha='right', va='center', color='black', fontsize=fontsize_labels)

    # Title formatting: split long titles into multiple lines if necessary.
    if len(title) > 30:
        parts = title.split(' - ')
        if len(parts) > 1:
            title = parts[0] + '\n' + ' - '.join(parts[1:])
    ax.set_title(title, pad=3, fontsize=11)
    
    # Remove spines and axis.
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.axis('off')
    
    plt.tight_layout(pad=0.8)
    plt.gcf().patch.set_edgecolor('lightgray')
    plt.gcf().patch.set_linewidth(0.5)
    plt.savefig(output_path, dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.close()

def process_prediction_overlays(model_data, xray_path, view='frontal'):
    """Process all prediction JSONs for multiple models and create overlaid heatmaps."""
    for model_name, json_path in model_data.items():
        print(f"Processing predictions for model: {model_name}")
        
        if not os.path.exists(json_path):
            print(f"Warning: JSON file {json_path} does not exist, skipping model {model_name}")
            continue
        
        try:
            results = read_json_results(json_path)
            
            for condition in results:
                if view not in results[condition]:
                    print(f"Warning: View {view} not found for {model_name} - {condition}")
                    continue
                
                # Extract predictions (output field)
                predictions = [item['output'] for item in results[condition][view]]
                
                if not predictions:
                    print(f"Warning: No predictions found for {model_name} - {condition}")
                    continue
                
                frequencies = calculate_frequencies(predictions)
                frequency_grid = create_frequency_grid(frequencies)
                
                base_name = f"{condition}_{model_name}_zeroshot"
                output_dir = os.path.join('overlaid_heatmaps_multi_model', 'predictions')
                os.makedirs(output_dir, exist_ok=True)
                
                title = f"{condition} - {model_name} - Zero-Shot"
                
                # Generate overlay image
                output_path = os.path.join(output_dir, f"{base_name}.png")
                overlay_heatmap(xray_path, frequency_grid, output_path, title, pred_type='zeroshot')
                print(f"Generated zero-shot overlay for {model_name} - {condition} ({len(predictions)} samples)")
                
        except Exception as e:
            print(f"Error processing {model_name}: {e}")

def process_ground_truth_overlays(model_data, xray_path, view='frontal'):
    """Process ground truth from JSON files and create overlaid heatmaps."""
    # Use the first model's JSON to extract ground truth (should be same across all models)
    first_model_json = list(model_data.values())[0]
    
    if not os.path.exists(first_model_json):
        print(f"Warning: JSON file {first_model_json} does not exist")
        return
    
    try:
        # Process both "all" and "significant" ground truth
        for use_significant in [True, False]:
            gt_results = extract_ground_truth_from_json(first_model_json, use_significant)
            gt_type = "significant" if use_significant else "all"
            
            for condition in gt_results:
                if view not in gt_results[condition]:
                    continue
                
                gt_data = gt_results[condition][view]
                if not gt_data:
                    continue
                
                frequencies = calculate_frequencies(gt_data)
                frequency_grid = create_frequency_grid(frequencies)
                
                base_name = f"{condition}_gt_{gt_type}"
                output_dir = os.path.join('overlaid_heatmaps_multi_model', 'ground_truth')
                os.makedirs(output_dir, exist_ok=True)
                
                title = f"{condition} - Ground Truth - {gt_type.title()}"
                output_path = os.path.join(output_dir, f"{base_name}.png")
                overlay_heatmap(xray_path, frequency_grid, output_path, title, pred_type='gt')
                print(f"Generated ground truth overlay for {condition} - {gt_type} ({len(gt_data)} samples)")
                
    except Exception as e:
        print(f"Error processing ground truth: {e}")

def create_comparison_plot(condition, gt_grid, model_grids, xray_path, output_dir):
    """
    Create a comparison plot with ground truth and multiple model predictions.
    Dynamic subplot arrangement based on number of models.
    """
    num_models = len(model_grids)
    total_plots = num_models + 1  # +1 for ground truth
    
    # Calculate subplot arrangement
    if total_plots <= 2:
        rows, cols = 1, total_plots
        fig_width = 4.5 * total_plots
        fig_height = 4
    elif total_plots <= 4:
        rows, cols = 2, 2
        fig_width = 9
        fig_height = 8
    else:
        # For more models, arrange in rows of 3
        cols = min(3, total_plots)
        rows = (total_plots + cols - 1) // cols
        fig_width = 4.5 * cols
        fig_height = 4 * rows
    
    fig, axes = plt.subplots(rows, cols, figsize=(fig_width, fig_height))
    
    # Handle single subplot case
    if total_plots == 1:
        axes = [axes]
    elif rows == 1 or cols == 1:
        axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
    else:
        axes = axes.flatten()
    
    plt.subplots_adjust(wspace=0.02, hspace=0.1)
    fig.suptitle(f"{condition}", fontsize=12, y=0.95)
    
    image = read_and_preprocess_xray(xray_path)
    image, cell_height, cell_width = create_grid_overlay(image)
    height, width = image.shape[:2]
    
    # First subplot: Ground Truth
    ax = axes[0]
    ax.imshow(image, alpha=0.9)
    colormap = create_colormap('gt', style='blocky')
    freq_grid = np.flipud(gt_grid)
    
    max_freq = np.max(freq_grid) if np.max(freq_grid) > 0 else 1
    for i in range(8):
        for j in range(8):
            if freq_grid[i, j] > 0:
                alpha = min(freq_grid[i, j] / max_freq, 1.0)
                rect = Rectangle((j * cell_width, i * cell_height),
                                 cell_width, cell_height,
                                 facecolor=colormap(1.0)[:3],
                                 alpha=alpha * 0.6,
                                 edgecolor='none')
                ax.add_patch(rect)
    
    # Add grid lines and labels for ground truth
    for i in range(9):
        ax.axhline(y=i * cell_height, color='white', linewidth=0.5, alpha=0.4)
        ax.axvline(x=i * cell_width, color='white', linewidth=0.5, alpha=0.4)
    
    fontsize = 8
    # Add labels to bottom row
    if rows == 1 or (rows > 1 and 0 >= (rows - 1) * cols):
        for i in range(8):
            ax.text(i * cell_width + cell_width/2, height + 5, chr(65 + i),
                    ha='center', va='top', color='black', fontsize=fontsize)
    # Add labels to leftmost column
    if 0 % cols == 0:
        for i in range(8):
            ax.text(-5, i * cell_height + cell_height/2, str(8 - i),
                    ha='right', va='center', color='black', fontsize=fontsize)
    
    ax.set_title('Ground Truth', pad=3, fontsize=9)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.axis('off')
    
    # Remaining subplots: Model predictions
    for idx, (model_name, pred_grid) in enumerate(model_grids.items(), 1):
        ax = axes[idx]
        ax.imshow(image, alpha=0.9)
        colormap = create_colormap('zeroshot', style='blocky')
        freq_grid = np.flipud(pred_grid)
        
        max_freq = np.max(freq_grid) if np.max(freq_grid) > 0 else 1
        for i in range(8):
            for j in range(8):
                if freq_grid[i, j] > 0:
                    alpha = min(freq_grid[i, j] / max_freq, 1.0)
                    rect = Rectangle((j * cell_width, i * cell_height),
                                     cell_width, cell_height,
                                     facecolor=colormap(1.0)[:3],
                                     alpha=alpha * 0.6,
                                     edgecolor='none')
                    ax.add_patch(rect)
        
        # Add grid lines
        for i in range(9):
            ax.axhline(y=i * cell_height, color='white', linewidth=0.5, alpha=0.4)
            ax.axvline(x=i * cell_width, color='white', linewidth=0.5, alpha=0.4)
        
        # Add labels to bottom row
        if rows == 1 or idx >= (rows - 1) * cols:
            for i in range(8):
                ax.text(i * cell_width + cell_width/2, height + 5, chr(65 + i),
                        ha='center', va='top', color='black', fontsize=fontsize)
        # Add labels to leftmost column
        if idx % cols == 0:
            for i in range(8):
                ax.text(-5, i * cell_height + cell_height/2, str(8 - i),
                        ha='right', va='center', color='black', fontsize=fontsize)
        
        ax.set_title(f'{model_name}', pad=3, fontsize=9)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.axis('off')
    
    # Hide unused subplots
    for idx in range(total_plots, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout(pad=0.5, rect=[0, 0, 1, 0.93])
    plt.gcf().patch.set_edgecolor('lightgray')
    plt.gcf().patch.set_linewidth(0.5)
    output_path = os.path.join(output_dir, f"{condition}_comparison.png")
    plt.savefig(output_path, dpi=600, bbox_inches='tight', pad_inches=0.1)
    plt.close()

def process_comparisons(model_data, xray_path, view='frontal'):
    """Generate comparison plots for each condition with ground truth and multiple models."""
    output_dir = os.path.join('overlaid_heatmaps_multi_model', 'comparisons')
    os.makedirs(output_dir, exist_ok=True)
    
    # Use first model's JSON to extract ground truth and get all conditions
    first_model_json = list(model_data.values())[0]
    
    if not os.path.exists(first_model_json):
        print(f"Warning: JSON file {first_model_json} does not exist")
        return
    
    try:
        # Get ground truth (using significant)
        gt_results = extract_ground_truth_from_json(first_model_json, use_significant=True)
        
        # Get all conditions
        all_conditions = set()
        for model_name, json_path in model_data.items():
            if os.path.exists(json_path):
                results = read_json_results(json_path)
                all_conditions.update(results.keys())
        
        for condition in all_conditions:
            if condition not in gt_results or view not in gt_results[condition]:
                print(f"Warning: Ground truth not found for {condition} - {view}")
                continue
            
            gt_data = gt_results[condition][view]
            if not gt_data:
                continue
                
            gt_frequencies = calculate_frequencies(gt_data)
            gt_grid = create_frequency_grid(gt_frequencies)
            
            model_grids = {}
            
            # Collect predictions from all models
            for model_name, json_path in model_data.items():
                if not os.path.exists(json_path):
                    continue
                
                try:
                    results = read_json_results(json_path)
                    
                    if condition in results and view in results[condition]:
                        predictions = [item['output'] for item in results[condition][view]]
                        if predictions:
                            pred_frequencies = calculate_frequencies(predictions)
                            model_grids[model_name] = create_frequency_grid(pred_frequencies)
                        else:
                            print(f"Warning: No predictions found for {model_name} - {condition}")
                    else:
                        print(f"Warning: Condition {condition} not found for {model_name}")
                        
                except Exception as e:
                    print(f"Error processing {model_name} for {condition}: {e}")
            
            # Generate comparison plot if we have at least one model
            if model_grids:
                create_comparison_plot(condition, gt_grid, model_grids, xray_path, output_dir)
                print(f"Generated comparison plot for {condition} with {len(model_grids)} models")
            else:
                print(f"Warning: No model predictions found for {condition}")
                
    except Exception as e:
        print(f"Error processing comparisons: {e}")

if __name__ == "__main__":
    import argparse
    
    # Default model pairs
    default_model_pairs = [
        ("gpt-4o-2024-05-13",
         "./progress_gpt_4o_2024_05_13_updated.json"),
        ("gpt-5",
         "./progress_gpt_5_updated.json"),
        ("medgemma-27b",
         "./progress_puyangwang_medgemma-27b-it_q8_updated.json")
    ]
    
    parser = argparse.ArgumentParser(description='Generate professional heatmap overlays for chest X-ray analysis with multiple models from JSON results')
    parser.add_argument('--xray', type=str, default="./xray.png", help='Path to the X-ray image')
    parser.add_argument('--models', type=str, nargs='*', default=None,
                        help='List of model JSON paths in format "model_name:path" (e.g., "GPT4o:./results_gpt4o.json" "Claude:./results_claude.json"). If not provided, uses default model paths.')
    parser.add_argument('--mode', type=str, choices=['pred', 'gt', 'both'], default='both', help='Which overlays to generate (default: both)')
    parser.add_argument('--comparison', action='store_true', help='Generate comparison plots')
    parser.add_argument('--view', type=str, default='frontal', help='Which view to process (default: frontal)')
    
    args = parser.parse_args()
    
    # Parse model JSON paths
    model_data = {}
    if args.models:
        # Use provided models
        for model_arg in args.models:
            if ':' not in model_arg:
                raise ValueError(f"Model argument must be in format 'model_name:path', got: {model_arg}")
            model_name, model_path = model_arg.split(':', 1)
            model_data[model_name] = model_path
    else:
        # Use default models
        for model_name, model_path in default_model_pairs:
            model_data[model_name] = model_path
    
    print(f"Processing {len(model_data)} models: {list(model_data.keys())}")
    
    if args.comparison:
        process_comparisons(model_data, args.xray, args.view)
    else:
        if args.mode in ['pred', 'both']:
            process_prediction_overlays(model_data, args.xray, args.view)
        if args.mode in ['gt', 'both']:
            process_ground_truth_overlays(model_data, args.xray, args.view)