import os
import random
import sys
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from scipy.stats import pearsonr
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import mean_squared_error

project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from main.medical_report_generator import MedicalReportGenerator


# Formats medical report for console display
def format_medical_report(analysis_result):
    if not analysis_result or 'medical_report' not in analysis_result:
        return "No medical report generated"
    
    report_data = analysis_result['medical_report']
    
    formatted_report = "CLINICAL FINDINGS:\n"
    formatted_report += "-" * 50 + "\n"
    
    report_text = report_data.get('report_text', 'No report text available')
    
    lines = report_text.split('\n')
    for line in lines:
        if line.strip():
            formatted_report += f"   {line.strip()}\n"
    
    return formatted_report


# Gets specific chest X-ray image from MIMIC dataset
def get_specific_xray_image():

    # Hardcoded patient ID
    specific_patient_id = "6c2b39fa-2c251fcf-addd31da-83faee60-044fa8f9"
    
    img_dir = project_root / "data_dump" / "output" / "img_png"
    
    if not img_dir.exists():
        raise FileNotFoundError(f"Image directory not found: {img_dir}")
    
    specific_image = img_dir / f"{specific_patient_id}.png"
    
    if not specific_image.exists():
        raise FileNotFoundError(f"Specific patient image not found: {specific_image}")
    
    print(f"Analyzing specific patient: {specific_patient_id}")
    print(f"Full path: {specific_image}")
    
    return str(specific_image)


# Loads and parses human gaze fixation data
def load_fixation_data(fixations_path):
    try:
        fixations_df = pd.read_csv(fixations_path)
        
        print(f"Loaded {len(fixations_df)} fixation points")
        print(f"Columns: {list(fixations_df.columns)}")
        
        required_cols = ['x_norm', 'y_norm', 'gaze_duration']
        missing_cols = [col for col in required_cols if col not in fixations_df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")
        
        print(f"X range: {fixations_df['x_norm'].min():.3f} - {fixations_df['x_norm'].max():.3f}")
        print(f"Y range: {fixations_df['y_norm'].min():.3f} - {fixations_df['y_norm'].max():.3f}")
        print(f"Duration range: {fixations_df['gaze_duration'].min():.1f} - {fixations_df['gaze_duration'].max():.1f} ms")
        
        return fixations_df
        
    except Exception as e:
        print(f"Error loading fixation data: {e}")
        return None


# Generates heatmap from human gaze fixation data
def generate_gaze_heatmap(fixations_df, heatmap_size=(224, 224), sigma=15):
    if fixations_df is None or len(fixations_df) == 0:
        return np.zeros(heatmap_size)
    
    heatmap = np.zeros(heatmap_size, dtype=np.float32)
    
    height, width = heatmap_size
    
    for _, row in fixations_df.iterrows():
        x_pixel = int(row['x_norm'] * width)
        y_pixel = int(row['y_norm'] * height)
        
        x_pixel = max(0, min(width - 1, x_pixel))
        y_pixel = max(0, min(height - 1, y_pixel))
        
        duration_weight = row['gaze_duration'] / 1000.0
        
        heatmap[y_pixel, x_pixel] += duration_weight
    
    if sigma > 0:
        heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigma)
    
    if heatmap.max() > 0:
        heatmap = heatmap / heatmap.max()
    
    return heatmap


# Calculates spatial correlation metrics between human gaze and model attention
def calculate_attention_metrics(human_heatmap, model_attention):
    if human_heatmap.shape != model_attention.shape:
        model_attention = cv2.resize(model_attention, (human_heatmap.shape[1], human_heatmap.shape[0]))
    
    human_flat = human_heatmap.flatten()
    model_flat = model_attention.flatten()
    
    pearson_corr, pearson_p = pearsonr(human_flat, model_flat)
    
    epsilon = 1e-8
    human_prob = (human_flat + epsilon) / (human_flat.sum() + epsilon * len(human_flat))
    model_prob = (model_flat + epsilon) / (model_flat.sum() + epsilon * len(model_flat))
    js_divergence = jensenshannon(human_prob, model_prob)
    
    mse = mean_squared_error(human_flat, model_flat)
    
    nss_scores = []
    for _, row in pd.DataFrame({'x_norm': human_heatmap.shape[1] * np.random.random(100), 
                               'y_norm': human_heatmap.shape[0] * np.random.random(100)}).iterrows():
        x_idx = int(min(row['x_norm'], model_attention.shape[1] - 1))
        y_idx = int(min(row['y_norm'], model_attention.shape[0] - 1))
        nss_scores.append(model_attention[y_idx, x_idx])
    
    nss = np.mean(nss_scores) if nss_scores else 0.0
    
    return {
        'pearson_correlation': pearson_corr,
        'pearson_p_value': pearson_p,
        'jensen_shannon_divergence': js_divergence,
        'mean_squared_error': mse,
        'normalized_scanpath_saliency': nss,
        'human_attention_entropy': -np.sum(human_prob * np.log(human_prob + epsilon)),
        'model_attention_entropy': -np.sum(model_prob * np.log(model_prob + epsilon))
    }


# Saves visualization comparing human gaze and model attention
def save_attention_comparison(human_heatmap, model_attention, metrics, patient_id, output_dir):
    output_dir.mkdir(exist_ok=True)
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    fig.suptitle(f'Gaze-Attention Comparison: {patient_id}', fontsize=16, fontweight='bold')
    
    im1 = axes[0, 0].imshow(human_heatmap, cmap='hot', interpolation='bilinear')
    axes[0, 0].set_title('Human Radiologist Gaze Heatmap', fontweight='bold')
    axes[0, 0].axis('off')
    plt.colorbar(im1, ax=axes[0, 0], fraction=0.046, pad=0.04)
    
    im2 = axes[0, 1].imshow(model_attention, cmap='hot', interpolation='bilinear')
    axes[0, 1].set_title('Model Attention Heatmap (Grad-CAM)', fontweight='bold')
    axes[0, 1].axis('off')
    plt.colorbar(im2, ax=axes[0, 1], fraction=0.046, pad=0.04)
    
    diff_map = np.abs(human_heatmap - model_attention)
    im3 = axes[1, 0].imshow(diff_map, cmap='viridis', interpolation='bilinear')
    axes[1, 0].set_title('Absolute Difference Map', fontweight='bold')
    axes[1, 0].axis('off')
    plt.colorbar(im3, ax=axes[1, 0], fraction=0.046, pad=0.04)
    
    axes[1, 1].axis('off')
    metrics_text = f"""CORRELATION METRICS:

Pearson Correlation: {metrics['pearson_correlation']:.4f}
P-value: {metrics['pearson_p_value']:.4f}

Jensen-Shannon Divergence: {metrics['jensen_shannon_divergence']:.4f}
Mean Squared Error: {metrics['mean_squared_error']:.4f}

Normalized Scanpath Saliency: {metrics['normalized_scanpath_saliency']:.4f}

Human Attention Entropy: {metrics['human_attention_entropy']:.4f}
Model Attention Entropy: {metrics['model_attention_entropy']:.4f}

INTERPRETATION:
• Pearson > 0.3: Good spatial correlation
• JS Divergence < 0.5: Similar distributions  
• Lower MSE: Better pixel-wise match
• Higher NSS: Better fixation prediction
"""
    
    axes[1, 1].text(0.05, 0.95, metrics_text, transform=axes[1, 1].transAxes, 
                   fontsize=10, verticalalignment='top', fontfamily='monospace',
                   bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    
    plt.tight_layout()
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    plot_filename = f"gaze_attention_comparison_{patient_id}_{timestamp}.png"
    plot_path = output_dir / plot_filename
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Gaze-attention comparison saved: {plot_path}")
    return plot_path



# Gets fixation data path for specific patient from dataset
def get_patient_fixation_path(patient_id):
    dataset_path = project_root / "final_dataset_fixed.csv"
    
    try:
        chunk_size = 1000
        for chunk in pd.read_csv(dataset_path, chunksize=chunk_size):
            patient_row = chunk[chunk['dicom_id'] == patient_id]
            if not patient_row.empty:
                fixations_path = patient_row['fixations_path'].iloc[0]
                print(f"Found fixation data: {fixations_path}")
                return fixations_path
        
        print(f"Patient {patient_id} not found in dataset")
        return None
        
    except Exception as e:
        print(f"Error reading dataset: {e}")
        return None


# Runs comprehensive X-ray analysis with gaze-attention validation
def main():
    print("Starting Specific Patient X-ray Analysis with Gaze-Attention Validation...")
    print("=" * 80)
    
    try:
        print("\nLoading specific patient X-ray image...")
        image_path = get_specific_xray_image()
        patient_id = Path(image_path).stem
        
        print("\nLoading human gaze fixation data...")
        fixations_path = get_patient_fixation_path(patient_id)
        
        if fixations_path:
            abs_fixations_path = project_root / fixations_path
            fixations_df = load_fixation_data(abs_fixations_path)
            
            if fixations_df is not None:
                print("Human gaze data loaded successfully!")
                print("Generating human gaze heatmap...")
                human_gaze_heatmap = generate_gaze_heatmap(fixations_df, heatmap_size=(224, 224))
                print(f"Human gaze heatmap shape: {human_gaze_heatmap.shape}")
                print(f"Human gaze intensity range: {human_gaze_heatmap.min():.4f} - {human_gaze_heatmap.max():.4f}")
            else:
                print("Could not load human gaze data, continuing without gaze comparison")
                human_gaze_heatmap = None
        else:
            print("No fixation data found for this patient, continuing without gaze comparison")
            human_gaze_heatmap = None
        
        print("\nInitializing Medical Report Generator...")
        print("Loading models and dependencies...")
        generator = MedicalReportGenerator()
        print("Generator initialized successfully!")
        
        print("\nAnalyzing chest X-ray...")
        print("• Running multimodal classification...")
        print("• Extracting clinical keywords...")
        print("• Generating attention analysis...")
        print("• Creating medical report with LLM...")
        print("UPDATED: Including image in LLM request for enhanced analysis")
        
        result = generator.generate_comprehensive_analysis(
            image_path=image_path,
            template="standard",
            include_image_in_llm=True
        )
        
        if human_gaze_heatmap is not None and 'attention_analysis' in result:
            print("\nPerforming Gaze-Attention Validation (Priority 1)...")
            
            attention_data = result['attention_analysis']
            print(f"DEBUG: attention_analysis keys: {list(attention_data.keys())}")
            print(f"DEBUG: attention_analysis content preview: {str(attention_data)[:200]}...")
            
            if 'attention_map' in attention_data:
                model_attention = attention_data['attention_map']
                print(f"Model attention shape: {model_attention.shape}")
                print(f"Model attention range: {model_attention.min():.4f} - {model_attention.max():.4f}")
                
                print("Calculating spatial correlation metrics...")
                correlation_metrics = calculate_attention_metrics(human_gaze_heatmap, model_attention)
                
                print(f"Pearson Correlation: {correlation_metrics['pearson_correlation']:.4f}")
                print(f"Jensen-Shannon Divergence: {correlation_metrics['jensen_shannon_divergence']:.4f}")
                print(f"Mean Squared Error: {correlation_metrics['mean_squared_error']:.4f}")
                
                print("Saving gaze-attention comparison visualization...")
                output_dir = project_root / "main" / "gaze_attention_analysis"
                comparison_plot = save_attention_comparison(
                    human_gaze_heatmap, model_attention, correlation_metrics, patient_id, output_dir
                )
                
                result['gaze_attention_validation'] = {
                    'correlation_metrics': correlation_metrics,
                    'human_gaze_heatmap_stats': {
                        'shape': human_gaze_heatmap.shape,
                        'min': float(human_gaze_heatmap.min()),
                        'max': float(human_gaze_heatmap.max()),
                        'mean': float(human_gaze_heatmap.mean())
                    },

                    'comparison_plot_path': str(comparison_plot)
                }
                
                print("Gaze-Attention Validation completed!")
            else:
                print("No attention map found in model results")
                print(f"Available keys: {list(attention_data.keys())}")
                print("Need to modify MedicalReportGenerator to save raw attention maps")
        else:
            print("Skipping gaze-attention validation (no human gaze data or attention analysis)")
        
        try:
            if 'llm_response' in result and 'raw_response' in result['llm_response']:
                print("Extracting and saving AI prompt for debugging...")
                
                debug_prompt = generator.create_medical_report_prompt(
                    condition_predictions=result.get('model_predictions', {}),
                    prediction_keywords=result.get('prediction_keywords', {}),
                    spatial_keywords=result.get('spatial_keywords', []),
                    attention_analysis=result.get('attention_analysis', {}),
                    relevant_anatomical_regions=result.get('relevant_anatomical_regions', {}),
                    template="standard",
                    patient_info=None,
                    include_image=False
                )
                
                prompt_file = project_root / "main" / "debug_ai_prompt.txt"
                with open(prompt_file, 'w', encoding='utf-8') as f:
                    f.write("=" * 80 + "\n")
                    f.write("AI PROMPT DEBUG OUTPUT\n")
                    f.write(f"Generated at: {datetime.now()}\n")
                    f.write(f"Patient ID: {Path(image_path).stem}\n")
                    f.write("=" * 80 + "\n\n")
                    
                    f.write(f"PROMPT STATISTICS:\n")
                    f.write(f"- Total Characters: {len(debug_prompt):,}\n")
                    f.write(f"- Total Words: {len(debug_prompt.split()):,}\n")
                    f.write(f"- Estimated Tokens: ~{len(debug_prompt.split()) * 1.3:.0f}\n")
                    f.write(f"- Size in KB: {len(debug_prompt.encode('utf-8')) / 1024:.2f} KB\n")
                    f.write("\n" + "=" * 80 + "\n\n")
                    
                    f.write("FULL PROMPT CONTENT:\n")
                    f.write("-" * 40 + "\n")
                    f.write(debug_prompt)
                
                print(f"AI prompt saved to: {prompt_file}")
                print(f"Prompt size: {len(debug_prompt):,} characters ({len(debug_prompt.encode('utf-8')) / 1024:.2f} KB)")
                print(f"Estimated tokens: ~{len(debug_prompt.split()) * 1.3:.0f}")
                
        except Exception as e:
            print(f"Could not save debug prompt: {e}")
        
        print("\nAnalysis complete! Formatting results...")
        print("\n")
        
        formatted_report = format_medical_report(result)
        print(formatted_report)
        
        if 'gaze_attention_validation' in result:
            print("\n" + "=" * 60)
            print("GAZE-ATTENTION VALIDATION RESULTS (Priority 1)")
            print("=" * 60)
            
            metrics = result['gaze_attention_validation']['correlation_metrics']
            print(f"Spatial Correlation Analysis:")
            print(f"   • Pearson Correlation: {metrics['pearson_correlation']:.4f}")
            print(f"   • P-value: {metrics['pearson_p_value']:.4f}")
            print(f"   • Jensen-Shannon Divergence: {metrics['jensen_shannon_divergence']:.4f}")
            print(f"   • Mean Squared Error: {metrics['mean_squared_error']:.4f}")
            print(f"   • Normalized Scanpath Saliency: {metrics['normalized_scanpath_saliency']:.4f}")
            
            print(f"\nInterpretation:")
            corr = metrics['pearson_correlation']
            if corr > 0.3:
                print(f"   GOOD spatial correlation ({corr:.4f}) - model attention aligns with human gaze")
            elif corr > 0.1:
                print(f"   MODERATE spatial correlation ({corr:.4f}) - some alignment but room for improvement")
            else:
                print(f"   POOR spatial correlation ({corr:.4f}) - model attention differs significantly from human gaze")
            
            js_div = metrics['jensen_shannon_divergence']
            if js_div < 0.3:
                print(f"   SIMILAR attention distributions (JS={js_div:.4f})")
            elif js_div < 0.6:
                print(f"   MODERATELY different distributions (JS={js_div:.4f})")
            else:
                print(f"   VERY different attention distributions (JS={js_div:.4f})")
            
            print(f"\nVisualization saved: {result['gaze_attention_validation']['comparison_plot_path']}")
        
        output_file = project_root / "main" / "latest_fixed_analysis.txt"
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(formatted_report)
        
        print(f"Report also saved to: {output_file}")
        
    except Exception as e:
        print(f"\nError during analysis: {str(e)}")
        print("\nPlease ensure:")
        print("1. Virtual environment is activated")
        print("2. LLM server is running at http://127.0.0.1:1234")
        print("3. All model files are available")
        print("4. CUDA GPU is available (recommended)")
        
        return 1
    
    return 0


if __name__ == "__main__":
    exit_code = main()
    sys.exit(exit_code) 