from typing import Optional, Union, Dict, List, Any
import json
import re
import torch
import requests
import time
import traceback
import numpy as np
from functools import partial
import tensordict
from tensordict import TensorDict
from roll.configs.worker_config import WorkerConfig
from roll.distributed.executor.worker import Worker
from roll.distributed.scheduler.decorator import Dispatch, register
from roll.distributed.scheduler.protocol import DataProto
from roll.distributed.strategy.factory import create_strategy
from roll.distributed.strategy.strategy import InferenceStrategy, TrainStrategy
from roll.models.model_providers import default_tokenizer_provider, default_reward_model_provider
from roll.utils.logging import get_logger
from roll.utils.context_managers import state_offload_manger
from roll.utils.prompt import *
from roll.datasets.chat_template import get_chat_template
import pandas as pd
from ast import literal_eval
from .process import process_stage1_output

reward_prompt = """
Here is a dialogue between {agent1_name} and {agent2_name}: 
{history}
{current_agent} is about to take an action then.

Task objective:
- Evaluate the candidate policies and select the single best one that maximizes {current_agent}’s likelihood of achieving {current_agent}’s stated goal in the next turn.

{current_agent}’s goal: {current_agent_goal}



Your job:

I will give you a list of candidate policies, you should perform a listwise evaluation: compare all policies relative to each other in the given dialogue context for {current_agent}.
For each policy, briefly reason step by step about its strengths and weaknesses in this specific context (keep it concise: 2–4 sentences).
Then assign each policy one overall score from 0 to 10 using the framework below, with Goal Completion (GOAL) as the primary driver and the other metrics as auxiliary modifiers.
Finally, choose exactly one policy as the final_choice.
Scoring framework:
Primary criteria:
Goal Completion (GOAL): How directly and effectively the policy advances {current_agent}’s defined goal(s) in the dialogue.

Auxiliary checks (use to adjust the overall score up/down):
Believability (BEL): Naturalness and consistency with {current_agent}’s character profile (personality, values).
Knowledge (KNO): Whether the policy helps A acquire new and important information.
Secret (SEC): Risk of leaking {current_agent}’s private information/intentions; severe leaks should significantly reduce the score.
Relationship (REL): Impact on {current_agent}’s relationships/reputation; harm reduces the score, clear improvements can slightly raise it.
Social Rules (SOC): Compliance with legal rules and social norms; violations strongly reduce the score.
Financial/Material (FIN): Expected material benefits or harms; significant costs reduce, clear benefits can slightly raise.
Practical criteria (affect GOAL and expected impact):

Feasibility and specificity: Clear, actionable steps {current_agent} can take immediately.
Context fit and tone: Respectful, non-aggressive, and appropriate to the dialogue history (sometimes quit is the appropriate action).
Risk of backfiring and ethics: Anticipate foreseeable downsides; prefer low-risk, ethically compliant policies.
Expected impact: Magnitude of positive effect toward achieving the goal.
Scoring scale (overall score per policy):

0–2: Misaligned or harmful; not actionable; high risk or unethical.
3–4: Weak alignment or low feasibility; noticeable risks; limited impact.
5–6: Moderately aligned; somewhat actionable; some risks or modest impact.
7–8: Strong alignment; actionable and context-appropriate; low risk; good impact.
9–10: Excellent alignment; highly actionable and tailored; minimal risk; high impact.
And Empty policies must be scored 0.

Candidate policys (8 total):
{policys}

Return ONLY JSON, no extra text.
For each policy key (e.g., "policy_1", "policy_2", ...), include:
"reason": a concise rationale referencing the dialogue context and the criteria above; briefly note strengths and weaknesses.
"score": an integer from 0 to 10 (no decimals).
Provide "final_choice": the key name (e.g., "policy_3") of the selected policy.

JSON schema and example format ("required": ["policy_1","policy_2",...,"policy_8", "final_choice"]):
```json
{{
  "policy_1": {{"reason": "...", "score": 0-10}},
  "policy_2": {{"reason": "...", "score": 0-10}},
  ...
  "policy_8": {{"reason": "...", "score": 0-10}},
  "final_choice": "policy_k"
}}
```
"""

def extract_first_to_last_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.rfind('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]


def extract_first_to_first_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.find('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]
def get_answer_length_score(num_tokens: int, used_tokens: int):
    alpha = 1/75
    beta = alpha

    delta = used_tokens - abs(num_tokens)
    sc = 0
    if delta < 0:
        sc = beta * delta * -1
    else:
        sc = alpha * delta * -1
    # Clip sc to [-1, 1]
    sc = max(-1, min(1, sc))
    return (sc + 1)/2

def to_scores_and_k(data):
    # 收集 (编号, 分数)
    pairs = []
    for key, val in data.items():
        m = re.fullmatch(r'policy_(\d+)', key)
        if m:
            idx = int(m.group(1))
            pairs.append((idx, val["score"]))

    pairs.sort(key=lambda x: x[0])
    scores = [s for _, s in pairs]
    fc = data.get("final_choice", "")
    m = re.fullmatch(r'policy_(\d+)', fc)
    k = int(m.group(1)) if m else 1

    return scores, k


class SotopiaPolicyRewardWorker(Worker):
    """
    Reward Worker that uses LLM-as-judge to compute rewards.
    """

    def __init__(self, worker_config: WorkerConfig):
        super().__init__(worker_config=worker_config)
        self.rank_info.dp_rank = self.rank_info.rank
        self.rank_info.dp_size = self.rank_info.world_size
        self.tokenizer = None
        self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None

        # LLM judge相关配置
        self.judge_prompt = self.worker_config.judge_prompt if hasattr(self.worker_config, "judge_prompt") else None
        # self.judge_prompt = prompt_maps[self.judge_prompt]
        self.judge_model_type = (
            self.worker_config.judge_model_type if hasattr(self.worker_config, "judge_model_type") else "api"
        )
        self.judge_model_name = (
            self.worker_config.judge_model_name if hasattr(self.worker_config, "judge_model_name") else None
        )
        self.judge_api_url = self.worker_config.judge_api_url if hasattr(self.worker_config, "judge_api_url") else None
        self.judge_api_key = self.worker_config.judge_api_key if hasattr(self.worker_config, "judge_api_key") else None

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def initialize(self, pipeline_config):
        super().initialize(pipeline_config)
        if self.judge_model_type == "api":
            self.tokenizer = default_tokenizer_provider(model_args=self.worker_config.model_args)
            print(f"{self.worker_name} initialized with API model")

        elif self.judge_model_type == "inference":
            self.strategy = create_strategy(worker=self)
            self.strategy.initialize(model_provider=default_reward_model_provider)
            self.tokenizer = self.strategy.tokenizer
            print(f"{self.worker_name} initialized with inference model")
            self.strategy.offload_states()
        else:
            raise ValueError(f"Unsupported model type: {self.judge_model_type}")

    def _call_api_model(self, messages: Dict, retry_times=3) -> str:
        from openai import OpenAI

        ouput = ""
        if not self.judge_api_url or not self.judge_api_key:
            raise ValueError("API URL and API key must be provided for API model type")
        while retry_times > 0:
            retry_times -= 1
            try:
                client = OpenAI(
                    api_key=self.judge_api_key,
                    base_url=self.judge_api_url,
                )
                completion = client.chat.completions.create(model=self.judge_model_name, messages=messages)
                output = completion.choices[0].message.content
                total_tokens = completion.usage.total_tokens
                prompt_token = completion.usage.prompt_tokens
                completion_token = completion.usage.completion_tokens
                token_info = {
                    "total_tokens": total_tokens,
                    "prompt_token": prompt_token,
                    "completion_token": completion_token,
                }
                print(token_info)
                if output != None and output != "":
                    break
            except Exception as e:
                print(e)
                continue
        # self.logger.info(f"judge model api output: {str(output)}")
        return output

    def _run_local_inference(self, messages: Dict) -> str:
        if not self.strategy:
            raise ValueError("Strategy not initialized for local inference")

        template_name = self.worker_config.data_args.template
        chat_template_func = get_chat_template(template_name, self.tokenizer)
        text = chat_template_func(messages)

        tokenized = self.tokenizer(text, return_tensors="pt")
        input_ids = tokenized["input_ids"].to("cuda")
        attention_mask = tokenized["attention_mask"].to("cuda")

        generation_config = self.worker_config.generating_args.to_dict()
        generation_config["eos_token_id"] = [self.tokenizer.eos_token_id]
        generation_config["pad_token_id"] = self.tokenizer.pad_token_id

        data = DataProto(
            batch=TensorDict({"input_ids": input_ids, "attention_mask": attention_mask}, batch_size=input_ids.shape[0])
        )
        data = data.to("cuda")
        data.meta_info = {"micro_batch_size": self.worker_config.infer_batch_size}

        with torch.no_grad():
            output = self.strategy.generate(batch=data, generation_config=generation_config)
            if isinstance(output, torch.Tensor):
                generate_ids = output[:, len(input_ids[0]) :]
            else:
                generate_ids = output.batch["input_ids"][:, len(input_ids[0]) :]

        output = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True)
        # self.logger.info(f"judge model inference output: {str(output)}")
        return output.strip()

    def _format_judge_prompt(self, prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn) -> str:
        # messages = self._format_judge_prompt(prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)
        # TODO 要改,history可能不够
        parts = []
        for i, p in enumerate(policy_li, 1):
            s = p if isinstance(p, str) else json.dumps(p, ensure_ascii=False)
            parts.append(f"Policy {i}: {s}")
        policy_string = "\n".join(parts)
        history = "\n".join(dialog[:-1])
        current_agent = agent1_name if turn%2==0 else agent2_name
        current_agent_goal = agent1_goal if turn%2==0 else agent2_goal
        formatted_prompt = reward_prompt.format(history=history,
                                                agent1_name=agent1_name, 
                                                agent2_name=agent2_name, 
                                                policys=policy_string,
                                                current_agent=current_agent,
                                                current_agent_goal=current_agent_goal)
        messages = [{"role": "user", "content": formatted_prompt}]
        return messages

    def _get_llm_judgment(self, prompt: str, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal, turn) -> float:
        # reward_li,k = self._get_llm_judgment(prompt_txt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)
        messages = self._format_judge_prompt(prompt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)

        if self.judge_model_type == "api":
            llm_response = self._call_api_model(messages)
        elif self.judge_model_type == "inference":
            llm_response = self._run_local_inference(messages)
        else:
            raise ValueError(f"Unsupported model type: {self.judge_model_type}")

        try:
            res = json.loads(extract_first_to_last_brace(llm_response))
            reward_li,k = to_scores_and_k(res)
        except json.JSONDecodeError as e:
            return [5]*len(policy_li),1
        except KeyError as e:
            return [5]*len(policy_li),1
        except TypeError as e:
            return [5]*len(policy_li),1

        # self.logger.info(f"奖励列表: {reward_li}，最优选{k}")
        return reward_li,k


    @register(dispatch_mode=Dispatch.DP_MP_COMPUTE, clear_cache=False)
    def compute_rewards(self, data: DataProto):
        # self.logger.info(f"输入维度{data}")
        global_step = data.meta_info.get("global_step", 0)
        is_offload_states = data.meta_info.get("is_offload_states", True)
        metrics = {}

        if self.judge_model_type == "inference" and self.strategy:
            with state_offload_manger(
                strategy=self.strategy,
                metrics=metrics,
                metric_infix=f"{self.cluster_name}/compute_rewards",
                is_offload_states=is_offload_states,
            ):
                return self._compute_rewards_impl(data, metrics)
        else:
            return self._compute_rewards_impl(data, metrics)

    def _compute_rewards_impl(self, data: DataProto, metrics: Dict):
        prompts_text_list = self.tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=True)
        response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=True)

        scores = []
        li_rewards = []
        policy_li = []
        mode_li = []
        for prompt_txt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal, turn, state in zip(
        prompts_text_list,response_text_list,data.non_tensor_batch["dialog"],
        data.non_tensor_batch['agent1_name'],data.non_tensor_batch['agent1_goal'],
        data.non_tensor_batch['agent2_name'],data.non_tensor_batch['agent2_goal'],
        data.non_tensor_batch['turn'], data.non_tensor_batch["state"]
        ):
            try:
                # 如果格式错误直接给空策略
                tmpdict = process_stage1_output(extract_first_to_first_brace(response))
                response0 = tmpdict["strategy"]
                mode0 = tmpdict["mode"]
            except json.JSONDecodeError as e:
                response0 = "vide policy"
                mode0 = "goal-oriented"
            except KeyError as e:
                response0 = "vide policy"
                mode0 = "goal-oriented"
            policy_li.append(response0)
            mode_li.append(mode0)
        # self.logger.info(f"输出长这样{policy_li}")
        
        reward_li,k = self._get_llm_judgment(prompt_txt, policy_li, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)


        scores_tensor = torch.tensor(reward_li, dtype=torch.float16)
        token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16)
        response_level_rewards = scores_tensor


        # custom_metrics = {
        #     "bargain_rule/rewards/reward_mean": scores_tensor.mean().item(),
        #     "bargain_rule/rewards/length": li_len.mean().item(),
        # }
        mode = 0 if mode_li[k-1]=="social-oriented" else 1
        output = DataProto.from_dict(
            tensors={
                "token_level_rewards": token_level_rewards,
                "response_level_rewards": response_level_rewards,
                "scores": scores_tensor,
                "chosed_index":torch.tensor([k]*len(scores_tensor), dtype=torch.float16), # TODO 这里输出最优选k
                "mode": torch.tensor([mode]*len(scores_tensor), dtype=torch.float16)
            },
            meta_info={"custom_metrics": {
                "stage_1_rewards": np.mean(reward_li)
            }} 
        )
        self.logger.info(f"策略奖励——最后奖励长这样{scores_tensor}")
        # output.meta_info = {"metrics": mean_dict}
        print(f"Computed rewards for {len(scores)} samples")
        return output
