import unittest
import sys
import os
import logging
import time
from unittest.mock import Mock, patch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from metrics.base_metric import BaseMetric, DummyMetric


class TestBaseMetric(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
    
    def test_abstract_base_class_instantiation(self):
        with self.assertRaises(TypeError):
            BaseMetric("test_metric", self.logger)
    
    def test_dummy_metric_instantiation(self):
        dummy = DummyMetric(self.logger)
        self.assertIsInstance(dummy, BaseMetric)
        self.assertEqual(dummy.get_name(), "Dummy Metric")
        self.assertIn("dummy metric", dummy.get_description().lower())
    
    def test_dummy_metric_calculation(self):
        dummy = DummyMetric(self.logger)
        
        result = dummy.calculate("reference text", "candidate text")
        self.assertIsInstance(result, dict)
        self.assertIn("dummy_score", result)
        self.assertIsInstance(result["dummy_score"], float)
        self.assertTrue(0.0 <= result["dummy_score"] <= 1.0)
    
    def test_performance_tracking(self):
        dummy = DummyMetric(self.logger)
        
        self.assertEqual(dummy.calculation_count, 0)
        self.assertEqual(dummy.total_calculation_time, 0.0)
        
        result, calc_time = dummy.calculate_with_timing("ref", "cand")
        
        self.assertEqual(dummy.calculation_count, 1)
        self.assertGreaterEqual(dummy.total_calculation_time, 0.0)
        self.assertGreaterEqual(calc_time, 0.0)
        self.assertEqual(dummy.last_calculation_time, calc_time)
        
        dummy.calculate_with_timing("ref2", "cand2")
        self.assertEqual(dummy.calculation_count, 2)
    
    def test_batch_processing(self):
        dummy = DummyMetric(self.logger)
        
        references = ["ref1", "ref2", "ref3"]
        candidates = ["cand1", "cand2", "cand3"]
        
        results = dummy.calculate_batch(references, candidates)
        
        self.assertEqual(len(results), 3)
        for result in results:
            self.assertIsInstance(result, dict)
            self.assertIn("dummy_score", result)
    
    def test_batch_processing_length_mismatch(self):
        dummy = DummyMetric(self.logger)
        
        references = ["ref1", "ref2"]
        candidates = ["cand1", "cand2", "cand3"]
        
        with self.assertRaises(ValueError):
            dummy.calculate_batch(references, candidates)
    
    def test_input_validation(self):
        dummy = DummyMetric(self.logger)
        
        is_valid, errors = dummy.validate_inputs("valid ref", "valid cand")
        self.assertTrue(is_valid)
        self.assertEqual(len(errors), 0)
        
        is_valid, errors = dummy.validate_inputs("", "valid cand")
        self.assertFalse(is_valid)
        self.assertGreater(len(errors), 0)
        
        is_valid, errors = dummy.validate_inputs("valid ref", "")
        self.assertFalse(is_valid)
        self.assertGreater(len(errors), 0)
        
        is_valid, errors = dummy.validate_inputs(None, "valid cand")
        self.assertFalse(is_valid)
        self.assertGreater(len(errors), 0)
    
    def test_result_formatting(self):
        dummy = DummyMetric(self.logger)
        
        scores = {"score1": 0.123456789, "score2": 0.987654321}
        
        formatted = dummy.format_results(scores)
        self.assertIn("scores", formatted)
        self.assertIn("metadata", formatted)
        
        formatted = dummy.format_results(scores, precision=2)
        formatted_scores = formatted["scores"]
        for score in formatted_scores.values():
            if isinstance(score, float):
                self.assertLessEqual(len(str(score).split('.')[-1]), 3)
    
    def test_configuration_handling(self):
        dummy = DummyMetric(self.logger)
        
        self.assertIsInstance(dummy.config, dict)
        
        test_config = {"param1": "value1", "param2": 42}
        dummy.configure(test_config)
        
        self.assertEqual(dummy.config["param1"], "value1")
        self.assertEqual(dummy.config["param2"], 42)
    
    def test_initialization_flag(self):
        dummy = DummyMetric(self.logger)
        
        self.assertFalse(dummy.is_initialized)
        
        dummy.initialize()
        self.assertTrue(dummy.is_initialized)
        
        dummy2 = DummyMetric(self.logger)
        self.assertFalse(dummy2.is_initialized)
        dummy2.calculate_with_timing("ref", "cand")
        self.assertTrue(dummy2.is_initialized)
    
    def test_performance_stats(self):
        dummy = DummyMetric(self.logger)
        
        dummy.calculate_with_timing("ref1", "cand1")
        dummy.calculate_with_timing("ref2", "cand2")
        
        stats = dummy.get_performance_stats()
        
        self.assertIn("calculation_count", stats)
        self.assertIn("total_time", stats)
        self.assertIn("average_time", stats)
        self.assertIn("last_calculation_time", stats)
        
        self.assertEqual(stats["calculation_count"], 2)
        self.assertGreaterEqual(stats["total_time"], 0.0)
        self.assertGreaterEqual(stats["average_time"], 0.0)
        
        dummy.reset_performance_stats()
        self.assertEqual(dummy.calculation_count, 0)
        self.assertEqual(dummy.total_calculation_time, 0.0)
    
    def test_string_representations(self):
        dummy = DummyMetric(self.logger)
        
        str_repr = str(dummy)
        self.assertIn("dummy", str_repr)
        
        repr_str = repr(dummy)
        self.assertIn("DummyMetric", repr_str)
    
    def test_metric_type_and_version(self):
        dummy = DummyMetric(self.logger)
        
        self.assertEqual(dummy.get_version(), "1.0")
        self.assertEqual(dummy.get_metric_type(), "test")
    
    def test_error_handling_in_calculation(self):
        dummy = DummyMetric(self.logger)
        
        with self.assertRaises(Exception):
            dummy.calculate_with_timing(None, "valid")
        
        with self.assertRaises(Exception):
            dummy.calculate_with_timing("valid", None)


class TestDummyMetricSpecific(unittest.TestCase):
    
    def setUp(self):
        self.logger = Mock(spec=logging.Logger)
        self.dummy = DummyMetric(self.logger)
    
    def test_dummy_score_range(self):
        test_cases = [
            ("identical text", "identical text"),
            ("completely different", "totally unrelated content"),
            ("", ""),
            ("short", "text"),
            ("very long text with many words and complex structure", "another long text")
        ]
        
        for ref, cand in test_cases:
            try:
                result = self.dummy.calculate(ref, cand)
                score = result["dummy_score"]
                self.assertTrue(0.0 <= score <= 1.0, 
                              f"Score {score} out of range for inputs: '{ref}', '{cand}'")
            except Exception:
                pass
    
    def test_dummy_score_consistency(self):
        ref = "test reference text"
        cand = "test candidate text"
        
        score1 = self.dummy.calculate(ref, cand)["dummy_score"]
        score2 = self.dummy.calculate(ref, cand)["dummy_score"]
        
        self.assertEqual(score1, score2, "Dummy scores should be consistent")
    
    def test_dummy_additional_metrics(self):
        result = self.dummy.calculate("reference", "candidate")
        
        self.assertIn("reference_length", result)
        self.assertIn("candidate_length", result)
        self.assertIn("length_difference", result)
        
        ref_length = result["reference_length"]
        self.assertGreaterEqual(ref_length, 0.0)
        
        cand_length = result["candidate_length"]
        self.assertGreaterEqual(cand_length, 0.0)
        
        length_diff = result["length_difference"]
        self.assertGreaterEqual(length_diff, 0)


# Runs comprehensive tests for BaseMetric and DummyMetric implementations
def run_base_metric_tests():
    print("Running Base Metric Tests...")
    print("=" * 50)
    
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    suite.addTests(loader.loadTestsFromTestCase(TestBaseMetric))
    suite.addTests(loader.loadTestsFromTestCase(TestDummyMetricSpecific))
    
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    print("\n" + "=" * 50)
    print(f"Base Metric Tests Summary:")
    print(f"   Tests run: {result.testsRun}")
    print(f"   Failures: {len(result.failures)}")
    print(f"   Errors: {len(result.errors)}")
    
    if result.failures:
        print("\nFailures:")
        for test, traceback in result.failures:
            print(f"   - {test}: {traceback.split('AssertionError:')[-1].strip()}")
    
    if result.errors:
        print("\nErrors:")
        for test, traceback in result.errors:
            print(f"   - {test}: {traceback.split('Exception:')[-1].strip()}")
    
    success = len(result.failures) == 0 and len(result.errors) == 0
    
    if success:
        print("\nAll base metric tests passed!")
    else:
        print(f"\n{len(result.failures + result.errors)} tests failed!")
    
    return success, result


if __name__ == "__main__":
    success, _ = run_base_metric_tests()
    exit(0 if success else 1) 