import json
import os
import httpx
import asyncio
from tqdm.asyncio import tqdm

import yaml
import argparse
import logging
import re

from prompts import DATA_CONSTRUCTION_USER_PROMPT

from datetime import datetime



parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default=None)
args = parser.parse_args()


with open(args.config, 'r') as file:
    config = yaml.safe_load(file)

def load_json(file):
    with open(file,'r', encoding="utf8") as load_f:
        data = json.load(load_f)
        return data
    
def write_json(file, dict):
    with open(file, "w", encoding="utf8") as f:
        json.dump(dict, f, indent=4, ensure_ascii=False)

def build_user_prompts(data):
    
    prompts = []
    
    for d in data:
  
        prompt = DATA_CONSTRUCTION_USER_PROMPT.format(question = d['question'], original_response = d['cot'])
        
        prompts.append(prompt)
      
    return prompts


def build_messages(data):
    user_prompts = build_user_prompts(data=data)
    messages = []
    
    for i in range(len(data)):
        messages.append([
                {'role': 'user', 'content': f'{user_prompts[i]}'}
            ]
        )
    return messages



def extract_after_compressed_reasoning(text):
    
    marker = "[[COMPRESSED REASONING]]"
    if marker in text:
        # Find the position of the marker and return everything after it
        start_pos = text.find(marker) + len(marker)
        return text[start_pos:].strip()
    else:
        return None


def count_code(text):
    if text is None:
        return 0 
    
    pattern = r'<interpreter>\s*([\s\S]*?)\s*</interpreter>'
    outputs = re.findall(pattern, text)
    
    total_count = len(outputs)

    return total_count

def extract_between(text: str, tag: str) -> str | None:
    open_tag, close_tag = f"<{tag}>", f"</{tag}>"
    start = text.find(open_tag) + len(open_tag)
    end = text.find(close_tag)
    if start > len(open_tag)-1 and end > -1:
        return text[start:end].strip()
    return None


def extract_correct_thinking(dataset):
    
    processed_dataset = []
    for d in dataset:
        # print(d)
        answers = d['generated_answer']
        judgment = d['judgments']
        
        for i in range(8):
            try:
                if judgment[i]['score'] == "Yes" and extract_between(answers[i]['answer'], 'think'):
                    # have at least one extractable correct CoT
                    d['cot'] = extract_between(answers[i]['answer'], 'think')
                    # d['cot'] = answers[i]['answer']
                    d['summary'] = answers[i]['answer'].split("</think>")[-1].strip()
                    d.pop('generated_answer')
                    d.pop('judgments')
                    processed_dataset.append(d)
                    
                    break
            except:
                pass
    return processed_dataset
     
        
import os
import json
from copy import deepcopy  # 避免修改原始 dataset

def extract_correct_thinking_sample(dataset, sample_num, output_meta_path=None):
    """
    从 dataset 中按 yes_ratio (pass rate) 分四档（1.0, 0.75, 0.5, 0.25）各抽取 sample_num 个样本，
    每个样本选取第一个有效的正确 CoT。
    
    对每个选中的样本：
      - 保留所有原始字段
      - 删除 'generated_answer' 和 'judgments'
      - 新增 'cot', 'summary', 'pass_rate'
    
    Args:
        dataset: 原始数据列表
        sample_num: 每个档位抽取数量
        output_meta_path: 可选 meta 文件路径
    
    Returns:
        processed_dataset: 修改后的数据列表（不修改原始 dataset）
    """
    
    target_ratios = [1.0, 0.75, 0.5, 0.25]
    ratio_to_samples = {r: [] for r in target_ratios}
    ratio_to_ids = {r: [] for r in target_ratios}
    
    def compute_yes_ratio(judgments):
        if not judgments:
            return 0.0
        total = 0
        yes_count = 0
        for j in judgments:
            if j is None:
                total += 1
                continue
            if isinstance(j, dict):
                score = j.get("score")
                if score == "Yes":
                    yes_count += 1
            total += 1
        return round(yes_count / total, 4) if total > 0 else 0.0

    for d in dataset:
        if not d.get('id') or 'generated_answer' not in d or 'judgments' not in d:
            continue
            
        yes_ratio = compute_yes_ratio(d['judgments'])
        if yes_ratio not in target_ratios:
            continue
        if len(ratio_to_samples[yes_ratio]) >= sample_num:
            continue
        
        answers = d['generated_answer']
        judgments = d['judgments']
        cot_extracted = False
        
        for i in range(min(8, len(answers), len(judgments))):
            try:
                if judgments[i] is None:
                    continue
                if judgments[i].get('score') == "Yes":
                    extracted_cot = extract_between(answers[i]['answer'], 'think')
                    if extracted_cot:
                        # 深拷贝原始数据，避免副作用
                        new_d = deepcopy(d)
                        
                        # 添加新字段
                        new_d['cot'] = extracted_cot
                        new_d['summary'] = answers[i]['answer'].split("</think>")[-1].strip()
                        new_d['pass_rate'] = yes_ratio  
                        
                        # 删除冗余字段
                        new_d.pop('generated_answer', None)
                        new_d.pop('judgments', None)
                        
                        ratio_to_samples[yes_ratio].append(new_d)
                        ratio_to_ids[yes_ratio].append(d['id'])
                        cot_extracted = True
                        break
            except (KeyError, IndexError, AttributeError, TypeError):
                continue
    
    # 按 1.0 → 0.75 → 0.5 → 0.25 顺序合并
    processed_dataset = []
    for ratio in target_ratios:
        processed_dataset.extend(ratio_to_samples[ratio])
    
    # 保存 meta 文件
    if output_meta_path:
        meta = {str(r): ratio_to_ids[r] for r in target_ratios}
        os.makedirs(os.path.dirname(output_meta_path), exist_ok=True)
        with open(output_meta_path, 'w', encoding='utf-8') as f:
            json.dump(meta, f, indent=2, ensure_ascii=False)
        print(f"Meta file saved to: {output_meta_path}")
    
    return processed_dataset


   
input_file = config['input_file']
output_dir = os.path.join(
    config['output_dir'], 
    config['dataset_version'],
    config['task']
    # config['model_id'].split('/')[-1])
)

output_file = os.path.join(output_dir, f"{config['task']}{config.get('save_id', '')}.json")

# 设置日志文件
log_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
price_log_path = os.path.join(output_dir, 'price', f'log_{log_timestamp}.txt')

os.makedirs(os.path.join(output_dir, 'price'), exist_ok=True)

# 加载价格配置
price_file = config.get('price_file')
with open(price_file, 'r') as p:
    price_data = json.load(p)

# log_path = os.path.join(output_dir, 'log.txt')

os.makedirs(os.path.join(output_dir, 'requests'), exist_ok=True)
log_path = os.path.join(output_dir, 'requests', f"log{config.get('save_id', '')}_{log_timestamp}.txt")

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename=log_path,
    filemode='a'  # 追加模式
)



if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
model_name = config['model_id']

price = price_data[model_name]


print(model_name)
print(config.get('shot_num'))
print(config.get('save_id'))

headers = {
    'Authorization': config['api_key'],
    'Content-Type': 'application/json',
}
# 设置并发数
CONCURRENCY = config.get('concurrency')
# 创建一个信号量，用于控制并发数量
semaphore = asyncio.Semaphore(CONCURRENCY)

# token统计信息
token_stats = {
    'total_prompt_tokens': 0,
    'total_completion_tokens': 0,
    'total_tokens': 0,
    'prompt_price': 0,
    'completion_price': 0,
    'total_price': 0,
    'requests_count': 0
}

def log_token_stats():
    with open(price_log_path, 'a') as log_file:
        log_file.write("\n" + "=" * 50 + "\n")
        log_file.write("SUMMARY STATISTICS\n")
        log_file.write(f"Total requests: {token_stats['requests_count']}\n")
        log_file.write(f"Total prompt tokens: {token_stats['total_prompt_tokens']}\n")
        log_file.write(f"Total completion tokens: {token_stats['total_completion_tokens']}\n")
        log_file.write(f"Total tokens: {token_stats['total_tokens']}\n")
        log_file.write(f"Total prompt price: {token_stats['prompt_price']}\n")
        log_file.write(f"Total completion price: {token_stats['completion_price']}\n")
        log_file.write(f"Total price: {token_stats['total_price']}\n")
        
        if token_stats['requests_count'] > 0:
            avg_prompt = token_stats['total_prompt_tokens'] / token_stats['requests_count']
            avg_completion = token_stats['total_completion_tokens'] / token_stats['requests_count']
            avg_total = token_stats['total_tokens'] / token_stats['requests_count']
            log_file.write(f"Average prompt tokens per request: {avg_prompt}\n")
            log_file.write(f"Average completion tokens per request: {avg_completion}\n")
            log_file.write(f"Average total tokens per request: {avg_total}\n")
            log_file.write(f"Average prompt price per request: {token_stats['prompt_price'] / token_stats['requests_count']}\n")
            log_file.write(f"Average completion price per request: {token_stats['completion_price'] / token_stats['requests_count']}\n")
            log_file.write(f"Average total price per request: {token_stats['total_price'] / token_stats['requests_count']}\n")

# async def get_batch_response_with_retry(messages, n=1, max_retries=3):
#     for attempt in range(max_retries):
#         try:
#             return await get_batch_response(messages, n)
#         except Exception as e:
#             if attempt == max_retries - 1:
#                 raise
#             await asyncio.sleep(1 * (attempt + 1)) 

async def get_batch_response_with_retry(messages, n=1, max_retries=3):
    last_exception = None
    for attempt in range(max_retries):
        try:
            return await get_batch_response(messages, n)
        except Exception as e:
            last_exception = e
            logging.error(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                await asyncio.sleep(1 * (attempt + 1))
            else:
                logging.error(f"Failed after {max_retries} attempts")
                raise last_exception
            



async def get_batch_response(messages, n=1):
    all_answers = []
    async with semaphore:
        limits = httpx.Limits(max_keepalive_connections=5, max_connections=10)
        async with httpx.AsyncClient(timeout=httpx.Timeout(360.0), limits=limits) as client:
            tasks = []
            for message in messages:
                json_data = {
                    'model': model_name,
                    'messages': message,
                    'temperature': config.get('temperature', 1.0) + config.get('add_temperature', 0.0),
                    'stream': False,
                }
                task = asyncio.create_task(send_request(client, json_data, n))
                tasks.append(task)
              
            results = await asyncio.gather(*tasks, return_exceptions=True)
            
            
            # print(results)
            for result in results:
                if isinstance(result, Exception):
                    logging.error(f"Error in batch: {str(result)}")
                    all_answers.append(str(result))
                else:
                    all_answers.append(result)
    # print(all_answers)
    return all_answers



async def send_request(client, json_data, n):
    prompt_answers = []
    for _ in range(n):
        # try:
        response = await client.post(config['base_url'], headers=headers, json=json_data)
        # print("-----------")
        # print("response",response)
        result = response.json()
    
        ans = {
            'answer': result['choices'][0]['message']['content'],
            'reasoning': result['choices'][0]['message'].get('reasoning_content', '')
        }


        # 获取token统计信息
        prompt_tokens = result['usage']['prompt_tokens']
        completion_tokens = result['usage']['completion_tokens']
        total_tokens = result['usage']['total_tokens']
        
        prompt_price = prompt_tokens * price['prompt']
        completion_price = completion_tokens * price['completion']
        all_price = prompt_price + completion_price
        
        # 更新统计信息
        token_stats['total_prompt_tokens'] += prompt_tokens
        token_stats['total_completion_tokens'] += completion_tokens
        token_stats['total_tokens'] += total_tokens
        token_stats['prompt_price'] += prompt_price
        token_stats['completion_price'] += completion_price
        token_stats['total_price'] += all_price
        token_stats['requests_count'] += 1
        
        # 记录日志
        with open(price_log_path, 'a') as log_file:
            log_file.write(f"Request ID: {token_stats['requests_count']}\n")
            log_file.write(f"  Prompt tokens: {prompt_tokens}\n")
            log_file.write(f"  Completion tokens: {completion_tokens}\n")
            log_file.write(f"  Total tokens: {total_tokens}\n")
            log_file.write(f"  Prompt price: {prompt_price}\n")
            log_file.write(f"  Completion price: {completion_price}\n")
            log_file.write(f"  Total price: {all_price}\n")
            log_file.write("-" * 40 + "\n")
        
        prompt_answers.append(ans)
        
    return prompt_answers

def save_json(data, file_name):
    json_data = json.dumps(data, indent=2, ensure_ascii=False)
    with open(file_name, 'w') as file:
        file.write(json_data)
    file.close()

def split_into_batches(input_list, batch_size):
    for i in range(0, len(input_list), batch_size):
        yield input_list[i:i + batch_size]

async def process_tasks(input_path):

    with open(input_path, 'r') as f:
        dataset = json.load(f)
    
    try:
        dataset = dataset[config['start']: config['end']]
    except:
        pass
    

    
    current_dataset_ids = [d['id'] for d in dataset]

    try:
        with open(output_file, 'r') as f3:
            results = json.load(f3)
            logging.info(f'Loaded results: {len(results)}')
            results = [d for d in results if d['id'] in current_dataset_ids] 
            results = [d for d in results if d["formalized_cot"] != "All connection attempts failed" and d["formalized_cot"] and d["formalized_cot"] !="" and d["formalized_cot"] !="'choices'"]
            logging.info(f'Filtered results (valid data): {len(results)}')
        
    except: 
        results = []
        logging.info('No existing results found')
    
    processed_id = [d['id'] for d in results ]
    
    logging.info(f'Total dataset size: {len(dataset)}')
    logging.info(f'Already processed: {len(processed_id)}')
        
    print(len(dataset))
    print('processed: ' , len(processed_id))
    
    dataset = [d for d in dataset if d['id'] not in processed_id]
    
    
    
    print(f'to be processed: {len(dataset)}')
    
    logging.info(f'To be processed: {len(dataset)}')
    
    messages = build_messages(dataset) 

    tasks = []
    total_batches = len(list(split_into_batches(messages, config['batch_size'])))
    pbar = tqdm(total=total_batches, desc="Processing batches")

    # 添加的计数器和保存频率设置
    processed_requests = 0
    save_frequency = 20  # 每处理500个请求保存一次
    
    for batch_idx, batch in enumerate(split_into_batches(messages, config['batch_size'])):
        task = asyncio.create_task(get_batch_response_with_retry(batch, n=config['n']))
        tasks.append(task)
        
        # 分批处理以便更频繁地保存
        if len(tasks) >= 10 or batch_idx == total_batches - 1:  # 每次处理10个批次或最后一批
            response_message_batches = await asyncio.gather(*tasks)
            tasks = []  # 重置任务列表
            
            for i, response_message_batch in enumerate(response_message_batches):
                pbar.update(1)
                start_index = (batch_idx - len(response_message_batches) + i + 1) * config['batch_size']
                for j, answer in enumerate(response_message_batch):
                    dataset_index = start_index + j
                    if dataset_index < len(dataset):
                        try:
                            dataset[dataset_index]['formalized_cot'] = extract_between(answer[0]['answer'], "revised_thinking_process")
                
                        except:
                            dataset[dataset_index]['formalized_cot'] = str(answer)
                            
                
                results.extend(dataset[start_index:start_index + len(response_message_batch)])
                processed_requests += len(response_message_batch)
                
                # 每处理'save_frequency'个请求或在最后保存一次结果
                if processed_requests >= save_frequency or (batch_idx == total_batches - 1 and i == len(response_message_batches) - 1):
                    save_json(results, output_file)
                    logging.info(f'Saved progress: {len(results)} results processed')
                    processed_requests = 0  # 保存后重置计数器
    
    pbar.close()
    
    
    #
    for r in results:
        r['call_num'] = count_code(r["formalized_cot"])
        
    # 最终保存以确保所有内容都被存储
    save_json(results, output_file)
    logging.info(f'Final save: {len(results)} total results')
    
    # 记录最终统计信息
    log_token_stats()
    

async def main():
    
    with open(price_log_path, 'w') as log_file:
        log_file.write(f"Token statistics for model: {model_name}\n")
        log_file.write(f"Date: {datetime.now()}\n")
        log_file.write("=" * 50 + "\n\n")
        
        
    await process_tasks(input_file)
    
    
if __name__ == "__main__":
    
    # asyncio.run(main())
    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)
    try:
        loop.run_until_complete(main())
    finally:
        loop.close()
        
    with open(price_log_path, 'a') as log_file:
        log_file.write(f"Date: {datetime.now()}\n")