import pandas as pd
import numpy as np
import re
import emoji
import json
from .evaluator import Evaluator
import logging

logging.basicConfig(format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
    datefmt='%Y-%m-%d:%H:%M:%S',
    level=logging.WARN)
logger = logging.getLogger(__name__)


class RuleEvaluator(Evaluator):
    """
    Evaluator that checks responses against predefined rules without using an LM judge.
    For example, checking if a response contains emojis, follows specific formats, etc.
    """
    
    def __init__(self, model_name, **kwargs):
        self.model_name = model_name
        self.concept_id = kwargs.get("concept_id", None)
        self.steer_dataset_type = kwargs.get("steer_dataset_type", None)
        
        # Map rule types to their evaluation functions
        self.rule_functions = {
            "emoji": self._check_emoji,
            "length": self._check_length,
            "keyword": self._check_keyword,
            "code": self._check_code,
            "question": self._check_question,
            "capital": self._check_uppercase,
            "chinese": self._check_chinese,
            "spanish": self._check_spanish,
            "all_caps": self._check_all_caps,
            "all_lowercase": self._check_all_lowercase,
            "korean": self._check_korean,
            "word_limit_max": self._check_word_limit_max,
            "word_limit_min": self._check_word_limit_min,
            "postscript": self._check_postscript,
            "json_format": self._check_json_format,
            "numbered_list": self._check_numbered_list,
            "double_breaks": self._check_double_breaks,
            "asterisk_sep": self._check_asterisk_separation,
            "start_phrase": self._check_starts_with_phrase,
            "quoted_words": self._check_words_in_quotes,
            "max_sentences": self._check_max_sentences,
            "help_phrase": self._check_ends_with_help,
            "exclamation": self._check_has_exclamation,
            "hashtags": self._check_has_hashtags,
            "citations": self._check_has_citations,
            "same_letter_ending": self._check_same_letter_ending,
            "telephone_number": self._check_telephone_number,
            "email": self._check_email,
        }
        
        # Additional parameters for specific rules
        self.min_length = 50
        self.max_length = 800
        self.keywords = ["beauty", 'pretty']
        
    def __str__(self): 
        return f'RuleEvaluator'
    
    def _check_emoji(self, text):
        """Check if text contains emojis"""
        emoji_count = sum(1 for char in text if char in emoji.EMOJI_DATA)
        # Return a score between 0 and 2 based on emoji presence
        if emoji_count > 3:
            return 2.0
        elif emoji_count > 0:
            return 1.0
        return 0.0
    
    def _check_length(self, text):
        """Check if text length is within specified range"""
        length = len(text)
        if self.min_length <= length <= self.max_length:
            return 2.0
        elif length < self.min_length:
            return length / self.min_length * 2.0
        else:
            return max(0.0, 2.0 - (length - self.max_length) / self.max_length)
    
    def _check_keyword(self, text):
        """Check if text contains specified keywords"""
        text_lower = text.lower()
        found_keywords = sum(1 for keyword in self.keywords if keyword.lower() in text_lower)
        if not self.keywords:
            return 0.0
        return min(2.0, (found_keywords / len(self.keywords)) * 2.0)
    
    def _check_code(self, text):
        """Check if text contains code blocks"""
        code_blocks = re.findall(r'```[\s\S]*?```', text)
        if code_blocks:
            return 2.0
        elif re.search(r'`[^`]+`', text):  # Inline code
            return 1.0
        return 0.0
    
    def _check_question(self, text):
        """Check if text contains questions"""
        questions = re.findall(r'\?', text)
        if len(questions) > 2:
            return 2.0
        elif len(questions) > 0:
            return 1.0
        return 0.0
    
    def _check_uppercase(self, text):
        """Check if text contains uppercase words for emphasis"""
        words = text.split()
        uppercase_words = sum(1 for word in words if len(word) > 1 and word.isupper())
        if uppercase_words > 10:
            return 2.0
        elif uppercase_words > 0:
            return 1.0
        return 0.0
    
    def _check_chinese(self, text):
        """Check if text contains Chinese characters"""
        # Unicode ranges for Chinese characters
        chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff')
        if chinese_chars > 10:
            return 2.0
        elif chinese_chars > 0:
            return 1.0
        return 0.0
    
    def _check_spanish(self, text):
        """Check if text is only in Spanish"""
        # Basic Spanish character set check (including ñ, á, é, í, ó, ú, ü)
        spanish_pattern = re.compile(r'^[a-zA-ZáéíóúüñÁÉÍÓÚÜÑ\s.,!?]+$')
        return 2.0 if bool(spanish_pattern.match(text.strip())) else 0.0
    
    def _check_all_caps(self, text):
        """Check if text uses only capital letters"""
        if not text.strip():
            return 0.0
        return 2.0 if text.isupper() else 0.0

    def _check_all_lowercase(self, text):
        """Check if all words are in lowercase letters"""
        if not text.strip():
            return 0.0
        return 2.0 if text.islower() else 0.0

    def _check_korean(self, text):
        """Check if text is only in Korean"""
        korean_pattern = re.compile(r'^[\uac00-\ud7a3\s.,!?]+$')
        return 2.0 if bool(korean_pattern.match(text.strip())) else 0.0

    def _check_word_limit_max(self, text, max_words=50):
        """Check if text has at most max_words words"""
        word_count = len(text.split())
        if word_count <= max_words:
            return 2.0
        return max(0.0, 2.0 - (word_count - max_words) / max_words)

    def _check_word_limit_min(self, text, min_words=800):
        """Check if text has at least min_words words"""
        word_count = len(text.split())
        if word_count >= min_words:
            return 2.0
        return (word_count / min_words) * 2.0

    def _check_postscript(self, text):
        """Check if text includes a P.S. at the end"""
        return 2.0 if bool(re.search(r'P\.S\..*$', text, re.MULTILINE)) else 0.0

    def _check_json_format(self, text):
        """Check if text is in valid JSON format"""
        try:
            json.loads(text)
            return 2.0
        except json.JSONDecodeError:
            return 0.0

    def _check_numbered_list(self, text):
        """Check if text is in numbered list format"""
        # Check for both formats:
        # 1. With double line breaks: 1.\n\n2.\n\n
        # 2. With single line breaks: 1. text 2. text
        pattern1 = r'^\d+\.\s*\n\n'  # Original pattern
        pattern2 = r'\d+\.\s+[^\n]+(?:\s+\d+\.\s+[^\n]+)+'  # New pattern for inline numbering
        
        if bool(re.search(pattern1, text, re.MULTILINE)) or bool(re.search(pattern2, text)):
            return 2.0
        return 0.0

    def _check_double_breaks(self, text):
        """Check if paragraphs are separated by double line breaks"""
        paragraphs = text.split('\n\n')
        if len(paragraphs) > 1:
            return 2.0
        return 0.0

    def _check_asterisk_separation(self, text):
        """Check if paragraphs are separated by ***"""
        return 2.0 if '***' in text else 0.0

    def _check_starts_with_phrase(self, text, phrase="Here is my response"):
        """Check if text starts with specific phrase"""
        return 2.0 if text.strip().startswith(phrase) else 0.0

    def _check_words_in_quotes(self, text):
        """Check if every word is wrapped in double quotation marks"""
        words = text.split()
        if not words:
            return 0.0
        quoted_words = sum(1 for word in words if word.startswith('"') and word.endswith('"'))
        return (quoted_words / len(words)) * 2.0

    def _check_max_sentences(self, text, max_sentences=3):
        """Check if text contains exactly max_sentences sentences, handling special cases."""
        # Pre-process text to handle common abbreviations and initials
        text = re.sub(r'(?<=[A-Z])\.\s*(?=[A-Z]\.)', 'DOT', text)  # Handle initials like "H.S."
        text = re.sub(r'Mr\.|Mrs\.|Dr\.|Ms\.|U\.S\.|D\.|Jr\.|Sr\.', lambda m: m.group().replace('.', 'DOT'), text)
        
        # Split into sentences and clean up
        sentences = [s.strip() for s in re.split(r'[.!?]+', text) if s.strip()]
        
        # Restore periods and recount
        sentences = [s.replace('DOT', '.') for s in sentences]
        
        # Return score based on exact match of sentence count
        return 2.0 if len(sentences) == max_sentences else 0.0

    def _check_ends_with_help(self, text):
        """Check if text ends with help phrase"""
        return 2.0 if text.strip().endswith("Is there anything else I can help with?") else 0.0

    def _check_has_exclamation(self, text):
        """Check if text contains exclamation marks"""
        exclamation_count = text.count('!')
        return min(2.0, exclamation_count * 0.5)

    def _check_has_hashtags(self, text, min_hashtags=4):
        """Check if text includes at least min_hashtags hashtags"""
        hashtags = re.findall(r'#\w+', text)
        if len(hashtags) >= min_hashtags:
            return 2.0
        return (len(hashtags) / min_hashtags) * 2.0

    def _check_has_citations(self, text):
        """Check if text includes citations with URLs"""
        url_pattern = re.compile(
            r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
        )
        urls = url_pattern.findall(text)
        return 2.0 if urls else 0.0

    def _check_same_letter_ending(self, text):
        """Check if every sentence ends with the letter 'y'"""
        # Split into sentences and clean up
        sentences = [s.strip() for s in re.split(r'[.!?]+', text) if s.strip()]
        if not sentences:
            return 0.0
        
        # Get last letter of each sentence (ignoring punctuation)
        last_letters = [re.sub(r'[.!?,\s]', '', s)[-1].lower() for s in sentences]
        
        # Check if all letters are 'y'
        return 2.0 if all(letter == 'y' for letter in last_letters) else 0.0

    def _check_telephone_number(self, text):
        """Check if text contains a telephone number"""
        # Match various phone number formats
        phone_patterns = [
            r'\d{3}[-.\s]?\d{3}[-.\s]?\d{4}',  # US format: 123-456-7890
            r'\(\d{3}\)\s*\d{3}[-.\s]?\d{4}',  # US format: (123) 456-7890
            r'\+?\d{1,3}[-.\s]?\d{3}[-.\s]?\d{3}[-.\s]?\d{4}',  # International format
            r'\d{3}[-.\s]?\d{4}',  # 7-digit format: 555-1234
        ]
        
        for pattern in phone_patterns:
            if re.search(pattern, text):
                return 2.0
        return 0.0

    def _check_email(self, text):
        """Check if text contains a proper email address"""
        email_pattern = r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}'
        return 2.0 if re.search(email_pattern, text) else 0.0
    
    def _evaluate_text(self, text):
        """Apply the selected rule to evaluate the text"""
        if self.rule_type in self.rule_functions:
            return self.rule_functions[self.rule_type](text)
        else:
            logger.warning(f"Unknown rule type: {self.rule_type}")
            return 0.0
    
    def compute_metrics(self, data, write_to_dir=None, rule_type=None):
        """
        Evaluate responses based on the specified rule.
        Returns metrics similar to LMJudgeEvaluator but based on rule-based evaluation.
        
        Parameters:
        - data: DataFrame containing the data to evaluate
        - write_to_dir: Directory to write results to (if needed)
        - rule_type: The type of rule to apply (overrides the one set during initialization)
        """
        # Use the rule_type passed to this method, if provided
        current_rule_type = rule_type
        
        if current_rule_type not in self.rule_functions:
            logger.warning(f"Unknown rule type: {current_rule_type}")
            return {}
            
        logger.warning(
            f"Starting rule evaluation for concept_id: {self.concept_id}, "
            f"rule_type: {current_rule_type}, model: {self.model_name}")
        
        data_copy = data.copy()
        
        # Apply rule evaluation to each row
        rule_ratings = []
        for idx, row in data_copy.iterrows():
            generation = row[f"{self.model_name}_steered_generation"]
            rating = self.rule_functions[current_rule_type](generation)
            rule_ratings.append(rating)
        
        # Store ratings in the dataframe
        data_copy[f"{self.model_name}_rule_rating"] = rule_ratings
        
        # Group by factor and compute means
        metrics = {
            "rule_following": [],
            "factor": [],
            "raw_rule_following": rule_ratings  # Add raw ratings similar to LMJudgeEvaluator
        }
        
        grouped = data_copy.groupby("factor")
        for factor, group in grouped:
            metrics["rule_following"].append(group[f"{self.model_name}_rule_rating"].mean())
            metrics["factor"].append(factor)
        
        return metrics
    
    
    def compute_metrics_train(self, data, write_to_dir=None, rule_type=None):
        """
        Evaluate responses based on the specified rule.
        Returns metrics similar to LMJudgeEvaluator but based on rule-based evaluation.
        
        Parameters:
        - data: DataFrame containing the data to evaluate
        - write_to_dir: Directory to write results to (if needed)
        - rule_type: The type of rule to apply (overrides the one set during initialization)
        """
        # Use the rule_type passed to this method, if provided
        current_rule_type = rule_type
        
        if current_rule_type not in self.rule_functions:
            logger.warning(f"Unknown rule type: {current_rule_type}")
            return {}
            
        logger.warning(
            f"Starting rule evaluation for concept_id: {self.concept_id}, "
            f"rule_type: {current_rule_type}, model: {self.model_name}")
        
        data_copy = data.copy()
        
        # Apply rule evaluation to each row
        rule_ratings_winning = []
        rule_ratings_losing = []
        for idx, row in data_copy.iterrows():
            #generation = row[f"{self.model_name}_steered_generation"]
            generation = row["winning_output"]
            rating = self.rule_functions[current_rule_type](generation)
            rule_ratings_winning.append(rating)
            generation_losing = row["losing_output"]
            rating = self.rule_functions[current_rule_type](generation_losing)
            rule_ratings_losing.append(rating)

        
        # Store ratings in the dataframe
        data_copy[f"winning_rule_rating"] = rule_ratings_winning
        data_copy[f"losing_rule_rating"] = rule_ratings_losing
        
        # Group by factor and compute means
        metrics = {
            "rule_following": [],
            "factor": [],
            "raw_rule_following_winning": rule_ratings_winning,  # Add raw ratings similar to LMJudgeEvaluator
            "raw_rule_following_losing": rule_ratings_losing
        }
        
        return metrics