import re
from collections import Counter
from typing import Dict, List, Optional, Tuple, Set
import logging

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent))

from base_metric import BaseMetric


class ROUGEScorer(BaseMetric):
    # ROUGE metric implementation for text similarity evaluation
    
    def __init__(self, use_stemming: bool = False, remove_stopwords: bool = False,
                 logger: Optional[logging.Logger] = None):
        super().__init__("rouge", logger)
        self.description = "ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric"
        self.metric_type = "similarity"
        
        self.use_stemming = use_stemming
        self.remove_stopwords = remove_stopwords
        
        self.stopwords = {
            'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from',
            'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on', 'that', 'the',
            'to', 'was', 'will', 'with', 'there', 'this', 'these', 'they',
            'have', 'had', 'been', 'being', 'do', 'does', 'did', 'can', 'could',
            'should', 'would', 'may', 'might', 'must', 'shall', 'will', 'am'
        }
        
        self.stemmer = None
        if self.use_stemming:
            try:
                from nltk.stem import PorterStemmer
                self.stemmer = PorterStemmer()
                self.logger.debug("Initialized Porter stemmer for ROUGE")
            except ImportError:
                self.logger.warning("NLTK not available, stemming disabled")
                self.use_stemming = False
        
        self.logger.debug(f"Initialized ROUGE scorer with stemming={use_stemming}, "
                         f"remove_stopwords={remove_stopwords}")
    
    # Calculates ROUGE scores between reference and candidate texts
    def calculate(self, reference: str, candidate: str, **kwargs) -> Dict[str, float]:
        is_valid, issues = self.validate_inputs(reference, candidate)
        if not is_valid:
            raise ValueError(f"Invalid inputs for ROUGE calculation: {issues}")
        
        ref_tokens = self._tokenize_and_preprocess(reference)
        cand_tokens = self._tokenize_and_preprocess(candidate)
        
        scores = {}
        
        rouge_1 = self._calculate_rouge_n(ref_tokens, cand_tokens, n=1)
        scores.update({
            "rouge_1_precision": rouge_1["precision"],
            "rouge_1_recall": rouge_1["recall"],
            "rouge_1_f1": rouge_1["f1"]
        })
        
        rouge_2 = self._calculate_rouge_n(ref_tokens, cand_tokens, n=2)
        scores.update({
            "rouge_2_precision": rouge_2["precision"],
            "rouge_2_recall": rouge_2["recall"],
            "rouge_2_f1": rouge_2["f1"]
        })
        
        rouge_l = self._calculate_rouge_l(ref_tokens, cand_tokens)
        scores.update({
            "rouge_l_precision": rouge_l["precision"],
            "rouge_l_recall": rouge_l["recall"],
            "rouge_l_f1": rouge_l["f1"]
        })
        
        f1_scores = [rouge_1["f1"], rouge_2["f1"], rouge_l["f1"]]
        scores["rouge_avg_f1"] = sum(f1_scores) / len(f1_scores)
        
        scores["reference_length"] = float(len(ref_tokens))
        scores["candidate_length"] = float(len(cand_tokens))
        scores["length_ratio"] = len(cand_tokens) / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
        
        return scores
    
    # Tokenizes and preprocesses text
    def _tokenize_and_preprocess(self, text: str) -> List[str]:
        text = text.lower()
        
        text = re.sub(r'([a-z])\.([a-z])', r'\1. \2', text)
        
        tokens = re.findall(r'\b\w+\b', text)
        
        if self.remove_stopwords:
            tokens = [token for token in tokens if token not in self.stopwords]
        
        if self.use_stemming and self.stemmer:
            tokens = [self.stemmer.stem(token) for token in tokens]
        
        return tokens
    
    # Extracts n-grams from a list of tokens
    def _get_ngrams(self, tokens: List[str], n: int) -> List[Tuple[str, ...]]:
        if len(tokens) < n:
            return []
        
        ngrams = []
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i + n])
            ngrams.append(ngram)
        
        return ngrams
    
    # Calculates ROUGE-N scores
    def _calculate_rouge_n(self, ref_tokens: List[str], cand_tokens: List[str], 
                          n: int) -> Dict[str, float]:
        ref_ngrams = self._get_ngrams(ref_tokens, n)
        cand_ngrams = self._get_ngrams(cand_tokens, n)
        
        if not ref_ngrams and not cand_ngrams:
            return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
        
        if not ref_ngrams:
            return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
        
        if not cand_ngrams:
            return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
        
        ref_counts = Counter(ref_ngrams)
        cand_counts = Counter(cand_ngrams)
        
        overlap = 0
        for ngram, count in cand_counts.items():
            if ngram in ref_counts:
                overlap += min(count, ref_counts[ngram])
        
        precision = overlap / len(cand_ngrams) if len(cand_ngrams) > 0 else 0.0
        recall = overlap / len(ref_ngrams) if len(ref_ngrams) > 0 else 0.0
        
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0
        
        return {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
    
    # Calculates ROUGE-L scores based on longest common subsequence
    def _calculate_rouge_l(self, ref_tokens: List[str], cand_tokens: List[str]) -> Dict[str, float]:
        if not ref_tokens and not cand_tokens:
            return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
        
        if not ref_tokens or not cand_tokens:
            return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
        
        lcs_length = self._lcs_length(ref_tokens, cand_tokens)
        
        precision = lcs_length / len(cand_tokens) if len(cand_tokens) > 0 else 0.0
        recall = lcs_length / len(ref_tokens) if len(ref_tokens) > 0 else 0.0
        
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0
        
        return {
            "precision": precision,
            "recall": recall,
            "f1": f1
        }
    
    # Calculates the length of the longest common subsequence
    def _lcs_length(self, seq1: List[str], seq2: List[str]) -> int:
        m, n = len(seq1), len(seq2)
        
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if seq1[i - 1] == seq2[j - 1]:
                    dp[i][j] = dp[i - 1][j - 1] + 1
                else:
                    dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
        
        return dp[m][n]
    
    # Calculates ROUGE-W (weighted longest common subsequence)
    def calculate_rouge_w(self, reference: str, candidate: str, weight: float = 1.2) -> Dict[str, float]:
        ref_tokens = self._tokenize_and_preprocess(reference)
        cand_tokens = self._tokenize_and_preprocess(candidate)
        
        if not ref_tokens or not cand_tokens:
            return {"rouge_w_precision": 0.0, "rouge_w_recall": 0.0, "rouge_w_f1": 0.0}
        
        wlcs_length = self._weighted_lcs_length(ref_tokens, cand_tokens, weight)
        
        ref_norm = self._weighted_lcs_length(ref_tokens, ref_tokens, weight)
        cand_norm = self._weighted_lcs_length(cand_tokens, cand_tokens, weight)
        
        precision = wlcs_length / cand_norm if cand_norm > 0 else 0.0
        recall = wlcs_length / ref_norm if ref_norm > 0 else 0.0
        
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0
        
        return {
            "rouge_w_precision": precision,
            "rouge_w_recall": recall,
            "rouge_w_f1": f1
        }
    
    # Calculates weighted LCS length with consecutive match bonus
    def _weighted_lcs_length(self, seq1: List[str], seq2: List[str], weight: float) -> float:
        m, n = len(seq1), len(seq2)
        
        dp = [[[0.0, 0] for _ in range(n + 1)] for _ in range(m + 1)]
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if seq1[i - 1] == seq2[j - 1]:
                    consecutive = dp[i - 1][j - 1][1] + 1
                    length = dp[i - 1][j - 1][0] + weight ** consecutive
                    dp[i][j] = [length, consecutive]
                else:
                    if dp[i - 1][j][0] > dp[i][j - 1][0]:
                        dp[i][j] = [dp[i - 1][j][0], 0]
                    else:
                        dp[i][j] = [dp[i][j - 1][0], 0]
        
        return dp[m][n][0]
    
    def get_name(self) -> str:
        return "ROUGE"
    
    def get_description(self) -> str:
        return (f"ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric "
                f"measuring n-gram and LCS overlap (stemming={self.use_stemming}, "
                f"remove_stopwords={self.remove_stopwords})")


# Tests ROUGE scorer with various medical text similarity scenarios
def test_rouge_scorer():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    try:
        print("Testing ROUGE Scorer...")
        
        rouge = ROUGEScorer(use_stemming=False, remove_stopwords=False, logger=logger)
        
        print(f"Metric name: {rouge.get_name()}")
        print(f"Metric description: {rouge.get_description()}")
        
        test_cases = [
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Normal chest radiograph with clear lung fields.",
                "description": "Perfect match"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Normal chest X-ray with clear lung areas.",
                "description": "High similarity"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Chest X-ray shows normal findings.",
                "description": "Medium similarity"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Patient has pneumonia and pleural effusion.",
                "description": "Low similarity"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields and normal heart size.",
                "candidate": "Normal chest radiograph.",
                "description": "Shorter candidate"
            },
            
            {
                "reference": "Normal chest radiograph.",
                "candidate": "Normal chest radiograph with clear lung fields and normal heart size.",
                "description": "Longer candidate"
            }
        ]
        
        print("\n--- ROUGE Score Tests ---")
        for i, test_case in enumerate(test_cases, 1):
            scores = rouge.calculate(test_case["reference"], test_case["candidate"])
            
            print(f"\nTest {i}: {test_case['description']}")
            print(f"Reference: '{test_case['reference']}'")
            print(f"Candidate: '{test_case['candidate']}'")
            print(f"ROUGE-1 F1: {scores['rouge_1_f1']:.4f}")
            print(f"ROUGE-2 F1: {scores['rouge_2_f1']:.4f}")
            print(f"ROUGE-L F1: {scores['rouge_l_f1']:.4f}")
            print(f"Average F1: {scores['rouge_avg_f1']:.4f}")
            print(f"Length ratio: {scores['length_ratio']:.4f}")
        
        print("\n--- ROUGE-W Test ---")
        ref_text = "Normal chest radiograph with clear lung fields."
        cand_text = "Normal chest X-ray with clear lung areas."
        
        rouge_w_scores = rouge.calculate_rouge_w(ref_text, cand_text)
        print(f"ROUGE-W F1: {rouge_w_scores['rouge_w_f1']:.4f}")
        print(f"ROUGE-W Precision: {rouge_w_scores['rouge_w_precision']:.4f}")
        print(f"ROUGE-W Recall: {rouge_w_scores['rouge_w_recall']:.4f}")
        
        print("\n--- Stopword Removal Test ---")
        rouge_no_stop = ROUGEScorer(remove_stopwords=True, logger=logger)
        
        ref_with_stops = "The chest radiograph shows normal lung fields and the heart is normal."
        cand_with_stops = "Chest radiograph shows normal lung fields and heart is normal."
        
        scores_with_stops = rouge.calculate(ref_with_stops, cand_with_stops)
        scores_no_stops = rouge_no_stop.calculate(ref_with_stops, cand_with_stops)
        
        print(f"With stopwords - ROUGE-1 F1: {scores_with_stops['rouge_1_f1']:.4f}")
        print(f"Without stopwords - ROUGE-1 F1: {scores_no_stops['rouge_1_f1']:.4f}")
        
        print("\n--- Batch Calculation Test ---")
        ref_list = [
            "Normal chest radiograph with clear lung fields.",
            "Heart size is within normal limits.",
            "No acute cardiopulmonary abnormalities."
        ]
        cand_list = [
            "Normal chest X-ray with clear lungs.",
            "Heart size appears normal.",
            "No acute findings identified."
        ]
        
        batch_scores = rouge.calculate_batch(ref_list, cand_list)
        print(f"Batch results: {len(batch_scores)} scores calculated")
        
        avg_rouge_1 = sum(score.get('rouge_1_f1', 0) for score in batch_scores) / len(batch_scores)
        avg_rouge_l = sum(score.get('rouge_l_f1', 0) for score in batch_scores) / len(batch_scores)
        
        print(f"Average ROUGE-1 F1: {avg_rouge_1:.4f}")
        print(f"Average ROUGE-L F1: {avg_rouge_l:.4f}")
        
        print("\n--- Edge Cases Test ---")
        
        try:
            empty_scores = rouge.calculate("", "")
            print("Empty strings should raise error")
        except ValueError:
            print("Empty strings correctly handled")
        
        short_scores = rouge.calculate("Normal.", "Normal.")
        print(f"Short text ROUGE-1 F1: {short_scores['rouge_1_f1']:.4f}")
        
        print("\n--- Performance Test ---")
        perf_stats = rouge.get_performance_stats()
        print(f"Calculations performed: {perf_stats['calculation_count']}")
        print(f"Average calculation time: {perf_stats['average_time']:.4f}s")
        
        print("\nAll ROUGE scorer tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    success = test_rouge_scorer()
    
    if success:
        print("\nROUGE Scorer tests passed!")
    else:
        print("\nSome tests failed!") 