import json
import os
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set
from datetime import datetime

import sys
sys.path.append(str(Path(__file__).parents[1]))

from utils.ground_truth_loader import GroundTruthLoader
from utils.text_extractor import ReportTextExtractor
from utils.ground_truth_validator import GroundTruthValidator


class GroundTruthIndexBuilder:
    
    def __init__(self, mimic_base_path: str, logger: Optional[logging.Logger] = None):
        self.mimic_base_path = Path(mimic_base_path)
        self.logger = logger or logging.getLogger(__name__)
        
        self.loader = GroundTruthLoader(str(self.mimic_base_path), self.logger)
        self.text_extractor = ReportTextExtractor(self.logger)
        self.validator = GroundTruthValidator(self.logger)
        
        self.index_data = {}
        self.index_metadata = {}
    
    # Build comprehensive index of all available ground truth reports
    def build_comprehensive_index(self, source_types: List[str] = ["radiology"]) -> Dict:
        self.logger.info(f"Building comprehensive index for sources: {source_types}")
        
        index = {
            "metadata": {
                "creation_date": datetime.now().isoformat(),
                "mimic_base_path": str(self.mimic_base_path),
                "source_types": source_types,
                "total_reports": 0,
                "total_images": 0,
                "sources": {}
            },
            "image_to_reports": {},
            "report_to_image": {},
            "source_statistics": {},
            "validation_summary": {}
        }
        
        all_reports = {}
        
        for source_type in source_types:
            self.logger.info(f"Processing {source_type} reports...")
            
            report_files = self.loader.find_report_files(source_type)
            source_files = report_files.get(source_type, [])
            
            self.logger.info(f"Found {len(source_files)} {source_type} files")
            
            source_reports = {}
            for file_path in source_files:
                report_data = self.loader.load_single_report(file_path)
                if report_data and report_data.get("image_id"):
                    image_id = report_data["image_id"]
                    source_reports[image_id] = report_data
            
            index["metadata"]["sources"][source_type] = {
                "file_count": len(source_files),
                "loaded_count": len(source_reports),
                "success_rate": len(source_reports) / len(source_files) * 100 if source_files else 0
            }
            
            all_reports.update(source_reports)
        
        for image_id, report_data in all_reports.items():
            report_info = {
                "image_id": image_id,
                "file_path": report_data["file_path"],
                "source_type": report_data["source_type"],
                "text_length": report_data["text_length"],
                "has_timestamps": report_data["has_timestamps"],
                "timestamp_count": report_data["timestamp_count"],
                "loaded_at": report_data["loaded_at"]
            }
            
            if image_id not in index["image_to_reports"]:
                index["image_to_reports"][image_id] = []
            index["image_to_reports"][image_id].append(report_info)
            
            index["report_to_image"][report_data["file_path"]] = image_id
        
        index["metadata"]["total_reports"] = len(all_reports)
        index["metadata"]["total_images"] = len(index["image_to_reports"])
        
        index["source_statistics"] = self._generate_source_statistics(all_reports)
        
        validation_results = self._validate_index(index, all_reports)
        index["validation_summary"] = validation_results
        
        self.index_data = index
        self.index_metadata = index["metadata"]
        
        self.logger.info(f"Index building complete: {index['metadata']['total_images']} images, "
                        f"{index['metadata']['total_reports']} reports")
        
        return index
    
    # Generate statistics about different sources in the index
    def _generate_source_statistics(self, all_reports: Dict) -> Dict:
        stats = {
            "by_source": {},
            "text_length_distribution": {},
            "timestamp_availability": {},
            "quality_scores": {}
        }
        
        by_source = {}
        for report_data in all_reports.values():
            source = report_data["source_type"]
            if source not in by_source:
                by_source[source] = []
            by_source[source].append(report_data)
        
        for source, reports in by_source.items():
            source_stats = {
                "count": len(reports),
                "avg_text_length": sum(r["text_length"] for r in reports) / len(reports),
                "min_text_length": min(r["text_length"] for r in reports),
                "max_text_length": max(r["text_length"] for r in reports),
                "timestamp_percentage": sum(1 for r in reports if r["has_timestamps"]) / len(reports) * 100,
                "avg_timestamp_count": sum(r["timestamp_count"] for r in reports) / len(reports)
            }
            stats["by_source"][source] = source_stats
        
        return stats
    
    # Validate the completeness and consistency of the index
    def _validate_index(self, index: Dict, all_reports: Dict) -> Dict:
        validation = {
            "is_valid": True,
            "issues": [],
            "statistics": {},
            "completeness_score": 0.0
        }
        
        image_count = len(index["image_to_reports"])
        report_count = len(index["report_to_image"])
        expected_reports = len(all_reports)
        
        if report_count != expected_reports:
            validation["is_valid"] = False
            validation["issues"].append(f"Report count mismatch: expected {expected_reports}, got {report_count}")
        
        missing_mappings = 0
        duplicate_mappings = 0
        
        for image_id, reports in index["image_to_reports"].items():
            if not reports:
                missing_mappings += 1
            elif len(reports) > 1:
                duplicate_mappings += 1
        
        if missing_mappings > 0:
            validation["issues"].append(f"Missing mappings: {missing_mappings} images")
        
        if duplicate_mappings > 0:
            validation["issues"].append(f"Duplicate mappings: {duplicate_mappings} images")
        
        inconsistent_mappings = 0
        for report_path, image_id in index["report_to_image"].items():
            if image_id not in index["image_to_reports"]:
                inconsistent_mappings += 1
            else:
                found = False
                for report_info in index["image_to_reports"][image_id]:
                    if report_info["file_path"] == report_path:
                        found = True
                        break
                if not found:
                    inconsistent_mappings += 1
        
        if inconsistent_mappings > 0:
            validation["is_valid"] = False
            validation["issues"].append(f"Inconsistent mappings: {inconsistent_mappings} reports")
        
        completeness_factors = []
        
        if expected_reports > 0:
            coverage = report_count / expected_reports
            completeness_factors.append(coverage)
        
        consistency = 1.0 - (inconsistent_mappings / max(report_count, 1))
        completeness_factors.append(consistency)
        
        mapping_quality = 1.0 - (missing_mappings / max(image_count, 1))
        completeness_factors.append(mapping_quality)
        
        validation["completeness_score"] = sum(completeness_factors) / len(completeness_factors)
        
        validation["statistics"] = {
            "total_images": image_count,
            "total_reports": report_count,
            "expected_reports": expected_reports,
            "missing_mappings": missing_mappings,
            "duplicate_mappings": duplicate_mappings,
            "inconsistent_mappings": inconsistent_mappings,
            "completeness_score": validation["completeness_score"]
        }
        
        return validation
    
    # Export the index to a JSON file
    def export_index(self, output_path: str, include_validation: bool = True) -> None:
        if not self.index_data:
            self.logger.error("No index data to export. Build index first.")
            return
        
        try:
            export_data = self.index_data.copy()
            
            if not include_validation:
                export_data.pop("validation_summary", None)
            
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(export_data, f, indent=2, ensure_ascii=False)
            
            self.logger.info(f"Index exported to: {output_path}")
            
        except Exception as e:
            self.logger.error(f"Failed to export index: {e}")
    
    # Load an existing index from a JSON file
    def load_index(self, index_path: str) -> Dict:
        try:
            with open(index_path, 'r', encoding='utf-8') as f:
                index_data = json.load(f)
            
            self.index_data = index_data
            self.index_metadata = index_data.get("metadata", {})
            
            self.logger.info(f"Index loaded from: {index_path}")
            self.logger.info(f"Index contains {len(index_data.get('image_to_reports', {}))} images")
            
            return index_data
            
        except Exception as e:
            self.logger.error(f"Failed to load index: {e}")
            return {}
    
    # Get all reports associated with an image ID
    def get_reports_for_image(self, image_id: str) -> List[Dict]:
        if not self.index_data:
            return []
        
        return self.index_data.get("image_to_reports", {}).get(image_id, [])
    
    # Get the image ID associated with a report path
    def get_image_for_report(self, report_path: str) -> Optional[str]:
        if not self.index_data:
            return None
        
        return self.index_data.get("report_to_image", {}).get(report_path)
    
    # Get comprehensive statistics about the current index
    def get_index_statistics(self) -> Dict:
        if not self.index_data:
            return {"error": "No index data available"}
        
        return {
            "metadata": self.index_metadata,
            "source_statistics": self.index_data.get("source_statistics", {}),
            "validation_summary": self.index_data.get("validation_summary", {}),
            "current_status": {
                "total_images": len(self.index_data.get("image_to_reports", {})),
                "total_reports": len(self.index_data.get("report_to_image", {})),
                "is_loaded": bool(self.index_data)
            }
        }


# Test the Ground Truth Index Builder functionality
def test_ground_truth_index_builder():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    mimic_path = "../../mimic-eye-integrating-mimic-datasets-with-reflacx-and-eye-gaze-for-multimodal-deep-learning-applications-1.0.0/mimic-eye/"
    
    try:
        print("Testing Ground Truth Index Builder...")
        
        builder = GroundTruthIndexBuilder(mimic_path, logger)
        
        print("Building test index...")
        index = builder.build_comprehensive_index(["radiology"])
        
        if index:
            print(f"Index built successfully:")
            print(f"   Total images: {index['metadata']['total_images']}")
            print(f"   Total reports: {index['metadata']['total_reports']}")
            print(f"   Sources: {list(index['metadata']['sources'].keys())}")
            
            validation = index.get('validation_summary', {})
            print(f"   Validation: {'Valid' if validation.get('is_valid') else 'Invalid'}")
            print(f"   Completeness score: {validation.get('completeness_score', 0):.2f}")
            
            image_to_reports = index.get('image_to_reports', {})
            if image_to_reports:
                sample_image_id = list(image_to_reports.keys())[0]
                reports = builder.get_reports_for_image(sample_image_id)
                print(f"   Sample lookup: {len(reports)} reports for image {sample_image_id}")
            
            test_export_path = "test_index.json"
            builder.export_index(test_export_path)
            
            loaded_index = builder.load_index(test_export_path)
            if loaded_index:
                print("Index export/load successful")
            
            if os.path.exists(test_export_path):
                os.remove(test_export_path)
        
        print("\nAll index builder tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        return False


if __name__ == "__main__":
    success = test_ground_truth_index_builder()
    
    if success:
        print("\nGround Truth Index Builder tests passed!")
    else:
        print("\nSome tests failed!") 