from typing import Optional, Union, Dict, List, Any
import json
import re
import torch
import requests
import time
import traceback
import numpy as np
from functools import partial
import tensordict
from tensordict import TensorDict
from roll.configs.worker_config import WorkerConfig
from roll.distributed.executor.worker import Worker
from roll.distributed.scheduler.decorator import Dispatch, register
from roll.distributed.scheduler.protocol import DataProto
from roll.distributed.strategy.factory import create_strategy
from roll.distributed.strategy.strategy import InferenceStrategy, TrainStrategy
from roll.models.model_providers import default_tokenizer_provider, default_reward_model_provider
from roll.utils.logging import get_logger
from roll.utils.context_managers import state_offload_manger
from roll.utils.prompt import *
from roll.datasets.chat_template import get_chat_template
import pandas as pd
from .process import process_stage2_output
from rouge_score import rouge_scorer

reward_prompt_follow = """
Below is a dialogue scenario. Please rate and compare the candidate actions for their "strategy alignment" with the given strategy for the current turn.

Scenario (placeholders—please substitute):
- Participants: {agent1_name} and {agent2_name}
- Dialogue history: {history}
- Current actor: {current_agent}
- Current strategy description (Guidance and Objective for this turn): {strategy}
- Candidate actions (8): {actions}

Task objective:
- Determine for each candidate action how well it aligns with the provided strategy for this turn: whether it covers the key points indicated by the strategy (what to say, the style to adopt, points to pay attention to such as protecting reputation/privacy, avoiding conflict, guiding the other's reply, etc.), and assess the likelihood that it will achieve the strategy’s intended effects.

Your instructions:
- Perform a listwise evaluation of all 8 candidate actions: compare their relative strengths and weaknesses in the same context.
- For each action, give a concise rationale (2–4 sentences) explaining how the action aligns or deviates from the strategy, and list main strengths and risks (for example: risk of leaking private information, provoking conflict, or successfully guiding the other’s reply).
- Assign an integer score (0–10) to each action, using strategy alignment as the primary criterion.

Scoring scale (overall 0–10):
- 0–2: Completely misaligned or harmful.
- 3–4: Low alignment; obvious risks or insufficient to achieve the strategy’s goal.
- 5–6: Moderate alignment; actionable but may need improvement; limited impact.
- 7–8: High alignment; specific and feasible; low risk; good expected effect.
- 9–10: Very high alignment; fully matches strategy points; highly feasible and likely to achieve the intended effect.
- Empty actions (if a candidate is empty) should be scored 0.

Output requirements (strict):
- Return ONLY JSON, with no extra explanatory text.
- The JSON must include action_1 through action_8 (each containing "reason" and "score").
- Each "reason" must be concise (2–4 sentences), reference the dialogue context or strategy points, and state main pros and cons.
- "score" must be an integer from 0–10.

JSON template example (please keep keys and format):
{{
  "action_1": {{"reason": "……", "score": 0-10}},
  "action_2": {{"reason": "……", "score": 0-10}},
  ...
  "action_8": {{"reason": "……", "score": 0-10}}
}}
"""

def extract_first_to_last_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.rfind('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]

def extract_first_to_first_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.find('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]
def get_answer_length_score(num_tokens: int, used_tokens: int):
    alpha = 1/75
    beta = alpha

    delta = used_tokens - abs(num_tokens)
    sc = 0
    if delta < 0:
        sc = beta * delta * -1
    else:
        sc = alpha * delta * -1
    # Clip sc to [-1, 1]
    sc = max(-1, min(1, sc))
    return (sc + 1)/2

def to_scores_and_k(data):
    # 收集 (编号, 分数)
    pairs = []
    for key, val in data.items():
        m = re.fullmatch(r'action_(\d+)', key)
        if m:
            idx = int(m.group(1))
            pairs.append((idx, val["score"]))

    pairs.sort(key=lambda x: x[0])
    scores = [s/10 for _, s in pairs]
    return scores

def rouge2_f1(s1,s2):
    scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
    return scorer.score(s1, s2)['rouge2'].fmeasure


class SotopiaFollowRewardWorker(Worker):
    """
    Reward Worker that uses LLM-as-judge to compute rewards.
    """

    def __init__(self, worker_config: WorkerConfig):
        super().__init__(worker_config=worker_config)
        self.rank_info.dp_rank = self.rank_info.rank
        self.rank_info.dp_size = self.rank_info.world_size
        self.tokenizer = None
        self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None

        # LLM judge相关配置
        self.judge_prompt = self.worker_config.judge_prompt if hasattr(self.worker_config, "judge_prompt") else None
        # self.judge_prompt = prompt_maps[self.judge_prompt]
        self.judge_model_type = (
            self.worker_config.judge_model_type if hasattr(self.worker_config, "judge_model_type") else "api"
        )
        self.judge_model_name = (
            self.worker_config.judge_model_name if hasattr(self.worker_config, "judge_model_name") else None
        )
        self.judge_api_url = self.worker_config.judge_api_url if hasattr(self.worker_config, "judge_api_url") else None
        self.judge_api_key = self.worker_config.judge_api_key if hasattr(self.worker_config, "judge_api_key") else None

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def initialize(self, pipeline_config):
        super().initialize(pipeline_config)
        if self.judge_model_type == "api":
            self.tokenizer = default_tokenizer_provider(model_args=self.worker_config.model_args)
            print(f"{self.worker_name} initialized with API model")

        elif self.judge_model_type == "inference":
            self.strategy = create_strategy(worker=self)
            self.strategy.initialize(model_provider=default_reward_model_provider)
            self.tokenizer = self.strategy.tokenizer
            print(f"{self.worker_name} initialized with inference model")
            self.strategy.offload_states()
        else:
            raise ValueError(f"Unsupported model type: {self.judge_model_type}")

    def _call_api_model(self, messages: Dict, retry_times=3) -> str:
        from openai import OpenAI

        output = ""
        if not self.judge_api_url or not self.judge_api_key:
            raise ValueError("API URL and API key must be provided for API model type")
        while retry_times > 0:
            retry_times -= 1
            try:
                client = OpenAI(
                    api_key=self.judge_api_key,
                    base_url=self.judge_api_url,
                )
                completion = client.chat.completions.create(model=self.judge_model_name, messages=messages)
                output = completion.choices[0].message.content
                total_tokens = completion.usage.total_tokens
                prompt_token = completion.usage.prompt_tokens
                completion_token = completion.usage.completion_tokens
                token_info = {
                    "total_tokens": total_tokens,
                    "prompt_token": prompt_token,
                    "completion_token": completion_token,
                }
                print(token_info)
                if output != None and output != "":
                    break
            except Exception as e:
                print(e)
                continue
        # self.logger.info(f"judge model api output: {str(output)}")
        return output

    def _run_local_inference(self, messages: Dict) -> str:
        if not self.strategy:
            raise ValueError("Strategy not initialized for local inference")

        template_name = self.worker_config.data_args.template
        chat_template_func = get_chat_template(template_name, self.tokenizer)
        text = chat_template_func(messages)

        tokenized = self.tokenizer(text, return_tensors="pt")
        input_ids = tokenized["input_ids"].to("cuda")
        attention_mask = tokenized["attention_mask"].to("cuda")

        generation_config = self.worker_config.generating_args.to_dict()
        generation_config["eos_token_id"] = [self.tokenizer.eos_token_id]
        generation_config["pad_token_id"] = self.tokenizer.pad_token_id

        data = DataProto(
            batch=TensorDict({"input_ids": input_ids, "attention_mask": attention_mask}, batch_size=input_ids.shape[0])
        )
        data = data.to("cuda")
        data.meta_info = {"micro_batch_size": self.worker_config.infer_batch_size}

        with torch.no_grad():
            output = self.strategy.generate(batch=data, generation_config=generation_config)
            if isinstance(output, torch.Tensor):
                generate_ids = output[:, len(input_ids[0]) :]
            else:
                generate_ids = output.batch["input_ids"][:, len(input_ids[0]) :]

        output = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True)
        # self.logger.info(f"judge model inference output: {str(output)}")
        return output.strip()

    def _format_follow_judge_prompt(self, prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn) -> str:
        # messages = self._format_judge_prompt(prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)
        # TODO 要改,history可能不够
        parts = []
        for i, p in enumerate(policy_li, 1):
            s = p if isinstance(p, str) else json.dumps(p, ensure_ascii=False)
            parts.append(f"Action {i}: {s}")
        policy_string = "\n".join(parts)
        history = "\n".join(dialog[:-1])
        current_agent = agent1_name if turn%2==0 else agent2_name
        current_agent_goal = agent1_goal if turn%2==0 else agent2_goal
        start_str = "This is a well-thought-out strategy: "
        start_idx = prompt.find(start_str)
        end_idx = prompt.find("\nAnd the mode of")
        formatted_prompt = reward_prompt_follow.format(history=history,
                                                agent1_name=agent1_name, 
                                                agent2_name=agent2_name, 
                                                actions=policy_string,
                                                current_agent=current_agent,
                                                # current_agent_goal=current_agent_goal,
                                                strategy=prompt[start_idx + len(start_str):end_idx])
        messages = [{"role": "user", "content": formatted_prompt}]
        return messages

    def _get_llm_follow_judgment(self, prompt: str, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal, turn) -> float:
        # reward_li,k = self._get_llm_judgment(prompt_txt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)
        messages = self._format_follow_judge_prompt(prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)

        if self.judge_model_type == "api":
            llm_response = self._call_api_model(messages)
        elif self.judge_model_type == "inference":
            llm_response = self._run_local_inference(messages)
        else:
            raise ValueError(f"Unsupported model type: {self.judge_model_type}")

        try:
            res = json.loads(extract_first_to_last_brace(llm_response))
            reward_li = to_scores_and_k(res)
        except json.JSONDecodeError as e:
            return [0.5]*len(policy_li)
        except KeyError as e:
            return [0.5]*len(policy_li)
        except TypeError as e:
            return [0.5]*len(policy_li)

        # self.logger.info(f"奖励列表: {reward_li}，最优选{k}")
        return reward_li


    # @register(dispatch_mode=Dispatch.DP_MP_COMPUTE, clear_cache=False)
    # def compute_rewards(self, data: DataProto):
    #     # self.logger.info(f"输入维度{data}")
    #     global_step = data.meta_info.get("global_step", 0)
    #     is_offload_states = data.meta_info.get("is_offload_states", True)
    #     metrics = {}

    #     if self.judge_model_type == "inference" and self.strategy:
    #         with state_offload_manger(
    #             strategy=self.strategy,
    #             metrics=metrics,
    #             metric_infix=f"{self.cluster_name}/compute_rewards",
    #             is_offload_states=is_offload_states,
    #         ):
    #             return self._compute_rewards_impl(data, metrics)
    #     else:
    #         return self._compute_rewards_impl(data, metrics)

    def _compute_follow_rewards_impl(self, data: DataProto, metrics: Dict):
        prompts_text_list = self.tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=True)
        response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=True)

        scores = []
        li_rewards = []
        action_li = []
        for prompt_txt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal, turn, state in zip(
        prompts_text_list,response_text_list,data.non_tensor_batch["dialog"],
        data.non_tensor_batch['agent1_name'],data.non_tensor_batch['agent1_goal'],
        data.non_tensor_batch['agent2_name'],data.non_tensor_batch['agent2_goal'],
        data.non_tensor_batch['turn'], data.non_tensor_batch["state"]
        ):
            try:
                # 如果格式错误直接给空策略
                tmpdict = process_stage2_output(extract_first_to_first_brace(response))
                response0 = tmpdict['action_type']+": "+tmpdict['argument']
            except json.JSONDecodeError as e:
                response0 = "vide action"
            except KeyError as e:
                response0 = "vide action"
            action_li.append(response0)
        
        reward_li = self._get_llm_follow_judgment(prompt_txt, action_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)


        scores_tensor = torch.tensor(reward_li, dtype=torch.float16)
        token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16)
        response_level_rewards = scores_tensor


        # custom_metrics = {
        #     "bargain_rule/rewards/reward_mean": scores_tensor.mean().item(),
        #     "bargain_rule/rewards/length": li_len.mean().item(),
        # }

        # output = DataProto.from_dict(
        #     tensors={
        #         "token_level_rewards": token_level_rewards,
        #         "response_level_rewards": response_level_rewards,
        #         "scores": scores_tensor
        #     },
        #     meta_info={"custom_metrics": {
        #         "followed_rewards": np.mean(reward_li)
        #     }} 
        # )
        # self.logger.info(f"策略遵循奖励——最后奖励长这样{scores_tensor}")
        # output.meta_info = {"metrics": mean_dict}
        # print(f"Computed rewards for {len(scores)} samples")
        follow_dict = {
                "token_level_rewards": token_level_rewards,
                "response_level_rewards": response_level_rewards,
                "scores": scores_tensor
            }
            
        return  scores_tensor,np.mean(reward_li)
