import json
import os
import httpx
import asyncio
from tqdm.asyncio import tqdm
import yaml
import argparse
import logging
from datetime import datetime
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import hashlib

from prompts import JUDGE_USER_PROMPT, JUDGE_SYSTEM_PROMPT

def load_json(file):
    with open(file, 'r', encoding="utf8") as f:
        return json.load(f)


def write_json(file, data):
    os.makedirs(os.path.dirname(file) if os.path.dirname(file) else '.', exist_ok=True)
    with open(file, "w", encoding="utf8") as f:
        json.dump(data, f, indent=4, ensure_ascii=False)


def extract_boxed_content(latex_string: str) -> Optional[str]:
    if not latex_string or not isinstance(latex_string, str):
        return None
    
    all_starts = []
    i = 0
    while i < len(latex_string):
        pos = latex_string.find('\\boxed{', i)
        if pos == -1:
            break
        all_starts.append(pos)
        i = pos + 1
    
    if not all_starts:
        return None
    
   
    start = all_starts[-1] + 7  
    stack = 1
    i = start
    while i < len(latex_string):
        if latex_string[i] == '{':
            stack += 1
        elif latex_string[i] == '}':
            stack -= 1
            if stack == 0:
                return latex_string[start:i].strip()
        i += 1
    
    return None


def extract_final_decision(response_text: str) -> str:
   
    if not isinstance(response_text, str):
        return "unknown"

    pattern = r'\[\[FINAL DECISION\]\]\s*(Yes|No)'
    match = re.search(pattern, response_text, re.IGNORECASE)
    if match:
        return match.group(1)


    response_lower = response_text.lower()
    if "final decision: yes" in response_lower:
        return "Yes"
    elif "final decision: no" in response_lower:
        return "No"
    else:
        return "unknown"


def generate_task_id(data_id: str, answer_idx: int) -> str:
    """生成唯一任务 ID"""
    return f"{data_id}###{answer_idx}"


def parse_task_id(task_id: str) -> Tuple[str, int]:

    parts = task_id.split("###")
    return parts[0], int(parts[1])


class JudgmentTask:

    def __init__(self, task_id: str, data_id: str, answer_idx: int, 
                 question: str, ground_truth: str, student_answer: str,
                 original_answer: str):
        self.task_id = task_id
        self.data_id = data_id
        self.answer_idx = answer_idx
        self.question = question
        self.ground_truth = ground_truth
        self.student_answer = student_answer
        self.original_answer = original_answer
        self.result = None
        self.is_completed = False
        self.error_count = 0


class TaskManager:
   
    def __init__(self, cache_dir: str):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.tasks: Dict[str, JudgmentTask] = {}
        self.completed_tasks: Dict[str, dict] = {}
        
    def create_tasks_from_dataset(self, dataset: List[dict]) -> List[JudgmentTask]:
        tasks = []
        
        for data_item in dataset:
            data_id = data_item['id']
            question = data_item['question']
            ground_truth = data_item['answer']
            
            for answer_idx, answer_item in enumerate(data_item.get('generated_answer', [])):

                raw_answer = answer_item.get('answer', '')
                
                final_answer = extract_boxed_content(raw_answer)
                
                student_answer = final_answer if final_answer else "<NO EXTRACTED ANSWER>"
                
                task_id = generate_task_id(data_id, answer_idx)
                
                task = JudgmentTask(
                    task_id=task_id,
                    data_id=data_id,
                    answer_idx=answer_idx,
                    question=question,
                    ground_truth=ground_truth,
                    student_answer=student_answer,
                    original_answer=raw_answer
                )
                
                self.tasks[task_id] = task
                tasks.append(task)
        
        return tasks
    
    def load_cached_results(self) -> int:

        loaded_count = 0
        
        for cache_file in self.cache_dir.glob("*.json"):
            try:
                with open(cache_file, 'r', encoding='utf-8') as f:
                    result = json.load(f)
                
                task_id = result.get('task_id')
                if task_id and task_id in self.tasks:
                    # 验证结果的完整性
                    if self._validate_result(result):
                        self.completed_tasks[task_id] = result
                        self.tasks[task_id].is_completed = True
                        self.tasks[task_id].result = result
                        loaded_count += 1
                    else:
                        logging.warning(f"Invalid cached result for {task_id}, will reprocess")
                        cache_file.unlink()  # 删除无效缓存
            except Exception as e:
                logging.error(f"Error loading cache {cache_file}: {e}")
        
        return loaded_count
    
    def _validate_result(self, result: dict) -> bool:
        """验证结果的有效性"""
        required_fields = ['task_id', 'data_id', 'answer_idx', 'score', 'full_response']
        if not all(field in result for field in required_fields):
            return False
        
        # 验证 score 不是错误状态
        if result['score'] in ['error', 'unknown', None]:
            return False
            
        return True
    
    def save_result(self, task_id: str, result: dict):
        """保存单个任务结果到缓存"""
        cache_file = self.cache_dir / f"{task_id}.json"
        try:
            write_json(str(cache_file), result)
            self.completed_tasks[task_id] = result
            if task_id in self.tasks:
                self.tasks[task_id].is_completed = True
                self.tasks[task_id].result = result
        except Exception as e:
            logging.error(f"Error saving result for {task_id}: {e}")
    
    def get_pending_tasks(self) -> List[JudgmentTask]:
        """获取待处理的任务"""
        return [task for task in self.tasks.values() if not task.is_completed]
    
    def get_completion_stats(self) -> dict:
        """获取完成统计"""
        total = len(self.tasks)
        completed = len(self.completed_tasks)
        return {
            'total': total,
            'completed': completed,
            'pending': total - completed,
            'completion_rate': completed / total if total > 0 else 0
        }
    
    def export_final_results(self, original_dataset: List[dict]) -> List[dict]:
        """导出最终结果，按原始数据集结构组织"""
        results = []
        
        for data_item in original_dataset:
            data_id = data_item['id']
            result_item = data_item.copy()
            
            # 初始化 judgments 数组
            num_answers = len(data_item.get('generated_answer', []))
            result_item['judgments'] = [None] * num_answers
            
            # 填充判断结果
            for answer_idx in range(num_answers):
                task_id = generate_task_id(data_id, answer_idx)
                if task_id in self.completed_tasks:
                    task_result = self.completed_tasks[task_id]
                    
  
                    raw_answer = data_item['generated_answer'][answer_idx].get('answer', '')
                    extracted = extract_boxed_content(raw_answer)
                
                    result_item['judgments'][answer_idx] = {
                        'full_response': task_result['full_response'],
                        'score': task_result['score'],
                        'reasoning': task_result.get('reasoning', ''),
                        'extracted_answer': extracted  
                    }
            
            results.append(result_item)
        
        return results


# ==================== 并发请求处理 ====================

class TokenStatistics:
    """Token 统计（线程安全）"""
    def __init__(self, price_log_path: str, model_name: str, price: dict):
        self.lock = asyncio.Lock()
        self.stats = {
            'total_prompt_tokens': 0,
            'total_completion_tokens': 0,
            'total_tokens': 0,
            'prompt_price': 0.0,
            'completion_price': 0.0,
            'total_price': 0.0,
            'requests_count': 0,
            'success_count': 0,
            'error_count': 0
        }
        self.price_log_path = price_log_path
        self.model_name = model_name
        self.price = price
        
        # 初始化日志文件
        os.makedirs(os.path.dirname(price_log_path), exist_ok=True)
        with open(price_log_path, 'w', encoding='utf-8') as f:
            f.write(f"Token statistics for model: {model_name}\n")
            f.write(f"Start time: {datetime.now()}\n")
            f.write("=" * 50 + "\n\n")
    
    async def record_usage(self, usage: dict, success: bool = True):
        """记录 token 使用情况"""
        async with self.lock:
            prompt_tokens = usage.get('prompt_tokens', 0)
            completion_tokens = usage.get('completion_tokens', 0)
            total_tokens = usage.get('total_tokens', 0)
            
            prompt_price = prompt_tokens * self.price['prompt']
            completion_price = completion_tokens * self.price['completion']
            total_price = prompt_price + completion_price
            
            self.stats['total_prompt_tokens'] += prompt_tokens
            self.stats['total_completion_tokens'] += completion_tokens
            self.stats['total_tokens'] += total_tokens
            self.stats['prompt_price'] += prompt_price
            self.stats['completion_price'] += completion_price
            self.stats['total_price'] += total_price
            self.stats['requests_count'] += 1
            
            if success:
                self.stats['success_count'] += 1
            else:
                self.stats['error_count'] += 1
            
            # 写入日志
            with open(self.price_log_path, 'a', encoding='utf-8') as f:
                f.write(f"Request #{self.stats['requests_count']} ({'Success' if success else 'Error'})\n")
                f.write(f"  Prompt tokens: {prompt_tokens}\n")
                f.write(f"  Completion tokens: {completion_tokens}\n")
                f.write(f"  Total tokens: {total_tokens}\n")
                f.write(f"  Total price: {total_price:.6f}\n")
                f.write("-" * 40 + "\n")
    
    def write_summary(self):
        """写入统计摘要"""
        with open(self.price_log_path, 'a', encoding='utf-8') as f:
            f.write("\n" + "=" * 50 + "\n")
            f.write("SUMMARY STATISTICS\n")
            f.write(f"End time: {datetime.now()}\n")
            f.write(f"Total requests: {self.stats['requests_count']}\n")
            f.write(f"Success: {self.stats['success_count']}\n")
            f.write(f"Errors: {self.stats['error_count']}\n")
            f.write(f"Total prompt tokens: {self.stats['total_prompt_tokens']}\n")
            f.write(f"Total completion tokens: {self.stats['total_completion_tokens']}\n")
            f.write(f"Total tokens: {self.stats['total_tokens']}\n")
            f.write(f"Total price: {self.stats['total_price']:.6f}\n")
            
            if self.stats['requests_count'] > 0:
                avg_total = self.stats['total_tokens'] / self.stats['requests_count']
                avg_price = self.stats['total_price'] / self.stats['requests_count']
                f.write(f"Average tokens per request: {avg_total:.2f}\n")
                f.write(f"Average price per request: {avg_price:.6f}\n")


class RequestHandler:
    """请求处理器"""
    def __init__(self, config: dict, headers: dict, token_stats: TokenStatistics, 
                 semaphore: asyncio.Semaphore):
        self.config = config
        self.headers = headers
        self.token_stats = token_stats
        self.semaphore = semaphore
        self.model_name = config['model_id']
        self.base_url = config['base_url']
        self.temperature = config.get('temperature', 1.0) + config.get('add_temperature', 0.0)
    
    async def send_single_request(self, client: httpx.AsyncClient, 
                                  messages: List[dict]) -> dict:
        """发送单个请求"""
        json_data = {
            'model': self.model_name,
            'messages': messages,
            'temperature': self.temperature,
            'stream': False,
        }
        
        try:
            response = await client.post(self.base_url, headers=self.headers, json=json_data)
            response.raise_for_status()
            result = response.json()
            
            # 记录 token 使用
            await self.token_stats.record_usage(result.get('usage', {}), success=True)
            
            return {
                'success': True,
                'answer': result['choices'][0]['message']['content'],
                'reasoning': result['choices'][0]['message'].get('reasoning_content', '')
            }
        
        except Exception as e:
            logging.error(f"Request failed: {str(e)}")
            # 记录错误（无 token 消耗）
            await self.token_stats.record_usage({}, success=False)
            
            return {
                'success': False,
                'error': str(e)
            }
    
    async def process_task(self, task: JudgmentTask, task_manager: TaskManager, 
                          max_retries: int = 3) -> bool:
        """处理单个任务（带重试）"""
        async with self.semaphore:
            # 构建消息
            prompt = JUDGE_USER_PROMPT.format(
                question=task.question,
                ground_truth=task.ground_truth,
                student_answer=task.student_answer
            )
            
            messages = [
                {'role': 'system', 'content': JUDGE_SYSTEM_PROMPT},
                {'role': 'user', 'content': prompt}
            ]
            
            # 重试逻辑
            limits = httpx.Limits(max_keepalive_connections=5, max_connections=10)
            timeout = httpx.Timeout(360.0)
            
            for attempt in range(max_retries):
                try:
                    async with httpx.AsyncClient(timeout=timeout, limits=limits) as client:
                        result = await self.send_single_request(client, messages)
                    
                    if result['success']:
                        # 提取判断结果
                        response_text = result['answer']
                        final_decision = extract_final_decision(response_text)
                        
                        # 保存结果
                        task_result = {
                            'task_id': task.task_id,
                            'data_id': task.data_id,
                            'answer_idx': task.answer_idx,
                            'full_response': response_text,
                            'reasoning': result.get('reasoning', ''),
                            'score': final_decision,
                            'timestamp': datetime.now().isoformat()
                        }
                        
                        task_manager.save_result(task.task_id, task_result)
                        return True
                    
                    else:
                        task.error_count += 1
                        if attempt < max_retries - 1:
                            await asyncio.sleep(2 ** attempt)  # 指数退避
                        
                except Exception as e:
                    logging.error(f"Task {task.task_id} attempt {attempt + 1} failed: {e}")
                    task.error_count += 1
                    if attempt < max_retries - 1:
                        await asyncio.sleep(2 ** attempt)
            
            # 所有重试都失败
            logging.error(f"Task {task.task_id} failed after {max_retries} attempts")
            
            # 保存错误结果
            error_result = {
                'task_id': task.task_id,
                'data_id': task.data_id,
                'answer_idx': task.answer_idx,
                'full_response': f"ERROR after {max_retries} attempts",
                'reasoning': '',
                'score': 'error',
                'error_count': task.error_count,
                'timestamp': datetime.now().isoformat()
            }
            
            task_manager.save_result(task.task_id, error_result)
            return False


# ==================== 主处理流程 ====================

async def process_all_tasks(config: dict, headers: dict, dataset: List[dict], 
                           output_file: str, price_log_path: str, price: dict):
    """处理所有任务的主函数"""
    
    # 初始化组件
    cache_dir = os.path.join(os.path.dirname(output_file), 'cache')
    task_manager = TaskManager(cache_dir)
    
    # 创建任务
    logging.info("Creating tasks from dataset...")
    all_tasks = task_manager.create_tasks_from_dataset(dataset)
    logging.info(f"Created {len(all_tasks)} tasks")
    
    # 加载缓存结果
    logging.info("Loading cached results...")
    cached_count = task_manager.load_cached_results()
    logging.info(f"Loaded {cached_count} cached results")
    
    # 获取待处理任务
    pending_tasks = task_manager.get_pending_tasks()
    stats = task_manager.get_completion_stats()
    
    logging.info(f"Total: {stats['total']}, Completed: {stats['completed']}, "
                f"Pending: {stats['pending']}, Progress: {stats['completion_rate']:.1%}")
    
    if not pending_tasks:
        logging.info("All tasks already completed!")
    else:
        # 初始化请求处理器
        semaphore = asyncio.Semaphore(config.get('concurrency', 10))
        token_stats = TokenStatistics(price_log_path, config['model_id'], price)
        request_handler = RequestHandler(config, headers, token_stats, semaphore)
        
        # 并发处理任务
        logging.info(f"Processing {len(pending_tasks)} pending tasks...")
        
        # 使用 tqdm 显示进度
        tasks_coroutines = [
            request_handler.process_task(task, task_manager, max_retries=3)
            for task in pending_tasks
        ]
        
        # 使用 tqdm.asyncio 来显示进度
        results = []
        with tqdm(total=len(tasks_coroutines), desc="Processing tasks") as pbar:
            for coro in asyncio.as_completed(tasks_coroutines):
                result = await coro
                results.append(result)
                pbar.update(1)
                
                # 每完成 50 个任务，显示一次统计
                if len(results) % 50 == 0:
                    success_count = sum(1 for r in results if r)
                    pbar.set_postfix({
                        'success': success_count,
                        'failed': len(results) - success_count
                    })
        
        # 写入统计摘要
        token_stats.write_summary()
        
        success_count = sum(1 for r in results if r)
        logging.info(f"Processing completed: {success_count}/{len(results)} successful")
    
    # 导出最终结果
    logging.info("Exporting final results...")
    final_results = task_manager.export_final_results(dataset)
    write_json(output_file, final_results)
    
    # 最终统计
    final_stats = task_manager.get_completion_stats()
    logging.info(f"Final statistics: {final_stats}")
    
    return final_results


# ==================== 主函数 ====================

def main():
    parser = argparse.ArgumentParser(description='Concurrent judgment system')
    parser.add_argument('--config_file', type=str, required=True, help='Path to config YAML file')
    args = parser.parse_args()
    
    # 加载配置
    with open(args.config_file, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    # 设置路径
    input_file = config['input_file']
    output_dir = os.path.join(
        config['output_dir'],
        config['dataset_version'],
        config['task']
    )
    output_file = os.path.join(output_dir, f"{config['task']}{config.get('save_id', '')}.json")
    price_log_path = os.path.join(
        output_dir, 
        'price', 
        f'log_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.txt'
    )
    
    os.makedirs(output_dir, exist_ok=True)
    
    # 设置日志
    log_path = os.path.join(
        output_dir, 
        'logs', 
        f'log{config.get("save_id", "")}_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.txt'
    )
    os.makedirs(os.path.dirname(log_path), exist_ok=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_path, encoding='utf-8'),
            logging.StreamHandler()
        ]
    )
    
    # 加载价格配置
    price_file = config.get('price_file')
    with open(price_file, 'r', encoding='utf-8') as f:
        price_data = json.load(f)
    price = price_data[config['model_id']]
    
    # 设置请求头
    headers = {
        'Authorization': config['api_key'],
        'Content-Type': 'application/json',
    }
    
    # 加载数据集
    logging.info(f"Loading dataset from {input_file}")
    full_dataset = load_json(input_file)
    
    start_idx = config.get('start', 0)
    end_idx = config.get('end', len(full_dataset))
    dataset = full_dataset[start_idx:end_idx]
    
    logging.info(f"Dataset loaded: {len(dataset)} items (from {start_idx} to {end_idx})")
    logging.info(f"Model: {config['model_id']}")
    logging.info(f"Concurrency: {config.get('concurrency', 10)}")
    logging.info(f"Output: {output_file}")
    
    # 运行主任务
    asyncio.run(process_all_tasks(config, headers, dataset, output_file, price_log_path, price))
    
    logging.info("All done!")


if __name__ == "__main__":
    main()