import re
import numpy as np
from typing import Dict, List, Optional, Tuple, Union
import logging

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent))

from base_metric import BaseMetric


# BERTScore metric using contextual embeddings for semantic similarity
class BERTScoreScorer(BaseMetric):
    
    def __init__(self, model_name: str = "bert-base-uncased", 
                 use_fast_tokenizer: bool = True, device: str = "auto",
                 batch_size: int = 32, max_length: int = 512,
                 logger: Optional[logging.Logger] = None):
        super().__init__("bertscore", logger)
        self.description = "BERTScore metric using contextual embeddings for semantic similarity"
        self.metric_type = "similarity"
        
        self.model_name = model_name
        self.use_fast_tokenizer = use_fast_tokenizer
        self.device = device
        self.batch_size = batch_size
        self.max_length = max_length
        
        self.tokenizer = None
        self.model = None
        self.actual_device = None
        
        self.embedding_cache = {}
        self.use_cache = True
        
        self.logger.debug(f"Initialized BERTScore scorer with model={model_name}")
    
    # Initialize BERT model and tokenizer
    def initialize(self) -> None:
        if self.is_initialized:
            return
        
        try:
            from transformers import AutoTokenizer, AutoModel
            import torch
            
            self.logger.info(f"Loading BERT model: {self.model_name}")
            
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                use_fast=self.use_fast_tokenizer
            )
            
            self.model = AutoModel.from_pretrained(self.model_name)
            
            if self.device == "auto":
                self.actual_device = "cuda" if torch.cuda.is_available() else "cpu"
            else:
                self.actual_device = self.device
            
            self.model.to(self.actual_device)
            self.model.eval()
            
            self.logger.info(f"BERTScore model loaded on device: {self.actual_device}")
            
        except ImportError as e:
            self.logger.error("Transformers library not available. Install with: pip install transformers torch")
            raise ImportError("BERTScore requires transformers library") from e
        except Exception as e:
            self.logger.error(f"Failed to initialize BERTScore model: {e}")
            raise
        
        super().initialize()
    
    # Calculate BERTScore between reference and candidate texts
    def calculate(self, reference: str, candidate: str, **kwargs) -> Dict[str, float]:
        # Validate inputs
        is_valid, issues = self.validate_inputs(reference, candidate)
        if not is_valid:
            raise ValueError(f"Invalid inputs for BERTScore calculation: {issues}")
        
        if not self.is_initialized:
            self.initialize()
        
        try:
            import torch
            
            ref_embeddings, ref_tokens = self._get_embeddings(reference)
            cand_embeddings, cand_tokens = self._get_embeddings(candidate)
            
            if ref_embeddings is None or cand_embeddings is None:
                return self._get_zero_scores()
            
            similarity_matrix = self._calculate_similarity_matrix(ref_embeddings, cand_embeddings)
            
            precision = self._calculate_precision(similarity_matrix)
            recall = self._calculate_recall(similarity_matrix)
            
            if precision + recall > 0:
                f1 = 2 * precision * recall / (precision + recall)
            else:
                f1 = 0.0
            
            scores = {
                "bertscore_precision": precision,
                "bertscore_recall": recall,
                "bertscore_f1": f1,
                "bertscore": f1,
                "reference_tokens": float(len(ref_tokens)),
                "candidate_tokens": float(len(cand_tokens)),
                "similarity_mean": float(similarity_matrix.mean()),
                "similarity_max": float(similarity_matrix.max()),
                "similarity_min": float(similarity_matrix.min())
            }
            
            return scores
            
        except Exception as e:
            self.logger.error(f"Error calculating BERTScore: {e}")
            return self._get_zero_scores()
    
    # Get BERT embeddings for text
    def _get_embeddings(self, text: str) -> Tuple[Optional[np.ndarray], List[str]]:
        try:
            import torch
            
            if self.use_cache and text in self.embedding_cache:
                return self.embedding_cache[text]
            
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=self.max_length
            )
            
            inputs = {k: v.to(self.actual_device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                embeddings = outputs.last_hidden_state.squeeze(0)
            
            embeddings = embeddings.cpu().numpy()
            
            tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze(0))
            
            valid_indices = []
            valid_tokens = []
            
            for i, token in enumerate(tokens):
                if token not in [self.tokenizer.cls_token, self.tokenizer.sep_token, 
                               self.tokenizer.pad_token]:
                    valid_indices.append(i)
                    valid_tokens.append(token)
            
            if valid_indices:
                valid_embeddings = embeddings[valid_indices]
            else:
                valid_embeddings = None
                valid_tokens = []
            
            if self.use_cache:
                self.embedding_cache[text] = (valid_embeddings, valid_tokens)
            
            return valid_embeddings, valid_tokens
            
        except Exception as e:
            self.logger.error(f"Error getting embeddings: {e}")
            return None, []
    
    # Calculate cosine similarity matrix between reference and candidate embeddings
    def _calculate_similarity_matrix(self, ref_embeddings: np.ndarray, 
                                   cand_embeddings: np.ndarray) -> np.ndarray:
        ref_norm = ref_embeddings / np.linalg.norm(ref_embeddings, axis=1, keepdims=True)
        cand_norm = cand_embeddings / np.linalg.norm(cand_embeddings, axis=1, keepdims=True)
        
        similarity_matrix = np.dot(ref_norm, cand_norm.T)
        
        return similarity_matrix
    
    # Calculate BERTScore precision
    def _calculate_precision(self, similarity_matrix: np.ndarray) -> float:
        max_similarities = np.max(similarity_matrix, axis=0)
        precision = np.mean(max_similarities)
        return float(precision)
    
    # Calculate BERTScore recall
    def _calculate_recall(self, similarity_matrix: np.ndarray) -> float:
        max_similarities = np.max(similarity_matrix, axis=1)
        recall = np.mean(max_similarities)
        return float(recall)
    
    # Get zero scores for error cases
    def _get_zero_scores(self) -> Dict[str, float]:
        return {
            "bertscore_precision": 0.0,
            "bertscore_recall": 0.0,
            "bertscore_f1": 0.0,
            "bertscore": 0.0,
            "reference_tokens": 0.0,
            "candidate_tokens": 0.0,
            "similarity_mean": 0.0,
            "similarity_max": 0.0,
            "similarity_min": 0.0
        }
    
    # Calculate BERTScore with IDF weighting
    def calculate_with_idf(self, reference: str, candidate: str, 
                          idf_weights: Optional[Dict[str, float]] = None) -> Dict[str, float]:
        scores = self.calculate(reference, candidate)
        
        if idf_weights is None:
            return scores
        
        try:
            ref_embeddings, ref_tokens = self._get_embeddings(reference)
            cand_embeddings, cand_tokens = self._get_embeddings(candidate)
            
            if ref_embeddings is None or cand_embeddings is None:
                return scores
            
            similarity_matrix = self._calculate_similarity_matrix(ref_embeddings, cand_embeddings)
            
            ref_weights = np.array([idf_weights.get(token, 1.0) for token in ref_tokens])
            cand_weights = np.array([idf_weights.get(token, 1.0) for token in cand_tokens])
            
            max_similarities = np.max(similarity_matrix, axis=0)
            weighted_precision = np.sum(max_similarities * cand_weights) / np.sum(cand_weights)
            
            max_similarities = np.max(similarity_matrix, axis=1)
            weighted_recall = np.sum(max_similarities * ref_weights) / np.sum(ref_weights)
            
            if weighted_precision + weighted_recall > 0:
                weighted_f1 = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
            else:
                weighted_f1 = 0.0
            
            scores.update({
                "bertscore_idf_precision": float(weighted_precision),
                "bertscore_idf_recall": float(weighted_recall),
                "bertscore_idf_f1": float(weighted_f1),
                "bertscore_idf": float(weighted_f1)
            })
            
            return scores
            
        except Exception as e:
            self.logger.error(f"Error calculating IDF-weighted BERTScore: {e}")
            return scores
    
    # Clear the embedding cache
    def clear_cache(self) -> None:
        self.embedding_cache.clear()
        self.logger.debug("Cleared BERTScore embedding cache")
    
    # Get the current cache size
    def get_cache_size(self) -> int:
        return len(self.embedding_cache)
    
    # Get the metric name
    def get_name(self) -> str:
        return "BERTScore"
    
    # Get the metric description
    def get_description(self) -> str:
        return (f"BERTScore metric using {self.model_name} for semantic similarity "
                f"measurement with contextual embeddings")


# Test function for the BERTScore Scorer
def test_bertscore_scorer():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    try:
        print("Testing BERTScore Scorer...")
        
        bertscore = BERTScoreScorer(
            model_name="distilbert-base-uncased",
            device="cpu",
            logger=logger
        )
        
        print(f"Metric name: {bertscore.get_name()}")
        print(f"Metric description: {bertscore.get_description()}")
        
        test_cases = [
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Normal chest radiograph with clear lung fields.",
                "description": "Perfect match"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Normal chest X-ray with clear lungs.",
                "description": "High semantic similarity"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Chest imaging shows normal pulmonary findings.",
                "description": "Medium similarity (paraphrase)"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Patient has severe pneumonia and effusion.",
                "description": "Low similarity (different meaning)"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "The weather is sunny today.",
                "description": "No similarity"
            }
        ]
        
        print("\n--- BERTScore Tests ---")
        for i, test_case in enumerate(test_cases, 1):
            try:
                scores = bertscore.calculate(test_case["reference"], test_case["candidate"])
                
                print(f"\nTest {i}: {test_case['description']}")
                print(f"Reference: '{test_case['reference']}'")
                print(f"Candidate: '{test_case['candidate']}'")
                print(f"BERTScore F1: {scores['bertscore_f1']:.4f}")
                print(f"BERTScore Precision: {scores['bertscore_precision']:.4f}")
                print(f"BERTScore Recall: {scores['bertscore_recall']:.4f}")
                print(f"Similarity Mean: {scores['similarity_mean']:.4f}")
                print(f"Reference Tokens: {scores['reference_tokens']}")
                print(f"Candidate Tokens: {scores['candidate_tokens']}")
                
            except Exception as e:
                print(f"Test {i} failed: {e}")
                continue
        
        print("\n--- IDF Weighting Test ---")
        try:
            idf_weights = {
                "normal": 2.0,
                "chest": 2.5,
                "radiograph": 3.0,
                "lung": 2.5,
                "fields": 2.0,
                "clear": 1.5,
                "the": 0.1,
                "with": 0.1,
                "and": 0.1
            }
            
            ref_text = "Normal chest radiograph with clear lung fields."
            cand_text = "Normal chest X-ray with clear lungs."
            
            idf_scores = bertscore.calculate_with_idf(ref_text, cand_text, idf_weights)
            regular_scores = bertscore.calculate(ref_text, cand_text)
            
            print(f"Regular BERTScore F1: {regular_scores['bertscore_f1']:.4f}")
            print(f"IDF-weighted BERTScore F1: {idf_scores.get('bertscore_idf_f1', 'N/A')}")
            
        except Exception as e:
            print(f"IDF test failed: {e}")
        
        print("\n--- Batch Calculation Test ---")
        try:
            ref_list = [
                "Normal chest radiograph with clear lung fields.",
                "Heart size is within normal limits.",
                "No acute cardiopulmonary abnormalities."
            ]
            cand_list = [
                "Normal chest X-ray with clear lungs.",
                "Heart size appears normal.",
                "No acute findings identified."
            ]
            
            batch_scores = bertscore.calculate_batch(ref_list, cand_list)
            print(f"Batch results: {len(batch_scores)} scores calculated")
            
            valid_scores = [score.get('bertscore_f1', 0) for score in batch_scores if score]
            if valid_scores:
                avg_bertscore = sum(valid_scores) / len(valid_scores)
                print(f"Average BERTScore F1: {avg_bertscore:.4f}")
            
        except Exception as e:
            print(f"Batch test failed: {e}")
        
        print("\n--- Cache Test ---")
        try:
            cache_size_before = bertscore.get_cache_size()
            
            bertscore.calculate("Normal chest radiograph.", "Normal chest X-ray.")
            bertscore.calculate("Normal chest radiograph.", "Normal chest X-ray.")
            
            cache_size_after = bertscore.get_cache_size()
            print(f"Cache size before: {cache_size_before}")
            print(f"Cache size after: {cache_size_after}")
            
            bertscore.clear_cache()
            cache_size_cleared = bertscore.get_cache_size()
            print(f"Cache size after clearing: {cache_size_cleared}")
            
        except Exception as e:
            print(f"Cache test failed: {e}")
        
        print("\n--- Performance Test ---")
        try:
            perf_stats = bertscore.get_performance_stats()
            print(f"Calculations performed: {perf_stats['calculation_count']}")
            print(f"Average calculation time: {perf_stats['average_time']:.4f}s")
            
        except Exception as e:
            print(f"Performance test failed: {e}")
        
        print("\nBERTScore scorer tests completed!")
        return True
        
    except ImportError as e:
        print(f"BERTScore test skipped - missing dependencies: {e}")
        print("Install with: pip install transformers torch")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    success = test_bertscore_scorer()
    
    if success:
        print("\nBERTScore Scorer tests passed!")
    else:
        print("\nSome tests failed!") 