import json
import os
import httpx
import asyncio
from tqdm.asyncio import tqdm

import yaml
import argparse
import logging
import re

from prompts import CONSISTENCY_CHECK_PROMPT  # 修改导入

from datetime import datetime


parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None)
args = parser.parse_args()


with open(args.config, 'r') as file:
    config = yaml.safe_load(file)

def load_json(file):
    with open(file,'r', encoding="utf8") as load_f:
        data = json.load(load_f)
        return data
    
def write_json(file, dict):
    with open(file, "w", encoding="utf8") as f:
        json.dump(dict, f, indent=4, ensure_ascii=False)

def build_user_prompts(data):
    """构建用户prompt - 适配一致性检查"""
    prompts = []
    
    for d in data:
        # 使用 formalized_cot 而不是 cot
        prompt = CONSISTENCY_CHECK_PROMPT.format(
            question=d['question'], 
            thinking_process=d['formalized_cot']
        )
        prompts.append(prompt)
      
    return prompts


def build_messages(data):
    user_prompts = build_user_prompts(data=data)
    messages = []
    
    for i in range(len(data)):
        messages.append([
                {'role': 'user', 'content': f'{user_prompts[i]}'}
            ]
        )
    return messages


def extract_between(text: str, tag: str) -> str | None:
    open_tag, close_tag = f"<{tag}>", f"</{tag}>"
    start = text.find(open_tag) + len(open_tag)
    end = text.find(close_tag)
    if start > len(open_tag)-1 and end > -1:
        return text[start:end].strip()
    return None




# 加载配置
input_file = config['input_file']
output_dir = os.path.join(
    config['output_dir'], 
    config['dataset_version'],
    config['task']
)

output_file = os.path.join(output_dir, f"{config['task']}{config.get('save_id', '')}.json")

# 设置日志文件
log_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
price_log_path = os.path.join(output_dir, 'price', f'log_{log_timestamp}.txt')

os.makedirs(os.path.join(output_dir, 'price'), exist_ok=True)

# 加载价格配置
price_file = config.get('price_file')
with open(price_file, 'r') as p:
    price_data = json.load(p)

os.makedirs(os.path.join(output_dir, 'requests'), exist_ok=True)
log_path = os.path.join(output_dir, 'requests', f"log{config.get('save_id', '')}_{log_timestamp}.txt")

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename=log_path,
    filemode='a'
)

if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
model_name = config['model_id']
price = price_data[model_name]

print(model_name)
print(config.get('shot_num'))
print(config.get('save_id'))

headers = {
    'Authorization': config['api_key'],
    'Content-Type': 'application/json',
}

# 设置并发数
CONCURRENCY = config.get('concurrency')
semaphore = asyncio.Semaphore(CONCURRENCY)

# token统计信息
token_stats = {
    'total_prompt_tokens': 0,
    'total_completion_tokens': 0,
    'total_tokens': 0,
    'prompt_price': 0,
    'completion_price': 0,
    'total_price': 0,
    'requests_count': 0
}

def log_token_stats():
    with open(price_log_path, 'a') as log_file:
        log_file.write("\n" + "=" * 50 + "\n")
        log_file.write("SUMMARY STATISTICS\n")
        log_file.write(f"Total requests: {token_stats['requests_count']}\n")
        log_file.write(f"Total prompt tokens: {token_stats['total_prompt_tokens']}\n")
        log_file.write(f"Total completion tokens: {token_stats['total_completion_tokens']}\n")
        log_file.write(f"Total tokens: {token_stats['total_tokens']}\n")
        log_file.write(f"Total prompt price: {token_stats['prompt_price']}\n")
        log_file.write(f"Total completion price: {token_stats['completion_price']}\n")
        log_file.write(f"Total price: {token_stats['total_price']}\n")
        
        if token_stats['requests_count'] > 0:
            avg_prompt = token_stats['total_prompt_tokens'] / token_stats['requests_count']
            avg_completion = token_stats['total_completion_tokens'] / token_stats['requests_count']
            avg_total = token_stats['total_tokens'] / token_stats['requests_count']
            log_file.write(f"Average prompt tokens per request: {avg_prompt}\n")
            log_file.write(f"Average completion tokens per request: {avg_completion}\n")
            log_file.write(f"Average total tokens per request: {avg_total}\n")
            log_file.write(f"Average prompt price per request: {token_stats['prompt_price'] / token_stats['requests_count']}\n")
            log_file.write(f"Average completion price per request: {token_stats['completion_price'] / token_stats['requests_count']}\n")
            log_file.write(f"Average total price per request: {token_stats['total_price'] / token_stats['requests_count']}\n")


async def get_batch_response_with_retry(messages, n=1, max_retries=3):
    last_exception = None
    for attempt in range(max_retries):
        try:
            return await get_batch_response(messages, n)
        except Exception as e:
            last_exception = e
            logging.error(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                await asyncio.sleep(1 * (attempt + 1))
            else:
                logging.error(f"Failed after {max_retries} attempts")
                raise last_exception


async def get_batch_response(messages, n=1):
    all_answers = []
    async with semaphore:
        limits = httpx.Limits(max_keepalive_connections=5, max_connections=10)
        async with httpx.AsyncClient(timeout=httpx.Timeout(360.0), limits=limits) as client:
            tasks = []
            for message in messages:
                json_data = {
                    'model': model_name,
                    'messages': message,
                    'temperature': config.get('temperature', 1.0) + config.get('add_temperature', 0.0),
                    'top_p': config.get('top_p', 1.0),
                    'top_k': config.get('top_k', 50),
                    'min_p': config.get('min_p'),
                    'stream': False,
                }
                task = asyncio.create_task(send_request(client, json_data, n))
                tasks.append(task)
              
            results = await asyncio.gather(*tasks, return_exceptions=True)
            
            for result in results:
                if isinstance(result, Exception):
                    logging.error(f"Error in batch: {str(result)}")
                    all_answers.append(str(result))
                else:
                    all_answers.append(result)
    
    return all_answers


async def send_request(client, json_data, n):
    prompt_answers = []
    for _ in range(n):
        response = await client.post(config['base_url'], headers=headers, json=json_data)
        result = response.json()
        
        ans = {
            'answer': result['choices'][0]['message']['content'],
            'reasoning': result['choices'][0]['message'].get('reasoning_content', '')
        }

        # 获取token统计信息
        prompt_tokens = result['usage']['prompt_tokens']
        completion_tokens = result['usage']['completion_tokens']
        total_tokens = result['usage']['total_tokens']
        
        prompt_price = prompt_tokens * price['prompt']
        completion_price = completion_tokens * price['completion']
        all_price = prompt_price + completion_price
        
        # 更新统计信息
        token_stats['total_prompt_tokens'] += prompt_tokens
        token_stats['total_completion_tokens'] += completion_tokens
        token_stats['total_tokens'] += total_tokens
        token_stats['prompt_price'] += prompt_price
        token_stats['completion_price'] += completion_price
        token_stats['total_price'] += all_price
        token_stats['requests_count'] += 1
        
        # 记录日志
        with open(price_log_path, 'a') as log_file:
            log_file.write(f"Request ID: {token_stats['requests_count']}\n")
            log_file.write(f"  Prompt tokens: {prompt_tokens}\n")
            log_file.write(f"  Completion tokens: {completion_tokens}\n")
            log_file.write(f"  Total tokens: {total_tokens}\n")
            log_file.write(f"  Prompt price: {prompt_price}\n")
            log_file.write(f"  Completion price: {completion_price}\n")
            log_file.write(f"  Total price: {all_price}\n")
            log_file.write("-" * 40 + "\n")
        
        prompt_answers.append(ans)
        
    return prompt_answers


def save_json(data, file_name):
    json_data = json.dumps(data, indent=2, ensure_ascii=False)
    with open(file_name, 'w') as file:
        file.write(json_data)
    file.close()


def split_into_batches(input_list, batch_size):
    for i in range(0, len(input_list), batch_size):
        yield input_list[i:i + batch_size]


async def process_tasks(input_path):
    try:
        with open(input_path, 'r') as f:
            dataset = json.load(f)[config['start']:config['end']]
    except:
        with open(input_path, 'r') as f:
            dataset = json.load(f)
    
    # raw_data_id = [d['id'] for d in dataset]
    
    # 只处理 match_rate != 100 的数据
    dataset_to_check = [d for d in dataset if d.get('match_rate') != 100]
    dataset_no_check = [d for d in dataset if d.get('match_rate') == 100]
    
    logging.info(f'Total dataset size: {len(dataset)}')
    logging.info(f'Data requiring consistency check (match_rate != 100): {len(dataset_to_check)}')
    logging.info(f'Data skipped (match_rate == 100): {len(dataset_no_check)}')
    
    # 对 match_rate == 100 的数据，直接复制 formalized_cot 到 verified_cot
    for d in dataset_no_check:
        d['verified_cot'] = d.get('formalized_cot', '')
    
    current_dataset_ids = [d['id'] for d in dataset_to_check]

    # 加载已处理的结果
    try:
        with open(output_file, 'r') as f3:
            existing_results = json.load(f3)
            logging.info(f'Loaded existing results: {len(existing_results)}')
    except:
        existing_results = []
        logging.info('No existing results found')
    
    # 提取已处理的ID
    processed_id = [d['id'] for d in existing_results]
    logging.info(f'Already processed: {len(processed_id)}')
    
    # 过滤出待处理数据
    dataset_to_check = [d for d in dataset_to_check if d['id'] not in processed_id]
    print(f'To be processed: {len(dataset_to_check)}')
    logging.info(f'To be processed: {len(dataset_to_check)}')
    
    messages = build_messages(dataset_to_check)
    
    tasks = []
    total_batches = len(list(split_into_batches(messages, config['batch_size'])))
    pbar = tqdm(total=total_batches, desc="Processing batches")
    
    # 用于存储新处理的结果
    newly_processed = []
    processed_requests = 0
    save_frequency = 20
    
    batch_counter = 0
    for batch_idx, batch in enumerate(split_into_batches(messages, config['batch_size'])):
        task = asyncio.create_task(get_batch_response_with_retry(batch, n=config['n']))
        tasks.append(task)
        
        if len(tasks) >= 10 or batch_idx == total_batches - 1:
            response_message_batches = await asyncio.gather(*tasks)
            tasks = []
            
            for response_message_batch in response_message_batches:
                pbar.update(1)
                
                # 正确计算起始索引
                start_index = batch_counter * config['batch_size']
                batch_counter += 1
                
                for j, answer in enumerate(response_message_batch):
                    dataset_index = start_index + j
                    if dataset_index < len(dataset_to_check):
                        try:
                            response_text = answer[0]['answer']
                            
                            if response_text.strip() == "No":
                                dataset_to_check[dataset_index]['consist'] = "Yes"
                            else:
                                extracted = extract_between(response_text, "consist_thinking")
                                dataset_to_check[dataset_index]['consist'] = "No"
                                
                                if extracted:
                                    dataset_to_check[dataset_index]['verified_cot'] = extracted
                                else:
                                    dataset_to_check[dataset_index]['verified_cot'] = None
                                    logging.warning(f"Failed to extract consist_thinking for id {dataset_to_check[dataset_index]['id']}")
                        
                        except Exception as e:
                            dataset_to_check[dataset_index]['verified_cot'] = str(answer)
                            logging.error(f"Error processing answer for id {dataset_to_check[dataset_index]['id']}: {str(e)}")
                
                # 添加本批次处理的数据到新结果
                batch_results = dataset_to_check[start_index:start_index + len(response_message_batch)]
                newly_processed.extend(batch_results)
                processed_requests += len(response_message_batch)
                
                # 定期保存
                if processed_requests >= save_frequency or (batch_idx == total_batches - 1):
                    # 合并：已有结果 + 新处理的 + 跳过检查的
                    all_results = existing_results + newly_processed + dataset_no_check
                    # 去重（按id）
                    unique_results = {d['id']: d for d in all_results}.values()
                    save_json(list(unique_results), output_file)
                    logging.info(f'Saved progress: {len(unique_results)} total results')
                    processed_requests = 0
    
    pbar.close()
    
    # 最终保存
    all_results = existing_results + newly_processed + dataset_no_check
    unique_results = {d['id']: d for d in all_results}.values()
    save_json(list(unique_results), output_file)
    logging.info(f'Final save: {len(unique_results)} total results')
    
    log_token_stats()
    
async def main():
    with open(price_log_path, 'w') as log_file:
        log_file.write(f"Consistency Check - Token statistics for model: {model_name}\n") 
        log_file.write(f"Date: {datetime.now()}\n")
        log_file.write("=" * 50 + "\n\n")
        
    await process_tasks(input_file)


if __name__ == "__main__":
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        loop.run_until_complete(main())
    finally:
        loop.close()