# -*- coding: utf-8 -*-
""" the IntentRL Workflow"""

from __future__ import annotations

import re
import time
from collections import Counter
from typing import List, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed

import openai
import os

from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows import WORKFLOWS, SimpleWorkflow, Task
from trinity.utils.log import get_logger

from trinity.plugins.utils import check_language_by_frequency, call_embedding_model

logger = get_logger(__name__)

REPEAT_NOTE_EN = "You repeat a question that has already been asked in the conversation history! Please avoid asking repetitive questions and generate new, more valuable ones."
REPEAT_NOTE_ZH = "你重复了之前问过的问题！请勿重复提问，请生成更有价值的问题。"
IRRELEVANT_NOTE_EN = "This question is not important to me, you don't need to focus on this."
IRRELEVANT_NOTE_ZH = "这个问题对我来说不重要，你不需要关注这一点。"

@WORKFLOWS.register_module("intentrl")
class Learn2AskWorkflow(SimpleWorkflow):
    """A workflow for Elem training with local model."""

    def __init__(
        self,
        *,
        task: Task,
        model: ModelWrapper,
        auxiliary_models: Optional[List[openai.OpenAI]] = None,
    ):
        self.train_mode = task.workflow_args.get("train_mode", "r+p")
        self.judge_mode = task.workflow_args.get("judge_mode", "local")
        super().__init__(
            task=task,
            model=model,
            auxiliary_models=auxiliary_models,
        )

    @property
    def resettable(self):
        return True

    def reset(self, task: Task):
        if self.judge_mode == "online":
            self.reward_client = openai.OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
                                               base_url=os.environ.get("OPENAI_API_BASE_URL")
            self.reward_model_name = "gpt-4.1-2025-04-14"
        self.format_args = task.format_args
        self.reply_prefix = task.format_args.reply_prefix

        self.raw_task = task.raw_task
        self.task_desc = task.task_desc
        # self.ask_web_info = task.raw_task.get("ask_web_info", [])
        # self.language = task.raw_task.get("language", check_language_by_frequency(self.task_desc[0]['content']))
        self.action_truth = task.raw_task.get("decision_truth", "continue")
        if self.action_truth == None:
            self.action_truth = "continue"
        if self.action_truth == "continue":
            from trinity.plugins.prompt_v2 import (
                rollout_prompt_en as system_prompt,
            )
            self.info_truth = task.raw_task["info_truth"] if "info_truth" in task.raw_task else "None"  # type: ignore
        else:
            from trinity.plugins.prompt_v2 import (
                summary_prompt_en as system_prompt,
            )
            self.finegrained_query = task.raw_task["finegrained_query"] if "finegrained_query" in task.raw_task else "None"  # type: ignore
        self.system_prompt = system_prompt

    def set_repeat_times(self, repeat_times, run_id_base):
        self.repeat_times = repeat_times
        self.task.rollout_args.n = repeat_times
        self.run_id_base = run_id_base

    def format_messages(self):
        """Format messages for the instruct model."""
        if isinstance(self.task_desc, list):
            if self.action_truth == "continue":
                messages = [{"role": "system", "content": self.system_prompt}] + self.task_desc
            else:
                messages = [{"role": "system", "content": self.system_prompt.format(self.task_desc[0]['content'])}] + self.task_desc
        elif isinstance(self.task_desc, str):
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": self.task_desc},
            ]
        else:
            raise ValueError("`task_desc` must be a list or a string")
        return messages

    def parse_tag_string(self, text):
        pattern = r"<(\w+)>(.*?)</\1>"
        matches = re.findall(pattern, text)
        result = {}
        for tag, value in matches:
            result[tag] = value
        return result

    def merge_msg_list(self, msg_list):
        result_str = ""
        for msg in msg_list:
            if msg["role"] == "user":
                result_str += f"user: {msg['content']}\n"
            if msg["role"] == "assistant":
                result_str += f"assistant: {msg['content']}\n"
        return result_str

    def _process_single_response(self, args: tuple[int, Experience]):
        """
        Helper function to process a single response in parallel.
        Calculates reward and logs information.
        """
        index, response = args
        content_score, format_score = self.reward_fn(response=response.response_text)
        reward = 2 * content_score + format_score
        response.reward = reward
        res_text = response.response_text.replace("\n", " ")
        logger.info(
            f"cid: {self.raw_task.get('cid', 'xxx')}, repeat: {index}, content: {content_score}, format: {format_score}, response: {res_text}"
        )

    def run(self) -> List[Experience]:
        # TODO: Optimize the generate function
        messages = self.format_messages()
        logger.debug("start chat")
        responses = self.model.chat(messages, **self.rollout_args)
        with ThreadPoolExecutor() as executor:
            tasks = list(enumerate(responses))
            list(executor.map(self._process_single_response, tasks))
        return responses

    def call_local_vllm(self, client, messages, reward_model_stream=False):
        completion = client.chat.completions.create(
            model=client.model_path, messages=messages, stream=reward_model_stream
        )

        if not reward_model_stream:
            content = completion.choices[0].message.content
        else:
            content = ""
            for chunk in completion:
                if chunk.choices:
                    content += chunk.choices[0].delta.content
        return content
    
    def call_online_api(self, messages):
        completion = self.reward_client.chat.completions.create(
            model=self.reward_model_name,
            messages=messages,
            temperature=0.0,  
        )
        content = completion.choices[0].message.content
        return content

    def llm_reward(self, response, mode):
        score_dict = {}
        if "onpolicy" in self.raw_task.get('cid'):
            print(f"{self.raw_task.get('cid')}: Mode is ", mode)
        if mode == 'clarify':
            if self.judge_mode == "embedding":
                from trinity.plugins.prompt_v2 import clarify_reward_prompt_onlyformat_en as reward_prompt
                messages = [
                    {"role": "user", "content": reward_prompt.format(response)},
                ]
            else:
                from trinity.plugins.prompt_v2 import clarify_reward_prompt_en as reward_prompt
                messages = [
                    {"role": "system", "content": reward_prompt.format(self.info_truth)},
                    {"role": "user", "content": f"Follow-up Question: {response}"},
                    # {"role": "user", "content": history},
                ]
        elif mode == 'summary':
            from trinity.plugins.prompt_v2 import summary_reward_prompt_en as reward_prompt
            messages = [
                {"role": "system", "content": reward_prompt.format(self.finegrained_query)},
                {"role": "user", "content": f"Genearted Query: {response}"},
            ]
        else:
            raise NotImplementedError
        try_count, max_retries = 0, 50
        irrelavant_panalty = 0
        history_question = [q['content'] for q in self.task_desc if q["role"] == "assistant"]
        history_user = [q['content'] for q in self.task_desc if q["role"] == "user"]
        while try_count <= max_retries:
            try:
                reward_model_stream = False
                if self.judge_mode == "local" or self.judge_mode == "embedding":
                    client = self.auxiliary_models[0]
                    content = self.call_local_vllm(client, messages, reward_model_stream)
                elif self.judge_mode == "online":
                    content = self.call_online_api(messages)
                else:
                    raise NotImplementedError
                score_dict = self.parse_tag_string(content)
                if self.judge_mode == "embedding" and mode == 'clarify':
                    score_dict['content_score'] = call_embedding_model([response], self.info_truth)
                    if score_dict['content_score'] < 0.85:
                        irrelavant_panalty = (history_user.count(IRRELEVANT_NOTE_EN) + history_user.count(IRRELEVANT_NOTE_ZH))/2
                        if "onpolicy" in self.raw_task.get('cid'):
                            print(f"{self.raw_task.get('cid')}: irrelavant_en: {history_user.count(IRRELEVANT_NOTE_EN)} irrelavant_zh: {history_user.count(IRRELEVANT_NOTE_ZH)} history: {history_user}")
                break

            except Exception as e:
                try_count += 1
                if try_count > max_retries:
                    logger.warning("retried too many times, abort task.")
                    break
                logger.warning(f"error: {e}, response:{response}, retries: {try_count}")
                time.sleep(try_count+1)
        if self.train_mode == 'r+p' and mode == 'clarify' and len(self.task_desc) > 1:
            while try_count <= max_retries:
                try:
                    if call_embedding_model([response], history_question) > 0.85:
                        overlap = (history_user.count(REPEAT_NOTE_EN) + history_user.count(REPEAT_NOTE_ZH))/2 + 1
                        if "onpolicy" in self.raw_task.get('cid'):
                            print(f"{self.raw_task.get('cid')}: overlap_en: {history_user.count(REPEAT_NOTE_EN)} overlap_zh: {history_user.count(REPEAT_NOTE_ZH)} history: {history_user}")
                    else:
                        overlap = 0
                    from trinity.plugins.prompt_v2 import invalid_panalty_prompt_en as panalty_prompt
                    panalty_messages = [
                        {"role": "system", "content": panalty_prompt},
                        {"role": "user", "content": f"# User Request: {self.task_desc[0]['content']}\n# Clarification Question: {response}"},
                        # {"role": "user", "content": history},
                    ]
                    panalty_content = self.call_local_vllm(client, panalty_messages, reward_model_stream)
                    panalty_dict = self.parse_tag_string(panalty_content)
                    invalid = int(panalty_dict.get("verdict", 0))
                    panalty = max(overlap*2, invalid, irrelavant_panalty)
                    score_dict.update({"panalty": panalty})
                    break
                except Exception as e:
                    try_count += 1
                    if try_count > max_retries:
                        logger.warning("retried too many times, abort task.")
                        return {}
                    else:
                        logger.warning(f"error: {e}, response:{response}, retries: {try_count}")
                    time.sleep(try_count+1)
        return score_dict
    def reward_vote(self, data_list):
        if not data_list:
            return 0.0
        counts = Counter(data_list)
        max_count = max(counts.values())
        most_frequent_numbers = [num for num, count in counts.items() if count == max_count]
        if len(most_frequent_numbers) > 2:
            return most_frequent_numbers[1]
        return min(most_frequent_numbers)
    def reward_mean(self, data_list):
        if not data_list:
            return 0.0
        return sum(data_list)/len(data_list)
    def reward_fn(self, response):
        """
        content_score: R_a, the reward for response quality
        format_score: P, the reward for response format
        """
        mode = 'clarify' if self.action_truth == "continue" else 'summary'
        score_dicts = []
        workers = 1
        with ThreadPoolExecutor(max_workers=workers) as executor:
            futures = [executor.submit(self.llm_reward, response=response, mode=mode) for _ in range(workers)]
            for future in as_completed(futures):
                try:
                    score_dicts.append(future.result())
                except Exception as e:
                    logger.error(f"A reward calculation task failed: {e}")
                    score_dicts.append({})

        format_score_list = []
        content_score_list = []
        panalty_list = []
        for score_dict in score_dicts:
            panalty = 0
            if self.action_truth == "continue":
                # score_dict = self.llm_reward(response=response, mode='clarify')
                if score_dict != {}:
                    panalty = score_dict.get("panalty", 0)
                    format_score = float(score_dict.get("format_score", 0.0))
                    # The content score is set to 0 when asking a repetitive question
                    content_score = float(score_dict.get("content_score", 0.0))
                else:
                    format_score, content_score = 0.0, 0.0
            else:
                if score_dict != {}:
                    format_score = float(score_dict.get("format", 0))
                    completeness_score = int(score_dict.get("completeness", 1))
                    accuracy_score = int(score_dict.get("accuracy", 1))
                    content_score = float((completeness_score+accuracy_score-2)/8)
                    if content_score < 1/3:
                        content_score = 0.0
                    elif content_score > 2/3:
                        content_score = 1.0
                    else:
                        content_score = 0.5
                else:
                    format_score, content_score= 0.0, 0.0
                # Check if response contains summarize token with proper format
            format_score_list.append(format_score)
            content_score_list.append(content_score)
            panalty_list.append(panalty)
        final_panalty = self.reward_vote(panalty_list)
        if final_panalty != 0:
            final_content_score = -1 * final_panalty
        else:
            final_content_score = self.reward_vote(content_score_list)
        final_format_score = self.reward_vote(format_score_list)

        return final_content_score, final_format_score
