import os
import json
import asyncio
import re
import argparse
from typing import List
from tqdm import tqdm
import openai

if os.getenv("OPENAI_API_KEY"):
    openai.api_key = os.getenv("OPENAI_API_KEY")
if os.getenv("OPENAI_API_BASE"):
    openai.api_base = os.getenv("OPENAI_API_BASE")


class OpenAIChat():

    def __init__(
            self,
            model_name='gpt-4o-mini',
            max_tokens=2500,
            temperature=0.5,
            top_p=1,
            request_timeout=180,
            stop=None,
            response_format='text',  # text or json_object
            logprobs=False,
            top_logprobs=None,
            n=1,
    ):
        self.config = {
            'model_name': model_name,
            'max_tokens': max_tokens,
            'temperature': temperature,
            'top_p': top_p,
            'request_timeout': request_timeout,
            'stop': stop,
            'response_format': response_format,
            'logprobs': logprobs,
            'top_logprobs': top_logprobs,
            'sample_n': n,
        }

    async def dispatch_openai_requests(
            self,
            messages_list,
    ) -> List[str]:
        """Dispatches requests to OpenAI API asynchronously."""

        async def _request_with_retry(messages, retry=3):
            for try_i in range(retry):
                try:
                    response = await openai.ChatCompletion.acreate(
                        model=self.config['model_name'],
                        response_format={'type': self.config['response_format']},
                        messages=messages,
                        max_tokens=self.config['max_tokens'],
                        temperature=self.config['temperature'],
                        top_p=self.config['top_p'],
                        request_timeout=self.config['request_timeout'],
                        stop=self.config['stop'],
                        logprobs=self.config['logprobs'],
                        top_logprobs=self.config['top_logprobs'],
                        n=self.config['sample_n'],
                    )
                    return response
                except Exception as e:
                    print(f"An error occurred: {e}. Retrying {try_i + 1}/{retry}...")
                    await asyncio.sleep(20)
            return None

        async_responses = [
            _request_with_retry(messages)
            for messages in messages_list
        ]
        return await asyncio.gather(*async_responses)

    async def async_run(self, messages_list, expected_type):
        retry = 3
        responses = [None for _ in range(len(messages_list))]
        messages_list_cur_index = list(range(len(messages_list)))

        while retry > 0 and len(messages_list_cur_index) > 0:
            messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
            predictions = await self.dispatch_openai_requests(messages_list=messages_list_cur)

            preds = [p['choices'][0]['message']['content'] if p is not None else None for p in predictions]

            finished_index = []
            for i, pred in enumerate(preds):
                if pred is not None:
                    responses[messages_list_cur_index[i]] = pred
                    finished_index.append(messages_list_cur_index[i])

            messages_list_cur_index = [i for i in messages_list_cur_index if i not in finished_index]
            if len(messages_list_cur_index) > 0:
                retry -= 1
        return responses


def refine_reasoning_batch(prompts: List[str], model_name: str) -> List[str]:
    """Refines reasoning for a batch of prompts using the specified model."""
    messages_list = [[{"role": "user", "content": prompt}] for prompt in prompts]

    chat = OpenAIChat(model_name=model_name, max_tokens=4096, temperature=0.5, top_p=0.9)

    responses = asyncio.run(chat.async_run(messages_list=messages_list, expected_type=List))
    return responses


def process_and_prune_file(input_file: str, output_file: str, model_name: str, batch_size: int):
    """
    Reads a JSON file with generated CoT, prunes/refines the reasoning,
    adds it as a new field 'cot_pruning', and writes to a new file.
    """
    print(f"Reading data from: {input_file}")
    try:
        with open(input_file, "r", encoding="utf-8") as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_file}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {input_file}.")
        return

    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    print(f"Pruning {len(data)} items using model '{model_name}' and writing to: {output_file}")

    batch_items = []
    batch_prompts = []

    with open(output_file, "w", encoding="utf-8") as f_out:
        for item in tqdm(data, desc="Pruning CoT Reasoning"):
            question = item.get("question")
            # The original reasoning is now taken from the 'cot_generation' field
            original_output = item.get("cot_generation")

            if not question or not original_output:
                item["cot_pruning"] = "Error: Missing 'question' or 'cot_generation' field."
                f_out.write(json.dumps(item, ensure_ascii=False) + "\n")
                continue

            options_dict = {
                "opa": item.get("opa"), "opb": item.get("opb"),
                "opc": item.get("opc"), "opd": item.get("opd")
            }
            options_dict = {k: v for k, v in options_dict.items() if v is not None}
            option_labels = {key: chr(ord('A') + i) for i, key in enumerate(options_dict.keys())}
            options_string = " ".join([f"({label}) {options_dict[key]}" for key, label in option_labels.items()])

            # --- Prompt for Pruning/Refining ---
            user_prompt = f"""
Modify the scientific reasoning to meet the following requirements.
Please follow these instructions:

1. Keep the technical accuracy and depth of the original reasoning.
2. Delete the content that is not related to the question.
3. Add realistic thought processes: moments of doubt and elimination of distractors.
4. Ensure every answer option is explicitly considered and evaluated.

Here is an example: ... 
Original output: ...
The model should rewrite like this: ...

Now, try to rewrite the thinking process of the next question in the same way.
Question: {question}  
Options: {options_string}
Original output: {original_output}

【Output Format】
<reasoning>\n ...\n</reasoning>
"""
            batch_prompts.append(user_prompt)
            batch_items.append(item)

            if len(batch_prompts) >= batch_size:
                responses = refine_reasoning_batch(batch_prompts, model_name)
                for resp, itm in zip(responses, batch_items):
                    if resp:
                        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', resp, re.DOTALL)
                        reasoning = reasoning_match.group(1).strip() if reasoning_match else resp.strip()
                        # Save to the new 'cot_pruning' field
                        itm["cot_pruning"] = reasoning
                    else:
                        itm["cot_pruning"] = "Error: Failed to generate response from API."
                    f_out.write(json.dumps(itm, ensure_ascii=False) + "\n")
                batch_prompts, batch_items = [], []

        # Process the final batch
        if batch_prompts:
            responses = refine_reasoning_batch(batch_prompts, model_name)
            for resp, itm in zip(responses, batch_items):
                if resp:
                    reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', resp, re.DOTALL)
                    reasoning = reasoning_match.group(1).strip() if reasoning_match else resp.strip()
                    itm["cot_pruning"] = reasoning
                else:
                    itm["cot_pruning"] = "Error: Failed to generate response from API."
                f_out.write(json.dumps(itm, ensure_ascii=False) + "\n")

    print(f"\n✅ Processing complete. Results written to {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Prune and refine Chain-of-Thought reasoning for biological questions.")

    parser.add_argument("--input_file", type=str, required=True,
                        help="Path to the input JSON file (must contain 'cot_generation' field).")
    parser.add_argument("--output_file", type=str, required=True,
                        help="Path for the output JSON file with 'cot_pruning' field.")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini",
                        help="Name of the model to use for refinement.")
    parser.add_argument("--batch_size", type=int, default=10, help="Number of items to process in a single batch.")
    parser.add_argument("--api_key", type=str, default=None,
                        help="OpenAI API key. (Recommended: use environment variable OPENAI_API_KEY)")
    parser.add_argument("--api_base", type=str, default=None,
                        help="OpenAI API base URL. (Recommended: use environment variable OPENAI_API_BASE)")

    args = parser.parse_args()

    if args.api_key:
        openai.api_key = args.api_key
    if args.api_base:
        openai.api_base = args.api_base

    if not openai.api_key:
        raise ValueError(
            "OpenAI API key is not set. Please set it via the OPENAI_API_KEY environment variable or the --api_key argument.")

    process_and_prune_file(
        input_file=args.input_file,
        output_file=args.output_file,
        model_name=args.model_name,
        batch_size=args.batch_size
    )


if __name__ == "__main__":
    main()