import os
import json
import asyncio
import re
import argparse
from typing import List
from tqdm import tqdm
import openai

if os.getenv("OPENAI_API_KEY"):
    openai.api_key = os.getenv("OPENAI_API_KEY")
if os.getenv("OPENAI_API_BASE"):
    openai.api_base = os.getenv("OPENAI_API_BASE")


class OpenAIChat():
    # more details on: https://platform.openai.com/docs/api-reference/chat
    def __init__(
            self,
            model_name='gpt-4o-mini',
            max_tokens=2500,
            temperature=0.5,
            top_p=1,
            request_timeout=180,
            stop=None,
            response_format='text',  # text or json_object
            logprobs=False,
            top_logprobs=None,
            n=1,
    ):
        self.config = {
            'model_name': model_name,
            'max_tokens': max_tokens,
            'temperature': temperature,
            'top_p': top_p,
            'request_timeout': request_timeout,
            'stop': stop,
            'response_format': response_format,
            'logprobs': logprobs,
            'top_logprobs': top_logprobs,
            'sample_n': n,
        }
        # The API key and base are now set globally from the main block
        # to avoid hardcoding them here.

    async def dispatch_openai_requests(
            self,
            messages_list,
    ) -> List[str]:
        """Dispatches requests to OpenAI API asynchronously."""

        async def _request_with_retry(messages, retry=3):
            for try_i in range(retry):
                try:
                    # for text embedding models
                    if "embedding" in self.config['model_name']:
                        response = await openai.Embedding.acreate(
                            model=self.config['model_name'],
                            input=messages,
                        )
                    else:
                        # for chat models
                        response = await openai.ChatCompletion.acreate(
                            model=self.config['model_name'],
                            response_format={'type': self.config['response_format']},
                            messages=messages,
                            max_tokens=self.config['max_tokens'],
                            temperature=self.config['temperature'],
                            top_p=self.config['top_p'],
                            request_timeout=self.config['request_timeout'],
                            stop=self.config['stop'],
                            logprobs=self.config['logprobs'],
                            top_logprobs=self.config['top_logprobs'],
                            n=self.config['sample_n'],
                        )
                    return response
                except Exception as e:
                    print(f"An error occurred: {e}. Retrying {try_i + 1}/{retry}...")
                    await asyncio.sleep(20)  # Wait before retrying
            return None

        async_responses = [
            _request_with_retry(messages)
            for messages in messages_list
        ]
        return await asyncio.gather(*async_responses)

    async def async_run(self, messages_list, expected_type):
        retry = 3
        responses = [None for _ in range(len(messages_list))]
        messages_list_cur_index = list(range(len(messages_list)))

        while retry > 0 and len(messages_list_cur_index) > 0:
            messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
            predictions = await self.dispatch_openai_requests(messages_list=messages_list_cur)

            if "embedding" in self.config['model_name']:
                preds = [p['data'][0]['embedding'] if p is not None else None for p in predictions]
            else:
                if not self.config['logprobs']:
                    preds = [p['choices'][0]['message']['content'] if p is not None else None for p in predictions]
                else:
                    preds = [
                        [p['choices'][0]['message']['content'],
                         [d['logprob'] for d in p['choices'][0]['logprobs']['content']]]
                        if p is not None else None for p in predictions
                    ]

            finished_index = []
            for i, pred in enumerate(preds):
                if pred is not None:
                    responses[messages_list_cur_index[i]] = pred
                    finished_index.append(messages_list_cur_index[i])

            messages_list_cur_index = [i for i in messages_list_cur_index if i not in finished_index]
            if len(messages_list_cur_index) > 0:
                retry -= 1
        return responses


def generate_answer_batch(prompts: List[str], model_name: str) -> List[str]:
    """Generates answers for a batch of prompts using the specified model."""
    messages_list = [[{"role": "user", "content": prompt}] for prompt in prompts]

    # Configure your model details here, using the passed model_name
    chat = OpenAIChat(model_name=model_name, max_tokens=4096, temperature=0.5, top_p=0.9)

    responses = asyncio.run(chat.async_run(messages_list=messages_list, expected_type=List))
    return responses


def process_json_file(input_file: str, output_file: str, model_name: str, batch_size: int):
    """
    Reads a JSON file, generates a chain-of-thought reasoning for each item,
    adds it as a new field 'cot_generation', and writes to a new file.
    """
    print(f"Reading data from: {input_file}")
    try:
        with open(input_file, "r", encoding="utf-8") as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_file}")
        return
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {input_file}.")
        return

    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    print(f"Processing {len(data)} items using model '{model_name}' and writing to: {output_file}")

    batch_items = []
    batch_prompts = []

    with open(output_file, "w", encoding="utf-8") as f_out:
        for item in tqdm(data, desc="Generating CoT Reasoning"):
            question = item.get("question")
            path_knowledge = item.get("path")
            answer_key = item.get("answer")

            options_dict = {
                "opa": item.get("opa"), "opb": item.get("opb"),
                "opc": item.get("opc"), "opd": item.get("opd")
            }
            options_dict = {k: v for k, v in options_dict.items() if v is not None}

            if not all([question, path_knowledge, answer_key, options_dict]):
                item["cot_generation"] = "Error: Missing required fields (question, path, answer, or options)."
                f_out.write(json.dumps(item, ensure_ascii=False) + "\n")
                continue

            option_labels = {key: chr(ord('A') + i) for i, key in enumerate(options_dict.keys())}
            options_string = " ".join([f"({label}) {options_dict[key]}" for key, label in option_labels.items()])
            true_answer_text = options_dict.get(answer_key, "N/A")
            true_answer_label = option_labels.get(answer_key, "N/A")
            true_answer_full = f"({true_answer_label}) {true_answer_text}"

            user_prompt = f"""You are given a single-choice biological reasoning question. Your task is to generate a detailed step-by-step reasoning process that can guide a model to select the correct answer.
Question: {question}
Options: {options_string}
Additional Knowledge (from Knowledge Graph Path):{path_knowledge}
Correct Answer: {true_answer_full}

Now, write a natural, plausible reasoning process that could lead a knowledgeable person to choose the correct answers. Avoid simply referencing pre-established facts, and instead build your reasoning as a sequence of logical thoughts that connect concepts.
Begin by examining biological basis, including key pathways, affected tissues, and pathological mechanisms. Then reason about which proteins among the options are likely to be functionally or mechanistically involved, based on their roles and interactions. Evaluate each option independently and identify all proteins that are plausible based on this reasoning.
Do not just state facts from the "Additional knowledge",integrate them into a coherent narrative, making the reasoning flow as though it arises naturally from your own deep biological understanding.

Here is an example:
Question: ...
Options: ...
Additional Knowledge (from Knowledge Graph Path): ...
The model should generate like this:
<reasoning>\n...\n</reasoning>

Now, try to answer the above question in the same way.

【Output Format】
<reasoning>
...
</reasoning>"""
            batch_prompts.append(user_prompt)
            batch_items.append(item)

            if len(batch_prompts) >= batch_size:
                responses = generate_answer_batch(batch_prompts, model_name)
                for resp, itm in zip(responses, batch_items):
                    if resp:
                        reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', resp, re.DOTALL)
                        reasoning = reasoning_match.group(1).strip() if reasoning_match else resp.strip()
                        itm["cot_generation"] = reasoning
                    else:
                        itm["cot_generation"] = "Error: Failed to generate response from API."
                    f_out.write(json.dumps(itm, ensure_ascii=False) + "\n")
                batch_prompts, batch_items = [], []

        if batch_prompts:
            responses = generate_answer_batch(batch_prompts, model_name)
            for resp, itm in zip(responses, batch_items):
                if resp:
                    reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', resp, re.DOTALL)
                    reasoning = reasoning_match.group(1).strip() if reasoning_match else resp.strip()
                    itm["cot_generation"] = reasoning
                else:
                    itm["cot_generation"] = "Error: Failed to generate response from API."
                f_out.write(json.dumps(itm, ensure_ascii=False) + "\n")

    print(f"\n✅ Processing complete. Results written to {output_file}")


def main():
    parser = argparse.ArgumentParser(description="Generate Chain-of-Thought reasoning for biological questions.")

    parser.add_argument("--input_file", type=str, required=True, help="Path to the input JSON file.")
    parser.add_argument("--output_file", type=str, required=True, help="Path to the output JSON file.")
    parser.add_argument("--model_name", type=str, default="gpt-4o-mini",
                        help="Name of the model to use for generation.")
    parser.add_argument("--batch_size", type=int, default=10, help="Number of items to process in a single batch.")
    parser.add_argument("--api_key", type=str, default=None,
                        help="OpenAI API key. (Recommended to use environment variable OPENAI_API_KEY)")
    parser.add_argument("--api_base", type=str, default=None,
                        help="OpenAI API base URL. (Recommended to use environment variable OPENAI_API_BASE)")

    args = parser.parse_args()

    if args.api_key:
        openai.api_key = args.api_key
    if args.api_base:
        openai.api_base = args.api_base

    if not openai.api_key:
        raise ValueError(
            "OpenAI API key is not set. Please set it via the OPENAI_API_KEY environment variable or the --api_key argument.")

    process_json_file(
        input_file=args.input_file,
        output_file=args.output_file,
        model_name=args.model_name,
        batch_size=args.batch_size
    )


if __name__ == "__main__":
    main()