import unittest
import sys
import os
import logging
import warnings
import numpy as np
from unittest.mock import Mock, patch, MagicMock
from typing import Dict, Any

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from metrics.bleu_scorer import BLEUScorer
from metrics.rouge_scorer import ROUGEScorer
from metrics.meteor_scorer import METEORScorer
from metrics.bertscore_scorer import BERTScoreScorer
from metrics.cider_scorer import CIDErScorer
from metrics.medical_scorer import MedicalScorer


class TestBLEUScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.bleu = BLEUScorer(self.logger)
    
    def test_bleu_instantiation(self):
        self.assertEqual(self.bleu.get_name(), "BLEU")
        self.assertIn("bleu", self.bleu.get_description().lower())
        self.assertEqual(self.bleu.get_metric_type(), "n-gram")
    
    def test_identical_texts(self):
        text = "The patient has clear lung fields."
        result = self.bleu.calculate(text, text)
        
        self.assertIn("bleu_1", result)
        self.assertIn("bleu_2", result)
        self.assertIn("bleu_3", result)
        self.assertIn("bleu_4", result)
        
        self.assertAlmostEqual(result["bleu_1"], 1.0, places=2)
    
    def test_completely_different_texts(self):
        ref = "The patient has clear lung fields."
        cand = "Elephant dancing in the moonlight."
        
        result = self.bleu.calculate(ref, cand)
        
        self.assertLess(result["bleu_1"], 0.5)
        self.assertLess(result["bleu_4"], 0.1)
    
    def test_empty_texts(self):
        with self.assertRaises(ValueError):
            self.bleu.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.bleu.calculate("some text", "")
    
    def test_whitespace_only_texts(self):
        with self.assertRaises(ValueError):
            self.bleu.calculate("   ", "some text")
        
        with self.assertRaises(ValueError):
            self.bleu.calculate("some text", "\t\n  ")
    
    def test_single_word_texts(self):
        result = self.bleu.calculate("pneumonia", "pneumonia")
        self.assertAlmostEqual(result["bleu_1"], 1.0, places=2)
        
        result = self.bleu.calculate("pneumonia", "infection")
        self.assertAlmostEqual(result["bleu_1"], 0.0, places=2)
    
    def test_case_sensitivity(self):
        ref = "The Patient Has Clear Lung Fields."
        cand = "the patient has clear lung fields."
        
        result = self.bleu.calculate(ref, cand)
        self.assertGreater(result["bleu_1"], 0.8)
    
    def test_punctuation_handling(self):
        ref = "The patient has clear lung fields."
        cand = "The patient has clear lung fields!"
        
        result = self.bleu.calculate(ref, cand)
        self.assertGreater(result["bleu_1"], 0.8)
    
    def test_smoothing_methods(self):
        ref = "The patient shows improvement."
        cand = "Patient shows some improvement."
        
        config_add_one = {"smoothing_method": "add_one"}
        self.bleu.configure(config_add_one)
        result1 = self.bleu.calculate(ref, cand)
        
        config_epsilon = {"smoothing_method": "epsilon"}
        self.bleu.configure(config_epsilon)
        result2 = self.bleu.calculate(ref, cand)
        
        self.assertNotEqual(result1["bleu_4"], result2["bleu_4"])
    
    def test_boundary_values(self):
        long_text = "word " * 1000
        result = self.bleu.calculate(long_text, long_text)
        self.assertAlmostEqual(result["bleu_1"], 1.0, places=2)
        
        result = self.bleu.calculate("a", "a")
        self.assertAlmostEqual(result["bleu_1"], 1.0, places=2)


class TestROUGEScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.rouge = ROUGEScorer(self.logger)
    
    def test_rouge_instantiation(self):
        self.assertEqual(self.rouge.get_name(), "ROUGE")
        self.assertIn("rouge", self.rouge.get_description().lower())
        self.assertEqual(self.rouge.get_metric_type(), "n-gram")
    
    def test_rouge_calculation(self):
        ref = "The patient has clear lung fields and normal heart size."
        cand = "Patient shows clear lungs and normal heart."
        
        result = self.rouge.calculate(ref, cand)
        
        self.assertIn("rouge_1_f", result)
        self.assertIn("rouge_1_p", result)
        self.assertIn("rouge_1_r", result)
        self.assertIn("rouge_2_f", result)
        self.assertIn("rouge_l_f", result)
        
        for score in result.values():
            self.assertTrue(0.0 <= score <= 1.0)
    
    def test_identical_texts_rouge(self):
        text = "The patient has clear lung fields."
        result = self.rouge.calculate(text, text)
        
        self.assertAlmostEqual(result["rouge_1_f"], 1.0, places=2)
        self.assertAlmostEqual(result["rouge_l_f"], 1.0, places=2)
    
    def test_empty_text_rouge(self):
        with self.assertRaises(ValueError):
            self.rouge.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.rouge.calculate("some text", "")
    
    def test_no_overlap(self):
        ref = "pneumonia consolidation infiltrate"
        cand = "elephant monkey zebra"
        
        result = self.rouge.calculate(ref, cand)
        self.assertAlmostEqual(result["rouge_1_f"], 0.0, places=2)
        self.assertAlmostEqual(result["rouge_2_f"], 0.0, places=2)
    
    def test_partial_overlap(self):
        ref = "The patient has pneumonia and fever."
        cand = "Patient shows pneumonia symptoms."
        
        result = self.rouge.calculate(ref, cand)
        self.assertGreater(result["rouge_1_f"], 0.0)
        self.assertLess(result["rouge_1_f"], 1.0)
    
    def test_lcs_calculation(self):
        ref = "A B C D E"
        cand = "A C E"
        
        result = self.rouge.calculate(ref, cand)
        self.assertGreater(result["rouge_l_f"], 0.0)
    
    def test_rouge_precision_recall(self):
        ref = "patient pneumonia"
        cand = "patient pneumonia infection"
        
        result = self.rouge.calculate(ref, cand)
        
        self.assertAlmostEqual(result["rouge_1_r"], 1.0, places=2)
        self.assertAlmostEqual(result["rouge_1_p"], 2/3, places=2)


class TestMETEORScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.meteor = METEORScorer(self.logger)
    
    def test_meteor_instantiation(self):
        self.assertEqual(self.meteor.get_name(), "METEOR")
        self.assertIn("meteor", self.meteor.get_description().lower())
        self.assertEqual(self.meteor.get_metric_type(), "alignment")
    
    def test_meteor_calculation(self):
        ref = "The patient has clear lung fields."
        cand = "Patient shows clear lungs."
        
        result = self.meteor.calculate(ref, cand)
        
        self.assertIn("meteor", result)
        self.assertTrue(0.0 <= result["meteor"] <= 1.0)
    
    def test_identical_texts_meteor(self):
        text = "The patient has pneumonia."
        result = self.meteor.calculate(text, text)
        
        self.assertAlmostEqual(result["meteor"], 1.0, places=2)
    
    def test_empty_text_meteor(self):
        with self.assertRaises(ValueError):
            self.meteor.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.meteor.calculate("some text", "")
    
    def test_stemming_functionality(self):
        ref = "The patient is walking slowly."
        cand = "The patient walks slow."
        
        config = {"use_stemming": True}
        self.meteor.configure(config)
        
        result_with_stem = self.meteor.calculate(ref, cand)
        
        config = {"use_stemming": False}
        self.meteor.configure(config)
        
        result_without_stem = self.meteor.calculate(ref, cand)
        
        self.assertGreaterEqual(result_with_stem["meteor"], result_without_stem["meteor"])
    
    def test_medical_synonyms(self):
        ref = "The patient has pneumonia."
        cand = "The patient has lung infection."
        
        config = {"use_medical_synonyms": True}
        self.meteor.configure(config)
        
        result = self.meteor.calculate(ref, cand)
        self.assertGreater(result["meteor"], 0.0)
    
    @patch('nltk.translate.meteor_score.meteor_score')
    def test_meteor_with_mock(self, mock_meteor):
        mock_meteor.return_value = 0.75
        
        ref = "Test reference"
        cand = "Test candidate"
        
        result = self.meteor.calculate(ref, cand)
        self.assertEqual(result["meteor"], 0.75)
    
    def test_meteor_error_handling(self):
        ref = "Patient has @@@ symptoms."
        cand = "Patient shows ### signs."
        
        result = self.meteor.calculate(ref, cand)
        self.assertIsInstance(result["meteor"], float)


class TestBERTScoreScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.bertscore = BERTScoreScorer(self.logger)
    
    def test_bertscore_instantiation(self):
        self.assertEqual(self.bertscore.get_name(), "BERTScore")
        self.assertIn("bert", self.bertscore.get_description().lower())
        self.assertEqual(self.bertscore.get_metric_type(), "semantic")
    
    @patch('bert_score.score')
    def test_bertscore_calculation(self, mock_bert_score):
        mock_bert_score.return_value = (
            [0.85],  
            [0.80],  
            [0.82]   
        )
        
        ref = "The patient has clear lung fields."
        cand = "Patient shows clear lungs."
        
        result = self.bertscore.calculate(ref, cand)
        
        self.assertIn("bert_precision", result)
        self.assertIn("bert_recall", result)
        self.assertIn("bert_f1", result)
        
        self.assertEqual(result["bert_precision"], 0.85)
        self.assertEqual(result["bert_recall"], 0.80)
        self.assertEqual(result["bert_f1"], 0.82)
    
    @patch('bert_score.score')
    def test_bertscore_identical_texts(self, mock_bert_score):
        mock_bert_score.return_value = ([1.0], [1.0], [1.0])
        
        text = "The patient has pneumonia."
        result = self.bertscore.calculate(text, text)
        
        self.assertEqual(result["bert_f1"], 1.0)
    
    def test_bertscore_fallback(self):
        ref = "The patient has pneumonia."
        cand = "Patient shows lung infection."
        
        result = self.bertscore.calculate(ref, cand)
        
        self.assertIn("bert_precision", result)
        self.assertIn("bert_recall", result)
        self.assertIn("bert_f1", result)
        
        for score in result.values():
            self.assertTrue(0.0 <= score <= 1.0)
    
    def test_bertscore_empty_text(self):
        with self.assertRaises(ValueError):
            self.bertscore.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.bertscore.calculate("some text", "")
    
    @patch('bert_score.score')
    def test_bertscore_model_configuration(self, mock_bert_score):
        mock_bert_score.return_value = ([0.75], [0.75], [0.75])
        
        config = {"model_type": "distilbert-base-uncased"}
        self.bertscore.configure(config)
        
        ref = "Patient has pneumonia."
        cand = "Patient shows infection."
        
        result = self.bertscore.calculate(ref, cand)
        
        mock_bert_score.assert_called_once()
        self.assertEqual(result["bert_f1"], 0.75)


class TestCIDErScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.cider = CIDErScorer(self.logger)
    
    def test_cider_instantiation(self):
        self.assertEqual(self.cider.get_name(), "CIDEr")
        self.assertIn("cider", self.cider.get_description().lower())
        self.assertEqual(self.cider.get_metric_type(), "consensus")
    
    def test_cider_calculation(self):
        ref = "The patient has clear lung fields."
        cand = "Patient shows clear lungs."
        
        result = self.cider.calculate(ref, cand)
        
        self.assertIn("cider", result)
        self.assertIsInstance(result["cider"], float)
    
    def test_cider_identical_texts(self):
        text = "The patient has pneumonia."
        result = self.cider.calculate(text, text)
        
        self.assertGreater(result["cider"], 0.5)
    
    def test_multiple_references(self):
        refs = [
            "The patient has pneumonia.",
            "Patient shows lung infection.",
            "Pneumonia detected in patient."
        ]
        cand = "Patient has lung infection."
        
        result = self.cider.calculate(refs, cand)
        
        self.assertIn("cider", result)
        self.assertIsInstance(result["cider"], float)
    
    def test_empty_text_cider(self):
        with self.assertRaises(ValueError):
            self.cider.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.cider.calculate("some text", "")
    
    def test_tf_idf_weighting(self):
        ref = "the the the patient has pneumonia"
        cand = "the patient shows pneumonia"
        
        result = self.cider.calculate(ref, cand)
        
        self.assertIn("cider", result)
        self.assertIsInstance(result["cider"], float)
    
    def test_cider_no_overlap(self):
        ref = "pneumonia consolidation infiltrate"
        cand = "elephant monkey zebra"
        
        result = self.cider.calculate(ref, cand)
        
        self.assertLessEqual(result["cider"], 0.1)


class TestMedicalScorer(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.medical = MedicalScorer(self.logger)
    
    def test_medical_instantiation(self):
        self.assertEqual(self.medical.get_name(), "Medical")
        self.assertIn("medical", self.medical.get_description().lower())
        self.assertEqual(self.medical.get_metric_type(), "domain-specific")
    
    def test_medical_terminology_matching(self):
        ref = "Patient has pneumonia with bilateral infiltrates."
        cand = "Patient shows pneumonia and bilateral lung infiltrates."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertIn("medical_terminology", result)
        self.assertIn("clinical_accuracy", result)
        self.assertIn("overall_medical", result)
        
        self.assertGreater(result["medical_terminology"], 0.7)
    
    def test_negation_detection(self):
        ref = "No evidence of pneumonia."
        cand = "Patient has pneumonia."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertLess(result["clinical_accuracy"], 0.5)
    
    def test_anatomical_terms(self):
        ref = "Left upper lobe consolidation."
        cand = "Right upper lobe consolidation."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertLess(result["clinical_accuracy"], 0.8)
    
    def test_severity_assessment(self):
        ref = "Mild pneumonia."
        cand = "Severe pneumonia."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertLess(result["clinical_accuracy"], 0.7)
    
    def test_abbreviation_normalization(self):
        ref = "Patient has CHF and COPD."
        cand = "Patient has congestive heart failure and chronic obstructive pulmonary disease."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertGreater(result["medical_terminology"], 0.8)
    
    def test_empty_text_medical(self):
        with self.assertRaises(ValueError):
            self.medical.calculate("", "some text")
        
        with self.assertRaises(ValueError):
            self.medical.calculate("some text", "")
    
    def test_non_medical_text(self):
        ref = "The cat sat on the mat."
        cand = "A cat is sitting on a mat."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertLess(result["medical_terminology"], 0.3)
    
    def test_clinical_concept_extraction(self):
        ref = "Patient presents with chest pain, shortness of breath, and fever."
        cand = "Patient has chest pain, dyspnea, and pyrexia."
        
        result = self.medical.calculate(ref, cand)
        
        self.assertGreater(result["medical_terminology"], 0.7)


class TestMetricIntegration(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.metrics = {
            'bleu': BLEUScorer(self.logger),
            'rouge': ROUGEScorer(self.logger),
            'meteor': METEORScorer(self.logger),
            'bertscore': BERTScoreScorer(self.logger),
            'cider': CIDErScorer(self.logger),
            'medical': MedicalScorer(self.logger)
        }
    
    def test_all_metrics_basic_functionality(self):
        ref = "The patient has pneumonia with bilateral lung infiltrates."
        cand = "Patient shows pneumonia and bilateral lung infiltration."
        
        results = {}
        
        for name, metric in self.metrics.items():
            try:
                result = metric.calculate(ref, cand)
                results[name] = result
                
                self.assertIsInstance(result, dict)
                
                self.assertGreater(len(result), 0)
                
                for score in result.values():
                    self.assertIsInstance(score, (int, float))
                    
            except Exception as e:
                self.fail(f"Metric {name} failed: {e}")
        
        self.assertEqual(len(results), len(self.metrics))
    
    def test_metric_consistency(self):
        ref = "Patient has clear lung fields."
        cand = "Patient shows clear lungs."
        
        for name, metric in self.metrics.items():
            results = []
            for _ in range(3):
                result = metric.calculate(ref, cand)
                results.append(result)
            
            first_result = results[0]
            for result in results[1:]:
                for key in first_result:
                    if key in result:
                        self.assertAlmostEqual(
                            first_result[key], 
                            result[key], 
                            places=3,
                            msg=f"Inconsistent results for {name}.{key}"
                        )
    
    def test_metric_performance_tracking(self):
        import time
        
        ref = "The patient presents with acute chest pain, shortness of breath, and diaphoresis."
        cand = "Patient has acute chest pain, dyspnea, and sweating."
        
        performance = {}
        
        for name, metric in self.metrics.items():
            start_time = time.time()
            
            for _ in range(10):
                result = metric.calculate(ref, cand)
            
            end_time = time.time()
            avg_time = (end_time - start_time) / 10
            
            performance[name] = avg_time
            
            self.assertLess(avg_time, 1.0, f"Metric {name} too slow: {avg_time:.3f}s")
        
        print("\n=== Metric Performance ===")
        for name, time_taken in sorted(performance.items(), key=lambda x: x[1]):
            print(f"{name:12s}: {time_taken:.4f}s per evaluation")


def run_comprehensive_metric_tests():
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    test_classes = [
        TestBLEUScorer,
        TestROUGEScorer,
        TestMETEORScorer,
        TestBERTScoreScorer,
        TestCIDErScorer,
        TestMedicalScorer,
        TestMetricIntegration
    ]
    
    for test_class in test_classes:
        tests = loader.loadTestsFromTestCase(test_class)
        suite.addTests(tests)
    
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    total_tests = result.testsRun
    failures = len(result.failures)
    errors = len(result.errors)
    skipped = len(result.skipped) if hasattr(result, 'skipped') else 0
    
    print(f"\n=== Test Summary ===")
    print(f"Total Tests: {total_tests}")
    print(f"Passed: {total_tests - failures - errors - skipped}")
    print(f"Failed: {failures}")
    print(f"Errors: {errors}")
    print(f"Skipped: {skipped}")
    print(f"Success Rate: {((total_tests - failures - errors) / total_tests * 100):.1f}%")
    
    return result.wasSuccessful()


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    
    success = run_comprehensive_metric_tests()
    
    if success:
        print("\nAll comprehensive metric tests passed!")
        exit(0)
    else:
        print("\nSome tests failed!")
        exit(1) 