"""
Logger manager for puzzle benchmark system
"""
import os
import json
import logging
from datetime import datetime
from typing import List, Dict, Any
from dataclasses import asdict

from game_session import SessionResult, RoundData
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import LOGS_DIR

logger = logging.getLogger(__name__)


class LoggerManager:
    """Log manager responsible for saving and managing test results"""
    
    def __init__(self):
        self.ensure_log_directories()
    
    def ensure_log_directories(self):
        """Ensure log directories exist"""
        os.makedirs(LOGS_DIR, exist_ok=True)
        os.makedirs(os.path.join(LOGS_DIR, "sessions"), exist_ok=True)
        os.makedirs(os.path.join(LOGS_DIR, "results"), exist_ok=True)
        logger.info("Log directories initialized")
    
    def save_session_log(self, session_result: SessionResult):
        """保存单次会话日志"""
        session_dir = os.path.join(LOGS_DIR, "sessions", session_result.session_id)
        os.makedirs(session_dir, exist_ok=True)
        
        # 保存元数据
        metadata = {
            "session_id": session_result.session_id,
            "domain": session_result.domain,
            "puzzle_name": session_result.puzzle_name,
            "qa_model": session_result.qa_model,
            "eval_model": session_result.eval_model,
            "start_time": session_result.start_time,
            "end_time": session_result.end_time,
            "success": session_result.success,
            "total_rounds": session_result.total_rounds,
            "final_result": session_result.final_result
        }
        
        metadata_path = os.path.join(session_dir, "metadata.json")
        with open(metadata_path, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, ensure_ascii=False, indent=2)
        
        # 保存对话记录
        conversation_data = {
            "session_id": session_result.session_id,
            "domain": session_result.domain,
            "puzzle_name": session_result.puzzle_name,
            "qa_model": session_result.qa_model,
            "eval_model": session_result.eval_model,
            "start_time": session_result.start_time,
            "end_time": session_result.end_time,
            "rounds": [self._round_to_dict(round_data) for round_data in session_result.rounds],
            "final_result": session_result.final_result
        }
        
        conversation_path = os.path.join(session_dir, "conversation.json")
        with open(conversation_path, 'w', encoding='utf-8') as f:
            json.dump(conversation_data, f, ensure_ascii=False, indent=2)
        
        logger.info(f"Session log saved: {session_result.session_id}")
    
    def _round_to_dict(self, round_data: RoundData) -> Dict[str, Any]:
        """将RoundData转换为字典"""
        result = {
            "round": round_data.round,
            "qa_prompt": round_data.qa_prompt,
            "qa_response": round_data.qa_response,
            "eval_prompt": round_data.eval_prompt,
            "eval_response": round_data.eval_response,
            "pure_parse": round_data.pure_parse,
            "pure_answer": round_data.pure_answer,
            "timestamp": round_data.timestamp,
            "qa_api_retries": round_data.qa_api_retries,
            "eval_api_retries": round_data.eval_api_retries
        }
        
        # 添加reasoning_content字段（如果存在）
        if hasattr(round_data, 'qa_reasoning_content') and round_data.qa_reasoning_content:
            result["qa_reasoning_content"] = round_data.qa_reasoning_content
        if hasattr(round_data, 'eval_reasoning_content') and round_data.eval_reasoning_content:
            result["eval_reasoning_content"] = round_data.eval_reasoning_content
        
        # 添加token统计字段
        if hasattr(round_data, 'qa_prompt_tokens'):
            result["qa_prompt_tokens"] = round_data.qa_prompt_tokens
        if hasattr(round_data, 'qa_completion_tokens'):
            result["qa_completion_tokens"] = round_data.qa_completion_tokens
        if hasattr(round_data, 'qa_total_tokens'):
            result["qa_total_tokens"] = round_data.qa_total_tokens
        if hasattr(round_data, 'eval_prompt_tokens'):
            result["eval_prompt_tokens"] = round_data.eval_prompt_tokens
        if hasattr(round_data, 'eval_completion_tokens'):
            result["eval_completion_tokens"] = round_data.eval_completion_tokens
        if hasattr(round_data, 'eval_total_tokens'):
            result["eval_total_tokens"] = round_data.eval_total_tokens
            
        return result
    
    def save_batch_results(self, results: List[SessionResult], test_name: str = None):
        """保存批量测试结果统计"""
        if not results:
            logger.warning("No results to save")
            return
        
        test_date = datetime.now().strftime("%Y%m%d")
        filename = f"{test_date}_{test_name}_results.json" if test_name else f"{test_date}_results.json"
        results_path = os.path.join(LOGS_DIR, "results", filename)
        
        # 计算统计信息
        total_puzzles = len(results)
        successful = sum(1 for r in results if r.success)
        failed = total_puzzles - successful
        avg_rounds = sum(r.total_rounds for r in results) / total_puzzles if total_puzzles > 0 else 0
        success_rate = successful / total_puzzles if total_puzzles > 0 else 0
        
        # 按领域统计
        by_domain = {}
        for result in results:
            domain = result.domain
            if domain not in by_domain:
                by_domain[domain] = {"total": 0, "success": 0, "rounds": []}
            
            by_domain[domain]["total"] += 1
            if result.success:
                by_domain[domain]["success"] += 1
            by_domain[domain]["rounds"].append(result.total_rounds)
        
        # 计算每个领域的平均轮次
        for domain_stats in by_domain.values():
            if domain_stats["rounds"]:
                domain_stats["avg_rounds"] = sum(domain_stats["rounds"]) / len(domain_stats["rounds"])
            else:
                domain_stats["avg_rounds"] = 0
            del domain_stats["rounds"]  # 删除详细轮次信息，保持简洁
        
        # 构建结果数据
        results_data = {
            "test_date": test_date,
            "test_name": test_name,
            "summary": {
                "total_puzzles": total_puzzles,
                "successful": successful,
                "failed": failed,
                "avg_rounds": round(avg_rounds, 2),
                "success_rate": round(success_rate, 3)
            },
            "by_domain": by_domain,
            "detailed_results": [
                {
                    "session_id": r.session_id,
                    "domain": r.domain,
                    "puzzle_name": r.puzzle_name,
                    "qa_model": r.qa_model,
                    "eval_model": r.eval_model,
                    "success": r.success,
                    "total_rounds": r.total_rounds,
                    "final_answer": r.final_result.get("final_answer"),
                    "reason": r.final_result.get("reason"),
                    "start_time": r.start_time,
                    "end_time": r.end_time
                }
                for r in results
            ]
        }
        
        # 保存到文件
        with open(results_path, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, ensure_ascii=False, indent=2)
        
        logger.info(f"Batch results saved: {filename}")
        logger.info(f"Test summary: {successful}/{total_puzzles} success ({success_rate:.1%}), avg rounds: {avg_rounds:.1f}")
        
        return results_data
    
    def load_session_log(self, session_id: str) -> Dict[str, Any]:
        """加载指定会话的日志"""
        session_dir = os.path.join(LOGS_DIR, "sessions", session_id)
        conversation_path = os.path.join(session_dir, "conversation.json")
        
        if not os.path.exists(conversation_path):
            raise FileNotFoundError(f"Session log not found: {session_id}")
        
        with open(conversation_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    def list_sessions(self, domain: str = None, date_filter: str = None) -> List[str]:
        """列出所有会话ID"""
        sessions_dir = os.path.join(LOGS_DIR, "sessions")
        if not os.path.exists(sessions_dir):
            return []
        
        session_ids = []
        for item in os.listdir(sessions_dir):
            session_path = os.path.join(sessions_dir, item)
            if os.path.isdir(session_path):
                # 应用过滤条件
                if domain and domain not in item:
                    continue
                if date_filter and not item.startswith(date_filter):
                    continue
                session_ids.append(item)
        
        return sorted(session_ids, reverse=True)  # 最新的在前
    
    def get_results_summary(self, date_filter: str = None) -> List[Dict[str, Any]]:
        """获取结果文件摘要"""
        results_dir = os.path.join(LOGS_DIR, "results")
        if not os.path.exists(results_dir):
            return []
        
        summaries = []
        for filename in os.listdir(results_dir):
            if filename.endswith('_results.json'):
                if date_filter and not filename.startswith(date_filter):
                    continue
                
                results_path = os.path.join(results_dir, filename)
                try:
                    with open(results_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)
                        summaries.append({
                            "filename": filename,
                            "test_date": data.get("test_date"),
                            "test_name": data.get("test_name"),
                            "summary": data.get("summary", {})
                        })
                except Exception as e:
                    logger.warning(f"Error reading results file {filename}: {e}")
        
        return sorted(summaries, key=lambda x: x["filename"], reverse=True) 