from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union, Any, Tuple
import logging
import time
from datetime import datetime


class BaseMetric(ABC):
    
    def __init__(self, name: str, logger: Optional[logging.Logger] = None):
        self.name = name
        self.logger = logger or logging.getLogger(__name__)
        self.version = "1.0"
        self.description = ""
        self.metric_type = "similarity"
        
        self.calculation_count = 0
        self.total_calculation_time = 0.0
        self.last_calculation_time = 0.0
        
        self.config = {}
        self.is_initialized = False
        
        self.logger.debug(f"Initialized {self.name} metric")
    
    # Calculate metric score between reference and candidate texts
    @abstractmethod
    def calculate(self, reference: str, candidate: str, **kwargs) -> Dict[str, float]:
        pass
    
    # Get the metric name
    @abstractmethod
    def get_name(self) -> str:
        pass
    
    # Get the metric description
    @abstractmethod
    def get_description(self) -> str:
        pass
    
    # Get the metric version
    def get_version(self) -> str:
        return self.version
    
    # Get the metric type category
    def get_metric_type(self) -> str:
        return self.metric_type
    
    # Configure the metric with parameters
    def configure(self, config: Dict[str, Any]) -> None:
        self.config.update(config)
        self.logger.debug(f"Configured {self.name} with parameters: {list(config.keys())}")
    
    # Initialize the metric (load models, prepare resources, etc.)
    def initialize(self) -> None:
        self.is_initialized = True
        self.logger.debug(f"{self.name} metric initialized")
    
    # Calculate metric with performance timing
    def calculate_with_timing(self, reference: str, candidate: str, **kwargs) -> Tuple[Dict[str, float], float]:
        if not self.is_initialized:
            self.initialize()
        
        start_time = time.time()
        
        try:
            scores = self.calculate(reference, candidate, **kwargs)
            calculation_time = time.time() - start_time
            
            self.calculation_count += 1
            self.total_calculation_time += calculation_time
            self.last_calculation_time = calculation_time
            
            self.logger.debug(f"{self.name} calculation completed in {calculation_time:.4f}s")
            
            return scores, calculation_time
            
        except Exception as e:
            calculation_time = time.time() - start_time
            self.logger.error(f"Error calculating {self.name}: {e}")
            raise
    
    # Calculate metric for a batch of text pairs
    def calculate_batch(self, reference_list: List[str], candidate_list: List[str], 
                       **kwargs) -> List[Dict[str, float]]:
        if len(reference_list) != len(candidate_list):
            raise ValueError(f"Reference and candidate lists must have same length: "
                           f"{len(reference_list)} vs {len(candidate_list)}")
        
        self.logger.info(f"Calculating {self.name} for batch of {len(reference_list)} pairs")
        
        batch_scores = []
        batch_start_time = time.time()
        
        for i, (ref, cand) in enumerate(zip(reference_list, candidate_list)):
            try:
                scores = self.calculate(ref, cand, **kwargs)
                batch_scores.append(scores)
                
                if (i + 1) % 100 == 0:
                    self.logger.debug(f"Processed {i + 1}/{len(reference_list)} pairs")
                    
            except Exception as e:
                self.logger.error(f"Error calculating {self.name} for pair {i}: {e}")
                batch_scores.append({})
        
        batch_time = time.time() - batch_start_time
        self.logger.info(f"Batch calculation completed in {batch_time:.2f}s "
                        f"({batch_time/len(reference_list):.4f}s per pair)")
        
        return batch_scores
    
    # Format metric results for output
    def format_results(self, scores: Dict[str, float], precision: int = 4) -> Dict[str, Any]:
        formatted_scores = {}
        
        for key, value in scores.items():
            if isinstance(value, float):
                formatted_scores[key] = round(value, precision)
            else:
                formatted_scores[key] = value
        
        result = {
            "metric_name": self.name,
            "metric_type": self.metric_type,
            "scores": formatted_scores,
            "metadata": {
                "version": self.version,
                "calculation_time": self.last_calculation_time,
                "timestamp": datetime.now().isoformat()
            }
        }
        
        return result
    
    # Validate input texts for metric calculation
    def validate_inputs(self, reference: str, candidate: str) -> Tuple[bool, List[str]]:
        issues = []
        
        if reference is None:
            issues.append("Reference text is None")
        elif not isinstance(reference, str):
            issues.append(f"Reference text must be string, got {type(reference)}")
        elif len(reference.strip()) == 0:
            issues.append("Reference text is empty")
        
        if candidate is None:
            issues.append("Candidate text is None")
        elif not isinstance(candidate, str):
            issues.append(f"Candidate text must be string, got {type(candidate)}")
        elif len(candidate.strip()) == 0:
            issues.append("Candidate text is empty")
        
        if isinstance(reference, str) and len(reference) > 50000:
            issues.append("Reference text is extremely long (>50k chars)")
        
        if isinstance(candidate, str) and len(candidate) > 50000:
            issues.append("Candidate text is extremely long (>50k chars)")
        
        is_valid = len(issues) == 0
        return is_valid, issues
    
    # Get performance statistics for this metric
    def get_performance_stats(self) -> Dict[str, Any]:
        avg_time = (self.total_calculation_time / self.calculation_count 
                   if self.calculation_count > 0 else 0.0)
        
        return {
            "metric_name": self.name,
            "calculation_count": self.calculation_count,
            "total_time": self.total_calculation_time,
            "average_time": avg_time,
            "last_calculation_time": self.last_calculation_time,
            "is_initialized": self.is_initialized
        }
    
    # Reset performance tracking statistics
    def reset_performance_stats(self) -> None:
        self.calculation_count = 0
        self.total_calculation_time = 0.0
        self.last_calculation_time = 0.0
        self.logger.debug(f"Reset performance stats for {self.name}")
    
    # String representation of the metric
    def __str__(self) -> str:
        return f"{self.name} (v{self.version}) - {self.metric_type}"
    
    # Detailed string representation of the metric
    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}(name='{self.name}', "
                f"type='{self.metric_type}', version='{self.version}', "
                f"initialized={self.is_initialized})")


class DummyMetric(BaseMetric):
    
    def __init__(self, logger: Optional[logging.Logger] = None):
        super().__init__("dummy", logger)
        self.description = "Dummy metric for testing base class functionality"
        self.metric_type = "test"
    
    # Calculate dummy scores
    def calculate(self, reference: str, candidate: str, **kwargs) -> Dict[str, float]:
        is_valid, issues = self.validate_inputs(reference, candidate)
        if not is_valid:
            raise ValueError(f"Invalid inputs: {issues}")
        
        ref_len = len(reference.split())
        cand_len = len(candidate.split())
        
        length_ratio = min(ref_len, cand_len) / max(ref_len, cand_len) if max(ref_len, cand_len) > 0 else 0.0
        
        return {
            "dummy_score": length_ratio,
            "reference_length": float(ref_len),
            "candidate_length": float(cand_len),
            "length_difference": abs(ref_len - cand_len)
        }
    
    def get_name(self) -> str:
        return "Dummy Metric"
    
    def get_description(self) -> str:
        return self.description


# Test the Base Metric class functionality
def test_base_metric():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    try:
        print("Testing Base Metric Class...")
        
        dummy = DummyMetric(logger)
        
        print(f"Metric name: {dummy.get_name()}")
        print(f"Metric description: {dummy.get_description()}")
        print(f"Metric type: {dummy.get_metric_type()}")
        print(f"Metric version: {dummy.get_version()}")
        
        dummy.configure({"test_param": 1.0, "another_param": "test"})
        print(f"Configuration: {dummy.config}")
        
        ref_text = "Normal chest radiograph with clear lung fields."
        cand_text = "Chest X-ray shows normal lungs and heart."
        
        scores, calc_time = dummy.calculate_with_timing(ref_text, cand_text)
        print(f"Single calculation: {scores}")
        print(f"Calculation time: {calc_time:.4f}s")
        
        formatted = dummy.format_results(scores)
        print(f"Formatted results: {formatted['scores']}")
        
        is_valid, issues = dummy.validate_inputs(ref_text, cand_text)
        print(f"Input validation: valid={is_valid}, issues={issues}")
        
        is_valid, issues = dummy.validate_inputs("", None)
        print(f"Invalid input test: valid={is_valid}, issues={len(issues)} found")
        
        ref_list = [
            "Normal chest radiograph.",
            "Clear lung fields bilaterally.",
            "Heart size within normal limits."
        ]
        cand_list = [
            "Normal chest X-ray.",
            "Lungs are clear on both sides.",
            "Normal heart size."
        ]
        
        batch_scores = dummy.calculate_batch(ref_list, cand_list)
        print(f"Batch calculation: {len(batch_scores)} results")
        
        perf_stats = dummy.get_performance_stats()
        print(f"Performance stats: {perf_stats['calculation_count']} calculations")
        
        print(f"String repr: {str(dummy)}")
        print(f"Detailed repr: {repr(dummy)}")
        
        print("\nAll base metric tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        return False


if __name__ == "__main__":
    success = test_base_metric()
    
    if success:
        print("\nBase Metric tests passed!")
    else:
        print("\nSome tests failed!") 