import json
import asyncio
import aiohttp
from tqdm import tqdm
from datetime import datetime
from asyncio import Semaphore
from transformers import AutoTokenizer

def extract_think(text):
    try:
        start = text.index("<think>") + len("<think>")
        end = text.index("</think>") if "</think>" in text else text.index("\n</think>")
        return text[start:end].strip()
    except Exception:
        return text 

class RateLimiter:
    def __init__(self, rate):
        self.rate = rate
        self.current = 0
        self.last_check = datetime.now()

    async def wait(self):
        while True:
            now = datetime.now()
            passed_seconds = (now - self.last_check).seconds
            if passed_seconds > 1:
                self.last_check = now
                self.current = 0
            if self.current < self.rate:
                self.current += 1
                return
            await asyncio.sleep(1)

async def request_chat_async(rate_limiter, semaphore, session, url, tokenizer, line, max_retries=10):
    query = line['prompt']
    think = extract_think(line['response'])
    content = f'''
    You are an AI language model tasked with analyzing the safety of queries and producing safe, guided final answers. For each input, perform the following steps:

    ### Instructions:

    1. **Generate a <safe> Section**:  
    Produce a concise safety analysis based on the given `query` and `<think>` content. This step should include:
    - **Core Issue Safety Analysis**: Briefly identify the core issue and assess for any potential safety risks. Remember, uncertainty also qualifies as a risk.
    - **Reasoning Process Safety Analysis**: If the core issue has potential risks, evaluate whether the `<think>` content sufficiently considers these risks, and summarize your findings briefly.
    - **Final Answer Guidance Content**: Provide short guidance on crafting a safe and appropriate final answer based on your analyses.

    2. **Generate a Detailed Final Answer**:  
    Using the analyses from the `<safe>` section, formulate a detailed and safe final answer that comprehensively addresses the query while mitigating identified risks.

    ### Input Format:
    - **Query**: {query}
    - **Think**: {think}

    ### Output Format:
    ```plain
    <safe>
    1. Core Issue Safety Analysis: [Concise analysis here]
    2. Reasoning Process Safety Analysis: [Concise analysis, if applicable]
    3. Final Answer Guidance Content: [Brief guidance here]
    </safe>

    Final Answer: [Detailed and thorough response based on <think> and <safe> analyses]
    '''
    temperature = 0.5
    messages = [
        {"role": "user", "content": content}
    ]
    input_str = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False, enable_thinking=True)
    data_entry = {
        "text": [input_str],
        "sampling_params": {
            "temperature": temperature,
            "max_new_tokens": 2048, 
            "skip_special_tokens": False,
        },
    }
    retries = 0
    while retries < max_retries:
        await rate_limiter.wait() 
        async with semaphore:  
            try:
                async with session.post(url, json=data_entry) as response:
                    raw_data = await response.read()
                    data = json.loads(raw_data.decode('utf-8'))
                    line['self_final'] = data[0]['text']
                    return  line
            except Exception as e:
                print(f'ds api exception: {e}')
                retries += 1
                await asyncio.sleep(1)
    print('Maximum retry attempts reached, returning error')
    return ""
              
async def process_prompts_chunk_async(rate_limiter, semaphore, session, url,tokenizer, line, max_retries=10):
    line = await request_chat_async(rate_limiter, semaphore, session, url,tokenizer, line, max_retries)
    return line

async def gen_assistant_async(data, url, tokenizer,outpath, qps=2, max_concurrent=20, max_retries=4):
    rate_limiter = RateLimiter(qps)
    semaphore = Semaphore(max_concurrent)  
    timeout = aiohttp.ClientTimeout(total=6000)  
    results = []
    async with aiohttp.ClientSession(timeout=timeout) as session:
        tasks = [request_chat_async(rate_limiter, semaphore, session, url, tokenizer, line, max_retries) for line in data]
        for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            line = await future
            results.append(line)
    with open(outpath, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)

def gen(path, outpath, url):
    max_concurrent = 600
    qps = 400
    with open(path) as f:
        data = json.load(f)
    if not data:
        return
    
    tokenizer = AutoTokenizer.from_pretrained('')
    loop = asyncio.get_event_loop()
    loop.run_until_complete(gen_assistant_async(data, url, tokenizer, outpath, qps=qps, max_concurrent=max_concurrent, max_retries=4))

if __name__ == '__main__':
    url = ""
    file_input = ""
    file_output = ""
    gen(file_input, file_output, url)
