import unittest
import sys
import os
import json
import tempfile
import shutil
import logging
from pathlib import Path

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from medical_report_evaluator import MedicalReportEvaluator
from utils.metric_aggregator import MetricAggregator
from utils.results_exporter import ResultsExporter


class TestComprehensiveSystem(unittest.TestCase):
    
    def setUp(self):
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger("test")
        
        self.temp_dir = tempfile.mkdtemp()
        
        self.sample_reports = [
            {
                "reference": "The chest X-ray shows clear lung fields bilaterally. Heart size is within normal limits. No acute cardiopulmonary abnormalities are identified.",
                "candidate": "Chest radiograph demonstrates clear lungs on both sides. Normal heart size. No acute findings in the chest."
            },
            {
                "reference": "There is consolidation in the right lower lobe consistent with pneumonia. Small pleural effusion is noted.",
                "candidate": "Right lower lobe pneumonia with associated pleural fluid collection."
            },
            {
                "reference": "No evidence of pneumothorax or pleural effusion. The cardiac silhouette appears normal.",
                "candidate": "No pneumothorax or fluid around the lungs. Heart looks normal."
            }
        ]
        
        self.test_config = {
            "metrics": {
                "bleu": {"enabled": True, "weight": 0.2},
                "rouge": {"enabled": True, "weight": 0.2},
                "meteor": {"enabled": True, "weight": 0.2},
                "bertscore": {"enabled": False, "weight": 0.0},
                "cider": {"enabled": True, "weight": 0.2},
                "medical": {"enabled": True, "weight": 0.2}
            },
            "aggregation": {
                "method": "weighted_average",
                "normalize_scores": True
            },
            "export": {
                "formats": ["json", "csv", "txt"],
                "include_metadata": True,
                "precision": 4
            }
        }
    
    def tearDown(self):
        shutil.rmtree(self.temp_dir, ignore_errors=True)
    
    def test_evaluator_initialization(self):
        try:
            evaluator = MedicalReportEvaluator()
            self.assertIsInstance(evaluator, MedicalReportEvaluator)
            self.logger.info("Evaluator initialization successful")
        except Exception as e:
            self.fail(f"Evaluator initialization failed: {e}")
    
    def test_single_evaluation(self):
        try:
            evaluator = MedicalReportEvaluator()
            
            report = self.sample_reports[0]
            result = evaluator.evaluate_single(
                generated_report=report["candidate"],
                ground_truth_report=report["reference"]
            )
            
            self.assertIsInstance(result, dict)
            self.assertIn("overall_score", result)
            self.assertIn("metrics", result)
            self.assertIn("metadata", result)
            
            overall_score = result["overall_score"]
            self.assertTrue(0.0 <= overall_score <= 1.0)
            
            self.logger.info(f"Single evaluation successful: {overall_score:.4f}")
            
        except Exception as e:
            self.fail(f"Single evaluation failed: {e}")
    
    def test_batch_evaluation(self):
        try:
            evaluator = MedicalReportEvaluator()
            
            report_pairs = [(report["candidate"], report["reference"]) for report in self.sample_reports]
            
            results = evaluator.evaluate_batch(
                report_pairs=report_pairs
            )
            
            self.assertIsInstance(results, dict)
            self.assertIn("individual_results", results)
            self.assertIn("batch_summary", results)
            self.assertIn("aggregate_metrics", results)
            
            individual_results = results["individual_results"]
            self.assertEqual(len(individual_results), len(self.sample_reports))
            
            for result in individual_results:
                self.assertIn("overall_score", result)
                self.assertTrue(0.0 <= result["overall_score"] <= 1.0)
            
            batch_score = results["aggregate_metrics"].get("batch_overall_score", 0.0)
            self.assertTrue(0.0 <= batch_score <= 1.0)
            
            self.logger.info(f"Batch evaluation successful: {batch_score:.4f}")
            
        except Exception as e:
            self.fail(f"Batch evaluation failed: {e}")
    
    def test_metric_aggregator(self):
        try:
            aggregator = MetricAggregator(self.logger)
            
            metric_scores = {
                "bleu": 0.75,
                "rouge_1_f1": 0.68,
                "rouge_2_f1": 0.45,
                "rouge_l_f1": 0.72,
                "meteor": 0.58,
                "cider": 0.82,
                "medical_score": 0.65
            }
            
            result = aggregator.aggregate_scores(
                metric_scores=metric_scores,
                method="weighted_average"
            )
            
            self.assertIsInstance(result, dict)
            self.assertIn("overall_score", result)
            self.assertIn("valid_scores", result)
            self.assertIn("aggregation_details", result)
            
            overall_score = result["overall_score"]
            self.assertTrue(0.0 <= overall_score <= 1.0)
            
            self.logger.info(f"Metric aggregation successful: {overall_score:.4f}")
            
        except Exception as e:
            self.fail(f"Metric aggregation failed: {e}")
    
    def test_results_exporter(self):
        try:
            exporter = ResultsExporter(self.logger)
            
            sample_results = {
                "evaluation_id": "test_001",
                "timestamp": "2024-01-15T10:30:00",
                "overall_score": 0.75,
                "individual_scores": {
                    "bleu": 0.8,
                    "rouge_1_f1": 0.7,
                    "meteor": 0.9
                },
                "metadata": {
                    "reference_length": 150,
                    "candidate_length": 140,
                    "evaluation_time": 2.5
                }
            }
            
            json_path = os.path.join(self.temp_dir, "test_results.json")
            exported_path = exporter.export_results(
                results=sample_results,
                output_path=json_path,
                format="json"
            )
            
            self.assertTrue(os.path.exists(exported_path))
            
            with open(exported_path, 'r') as f:
                loaded_data = json.load(f)
            
            self.assertEqual(loaded_data["evaluation_id"], "test_001")
            self.assertEqual(loaded_data["overall_score"], 0.75)
            
            csv_path = os.path.join(self.temp_dir, "test_results.csv")
            exported_path = exporter.export_results(
                results=sample_results,
                output_path=csv_path,
                format="csv"
            )
            
            self.assertTrue(os.path.exists(exported_path))
            
            txt_path = os.path.join(self.temp_dir, "test_results.txt")
            exported_path = exporter.export_results(
                results=sample_results,
                output_path=txt_path,
                format="txt"
            )
            
            self.assertTrue(os.path.exists(exported_path))
            
            self.logger.info("Results export successful")
            
        except Exception as e:
            self.fail(f"Results export failed: {e}")
    
    def test_end_to_end_workflow(self):
        try:
            evaluator = MedicalReportEvaluator()
            
            report = self.sample_reports[0]
            result = evaluator.evaluate_single(
                generated_report=report["candidate"],
                ground_truth_report=report["reference"]
            )
            
            output_path = os.path.join(self.temp_dir, "end_to_end_results.json")
            exported_path = evaluator.export_results(
                results=result,
                output_path=output_path,
                format="json"
            )
            
            self.assertTrue(os.path.exists(exported_path))
            
            with open(exported_path, 'r') as f:
                exported_data = json.load(f)
            
            self.assertIn("overall_score", exported_data)
            self.assertIn("metrics", exported_data)
            
            self.logger.info("End-to-end workflow successful")
            
        except Exception as e:
            self.fail(f"End-to-end workflow failed: {e}")
    
    def test_error_handling(self):
        try:
            evaluator = MedicalReportEvaluator()
            
            try:
                result = evaluator.evaluate_single("", "some text")
            except ValueError:
                pass
            
            try:
                result = evaluator.evaluate_single(None, "some text")
            except (ValueError, TypeError):
                pass
            
            try:
                result = evaluator.evaluate_batch(
                    report_pairs=[]
                )
            except ValueError:
                pass
            
            self.logger.info("Error handling tests passed")
            
        except Exception as e:
            self.fail(f"Error handling test failed: {e}")
    
    def test_performance_tracking(self):
        try:
            evaluator = MedicalReportEvaluator()
            
            report = self.sample_reports[0]
            result = evaluator.evaluate_single(
                generated_report=report["candidate"],
                ground_truth_report=report["reference"]
            )
            
            if "metadata" in result and "evaluation_time" in result["metadata"]:
                eval_time = result["metadata"]["evaluation_time"]
                self.assertGreater(eval_time, 0.0)
                self.logger.info(f"Performance tracking: {eval_time:.4f}s")
            else:
                self.logger.info("Performance tracking structure verified")
            
        except Exception as e:
            self.fail(f"Performance tracking test failed: {e}")


# Runs end-to-end tests for the complete medical report evaluation system
def run_comprehensive_tests():
    print("Running Comprehensive System Tests...")
    print("=" * 60)
    
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    suite.addTests(loader.loadTestsFromTestCase(TestComprehensiveSystem))
    
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    print("\n" + "=" * 60)
    print(f"Comprehensive System Tests Summary:")
    print(f"   Tests run: {result.testsRun}")
    print(f"   Failures: {len(result.failures)}")
    print(f"   Errors: {len(result.errors)}")
    
    if result.failures:
        print("\nFailures:")
        for test, traceback in result.failures:
            print(f"   - {test}")
            print(f"     {traceback.split('AssertionError:')[-1].strip()}")
    
    if result.errors:
        print("\nErrors:")
        for test, traceback in result.errors:
            print(f"   - {test}")
            error_msg = traceback.split('\n')[-2] if '\n' in traceback else traceback
            print(f"     {error_msg}")
    
    success = len(result.failures) == 0 and len(result.errors) == 0
    
    if success:
        print("\nAll comprehensive system tests passed!")
        print("The medical report evaluation system is working correctly!")
    else:
        print(f"\n{len(result.failures + result.errors)} tests failed!")
        print("System needs fixes before proceeding to Phase 5")
    
    return success, result


if __name__ == "__main__":
    success, _ = run_comprehensive_tests()
    exit(0 if success else 1) 