#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
流畅度评估服务 - 基于PPL困惑度评估
"""

import math
import logging
from typing import List, Dict, Any, Optional
import os

from ...core.interfaces.fluency_interfaces import IFluencyEvaluator, FluencyResult
from ...core.exceptions import ServiceError


class FluencyEvaluatorService(IFluencyEvaluator):
    """流畅度评估服务（基于PPL困惑度）"""
    
    def __init__(self, model_name: str = "gpt2", device: str = None):
        """
        初始化流畅度评估器
        
        Args:
            model_name: 用于计算PPL的模型名称
            device: 计算设备
        """
        self.model_name = model_name
        self.device = device or "cpu"
        self.logger = logging.getLogger(__name__)
        
        # 设置环境变量以避免tokenizers并行处理警告
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        
        try:
            # 尝试导入依赖
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM
            
            self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
            
            # 加载模型和tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(model_name)
            self.model.to(self.device)
            self.model.eval()
            
            self.has_real_models = True
            self.logger.info(f"流畅度评估器初始化完成: {model_name} on {self.device}")
            
        except ImportError as e:
            self.has_real_models = False
            self.model = None
            self.tokenizer = None
            self.logger.warning(f"依赖模块不可用，将使用Mock模式: {e}")
    
    def evaluate_fluency(self, text: str) -> FluencyResult:
        """
        评估单个文本的流畅度
        
        Args:
            text: 要评估的文本
            
        Returns:
            流畅度评估结果
        """
        try:
            if self.has_real_models:
                ppl = self._calculate_perplexity_real(text)
            else:
                ppl = self._calculate_perplexity_mock(text)
            
            # 判断是否流畅（PPL < 100为流畅）
            is_fluent = not math.isinf(ppl) and ppl <= 100.0
            
            return FluencyResult(
                text=text,
                perplexity_score=ppl,
                is_fluent=is_fluent,
                fluency_score=self._ppl_to_fluency_score(ppl),
                evaluation_method="ppl_only"
            )
            
        except Exception as e:
            raise ServiceError(f"流畅度评估失败: {str(e)}", "Fluency")
    
    def batch_evaluate_fluency(self, texts: List[str]) -> List[FluencyResult]:
        """
        批量评估文本流畅度
        
        Args:
            texts: 要评估的文本列表
            
        Returns:
            流畅度评估结果列表
        """
        results = []
        for i, text in enumerate(texts):
            try:
                result = self.evaluate_fluency(text)
                results.append(result)
                
                # 每50个记录一次进度
                if (i + 1) % 50 == 0:
                    self.logger.info(f"流畅度评估进度: {i+1}/{len(texts)}")
                    
            except Exception as e:
                self.logger.warning(f"文本评估失败: {e}")
                # 添加失败的结果
                results.append(FluencyResult(
                    text=text,
                    perplexity_score=float('inf'),
                    is_fluent=False,
                    fluency_score=0.0,
                    evaluation_method="ppl_only",
                    error_message=str(e)
                ))
        
        return results
    
    def _calculate_perplexity_real(self, text: str) -> float:
        """使用真实模型计算困惑度"""
        if not text or not text.strip():
            return float('inf')
        
        try:
            import torch
            
            # 编码文本
            encoded = self.tokenizer(
                text.strip(),
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding=False
            )
            input_ids = encoded["input_ids"].to(self.device)
            
            # 确保序列长度 > 1
            if input_ids.size(1) <= 1:
                return float('inf')
            
            with torch.no_grad():
                # 计算损失
                outputs = self.model(input_ids, labels=input_ids)
                loss = outputs.loss
                
                # 困惑度 = exp(loss)
                perplexity = torch.exp(loss).item()
                
                return perplexity
                
        except Exception as e:
            self.logger.warning(f"PPL计算失败: {e}")
            return float('inf')
    
    def _calculate_perplexity_mock(self, text: str) -> float:
        """Mock模式下的困惑度计算"""
        if not text or not text.strip():
            return float('inf')
        
        # 基于文本长度的简单模拟
        text_len = len(text.strip())
        
        # 模拟PPL: 文本越长，PPL可能越稳定
        if text_len < 10:
            return 80.0  # 短文本较高PPL
        elif text_len < 50:
            return 45.0  # 中等文本
        else:
            return 25.0  # 长文本较低PPL
    
    def _ppl_to_fluency_score(self, ppl: float) -> float:
        """将PPL转换为0-1的流畅度分数"""
        if math.isinf(ppl):
            return 0.0
        
        # PPL越低，流畅度越高
        # 基于经验：PPL < 30 (优秀), 30-60 (良好), 60-100 (可接受), >100 (不流畅)
        if ppl <= 30:
            return 1.0
        elif ppl <= 60:
            return 0.8
        elif ppl <= 100:
            return 0.5
        else:
            return max(0.0, 1.0 - math.log(ppl / 100.0))
