from typing import Optional, Union, Dict, List, Any
import json
import re
import torch
import requests
import time
import traceback
import numpy as np
from functools import partial
import tensordict
from tensordict import TensorDict
from roll.configs.worker_config import WorkerConfig
from roll.distributed.executor.worker import Worker
from roll.distributed.scheduler.decorator import Dispatch, register
from roll.distributed.scheduler.protocol import DataProto
from roll.distributed.strategy.factory import create_strategy
from roll.distributed.strategy.strategy import InferenceStrategy, TrainStrategy
from roll.models.model_providers import default_tokenizer_provider, default_reward_model_provider
from roll.utils.logging import get_logger
from roll.utils.context_managers import state_offload_manger
from roll.utils.prompt import *
from roll.datasets.chat_template import get_chat_template
import pandas as pd
from .sotopia_follow_reward_worker import SotopiaFollowRewardWorker
from .process import process_stage2_output
from rouge_score import rouge_scorer

reward_prompt = """{history}

Based on previous interactions, evaluate how well participants achieve their goals. 

[Information]
Agent1: {agent1_name}
Agent1's Goal: {agent1_goal}

Agent2: {agent2_name}
Agent2's Goal: {agent2_goal}

[Requirements]
1. Please first reiterate agent's social goals. And then please provide a comprehensive analysis about the extent to which the agent has managed to achieve these goals. In the "reasoning" field, provide a comprehensive account of the logic or thought process that led you to your conclusion. Further, provide an integer score ranging from 0 and 10 in the "score" field. 0 represents minimal goals achievement, 10 represents complete goal achievement, and a higher score indicates that the agent is making progress towards their social goals.
2. Please following the output format.

Here is the output schema:
```json
{{
    "agent1": {{
        "reasoning": "",
        "score": "", 
    }},
    "agent2": {{
        "reasoning": "",
        "score": "", 
    }}
}}
```
Please provide your response directly below this prompt."""

def extract_first_to_last_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.rfind('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]

def extract_first_to_first_brace(s: str):
    if not isinstance(s, str):
        return "None"
    start = s.find('{')
    end = s.find('}')
    if start == -1 or end == -1 or end < start:
        return "None"
    return s[start:end+1]
def get_answer_length_score(num_tokens: int, used_tokens: int):
    alpha = 1/75
    beta = alpha

    delta = used_tokens - abs(num_tokens)
    sc = 0
    if delta < 0:
        sc = beta * delta * -1
    else:
        sc = alpha * delta * -1
    # Clip sc to [-1, 1]
    sc = max(-1, min(1, sc))
    return (sc + 1)/2

def reward_format(data):
    match = re.search(r'\{.*\}', data, re.DOTALL)
    if match:
        json_str = match.group(0)
        try:
            try:
                data = json.loads(match.group(0))
            except:
                try:
                    agent1_str = re.search(r'"agent1": "(.*?)"', json_str).group(1)
                    reason1_str = re.search(r'"reasoning": "(.*?)"', agent1_str).group(1)
                    score1_str = re.search(r'"score": "(.*?)"', agent1_str).group(1)
                    agent2_str = re.search(r'"agent2": "(.*?)"', json_str).group(1)
                    reason2_str = re.search(r'"reasoning": "(.*?)"', agent2_str).group(1)
                    score2_str = re.search(r'"score": "(.*?)"', agent2_str).group(1)
                    data = {"agent1": {"reasoning": reason1_str, "score": score1_str},
                            "agent2": {"reasoning": reason2_str, "score": score2_str}}
                except:
                    return []
            assert 'agent1' in data, 'agent1 should be in the data'
            assert 'agent2' in data, 'agent2 should be in the data'
            assert 'reasoning' in data['agent1'], 'judgment should be in the data'
            assert 'reasoning' in data['agent2'], 'judgment should be in the data'
            assert 'score' in data['agent1'], 'score should be in the data'
            assert 'score' in data['agent2'], 'score should be in the data'
            data['agent1']['score'] = float(data['agent1']['score'])
            data['agent2']['score'] = float(data['agent2']['score'])
            assert 0 <= data['agent1']['score'] <= 10, 'score should be in the range of 0 to 10'
            assert 0 <= data['agent2']['score'] <= 10, 'score should be in the range of 0 to 10'
        except Exception as e:
            return []
    else:
        return []
    return data

def extract_rewards(data, turn):
    if turn % 2 == 0:
        return {
            'actor': float(data['agent1']['score']),
            'oppo': float(data['agent2']['score'])
        }
    else:
        return {
            'actor': float(data['agent2']['score']),
            'oppo': float(data['agent1']['score'])
        }

def history_format(dialog,response):
    # try:
    #     response = response.split('<answer>')[-1].split('</answer>')[0].strip()
    #     # 分隔词可能要改的
    # except Exception as e:
    #     print("="*20)
    #     print("Processing response")
    #     print(e)
    #     print(response)
    #     response = response.strip().split('\n')[0]
    # temp_list = dialog.copy()
    try:
        response0 = process_stage2_output(extract_first_to_first_brace(response))["argument"]
    except json.JSONDecodeError as e:
        response0 = "None"
    except KeyError as e:
        response0 = "None"
    except TypeError as e:
        response0 = "None"
    temp_list = dialog.copy()
    temp_list[-1] += " " + str(response0) + "\n"
    history = "\n".join(temp_list).strip()
    del temp_list
    return history

def scale_gradient(grad, current_state):
    if grad == 0:
        return 0
    upper = 10 - current_state
    lower = current_state
    return grad / upper if grad > 0 else grad / lower

def repeat_p(s1,s2):
    scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
    return scorer.score(s1, s2)['rouge2'].fmeasure

class SotopiaRewardWorker(SotopiaFollowRewardWorker):
    """
    Reward Worker that uses LLM-as-judge to compute rewards.
    """

    def __init__(self, worker_config: WorkerConfig):
        super().__init__(worker_config=worker_config)
        self.vide = 0
        # self.rank_info.dp_rank = self.rank_info.rank
        # self.rank_info.dp_size = self.rank_info.world_size
        # self.tokenizer = None
        # self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None

        # # LLM judge相关配置
        # self.judge_prompt = self.worker_config.judge_prompt if hasattr(self.worker_config, "judge_prompt") else None
        # # self.judge_prompt = prompt_maps[self.judge_prompt]
        # self.judge_model_type = (
        #     self.worker_config.judge_model_type if hasattr(self.worker_config, "judge_model_type") else "api"
        # )
        # self.judge_model_name = (
        #     self.worker_config.judge_model_name if hasattr(self.worker_config, "judge_model_name") else None
        # )
        # self.judge_api_url = self.worker_config.judge_api_url if hasattr(self.worker_config, "judge_api_url") else None
        # self.judge_api_key = self.worker_config.judge_api_key if hasattr(self.worker_config, "judge_api_key") else None

    # @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    # def initialize(self, pipeline_config):
    #     super().initialize(pipeline_config)
    #     if self.judge_model_type == "api":
    #         self.tokenizer = default_tokenizer_provider(model_args=self.worker_config.model_args)
    #         print(f"{self.worker_name} initialized with API model")

    #     elif self.judge_model_type == "inference":
    #         self.strategy = create_strategy(worker=self)
    #         self.strategy.initialize(model_provider=default_reward_model_provider)
    #         self.tokenizer = self.strategy.tokenizer
    #         print(f"{self.worker_name} initialized with inference model")
    #         self.strategy.offload_states()
    #     else:
    #         raise ValueError(f"Unsupported model type: {self.judge_model_type}")

    # def _call_api_model(self, messages: Dict, retry_times=3) -> str:
    #     from openai import OpenAI

    #     output = ""
    #     if not self.judge_api_url or not self.judge_api_key:
    #         raise ValueError("API URL and API key must be provided for API model type")
    #     while retry_times > 0:
    #         retry_times -= 1
    #         try:
    #             client = OpenAI(
    #                 api_key=self.judge_api_key,
    #                 base_url=self.judge_api_url,
    #             )
    #             completion = client.chat.completions.create(model=self.judge_model_name, messages=messages)
    #             output = completion.choices[0].message.content
    #             total_tokens = completion.usage.total_tokens
    #             prompt_token = completion.usage.prompt_tokens
    #             completion_token = completion.usage.completion_tokens
    #             token_info = {
    #                 "total_tokens": total_tokens,
    #                 "prompt_token": prompt_token,
    #                 "completion_token": completion_token,
    #             }
    #             print(token_info)
    #             if output != None and output != "":
    #                 break
    #         except Exception as e:
    #             print(e)
    #             continue
    #     # self.logger.info(f"judge model api output: {str(output)}")
    #     return output

    # def _run_local_inference(self, messages: Dict) -> str:
    #     if not self.strategy:
    #         raise ValueError("Strategy not initialized for local inference")

    #     template_name = self.worker_config.data_args.template
    #     chat_template_func = get_chat_template(template_name, self.tokenizer)
    #     text = chat_template_func(messages)

    #     tokenized = self.tokenizer(text, return_tensors="pt")
    #     input_ids = tokenized["input_ids"].to("cuda")
    #     attention_mask = tokenized["attention_mask"].to("cuda")

    #     generation_config = self.worker_config.generating_args.to_dict()
    #     generation_config["eos_token_id"] = [self.tokenizer.eos_token_id]
    #     generation_config["pad_token_id"] = self.tokenizer.pad_token_id

    #     data = DataProto(
    #         batch=TensorDict({"input_ids": input_ids, "attention_mask": attention_mask}, batch_size=input_ids.shape[0])
    #     )
    #     data = data.to("cuda")
    #     data.meta_info = {"micro_batch_size": self.worker_config.infer_batch_size}

    #     with torch.no_grad():
    #         output = self.strategy.generate(batch=data, generation_config=generation_config)
    #         if isinstance(output, torch.Tensor):
    #             generate_ids = output[:, len(input_ids[0]) :]
    #         else:
    #             generate_ids = output.batch["input_ids"][:, len(input_ids[0]) :]

    #     output = self.tokenizer.decode(generate_ids[0], skip_special_tokens=True)
    #     # self.logger.info(f"judge model inference output: {str(output)}")
    #     return output.strip()

    # def _extract_score(self, response: str) -> float:
    #     try:
    #         match = re.search("Score: ([0-9.]+)", response)
    #         if match:
    #             score = float(match.group(1))
    #             normalized_score = score / 10
    #             return normalized_score
    #         else:
    #             self.logger.warning(f"Could not extract score from response: {response}")
    #             return 0.5
    #     except Exception as e:
    #         self.logger.error(f"Error extracting score: {e}")
    #         return 0.5

    # def _extract_score_v2(self, response: str) -> float:
    #     response = response.lower()
    #     try:
    #         if "yes" in response:
    #             return 1
    #         elif "no" in response:
    #             return 0
    #         else:
    #             self.logger.warning(f"Could not extract score from response: {response}")
    #             return 0
    #     except Exception as e:
    #         self.logger.error(f"Error extracting score: {e}")
    #         return 0

    def _format_judge_prompt(self, prompt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal) -> str:
        if "user\n" in prompt:
            prompt = prompt.split("user\n")[-1].strip()

        history = history_format(dialog,response)
        formatted_prompt = reward_prompt.format(history=history,
                                                agent1_name=agent1_name, 
                                                agent1_goal=agent1_goal, 
                                                agent2_name=agent2_name, 
                                                agent2_goal=agent2_goal)
        messages = [{"role": "user", "content": formatted_prompt}]
        return messages

    def _get_llm_judgment(self, prompt: str, response: str, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn) -> float:
        messages = self._format_judge_prompt(prompt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal)

        if self.judge_model_type == "api":
            llm_response = self._call_api_model(messages)
        elif self.judge_model_type == "inference":
            llm_response = self._run_local_inference(messages)
        else:
            raise ValueError(f"Unsupported model type: {self.judge_model_type}")

        formatted_result = reward_format(llm_response)

        fail_list = []
        reward = None
        if formatted_result:
            try:
                reward = extract_rewards(formatted_result, turn)
            except Exception as e:
                fail_list.append(i)

        info = {
            "prompt": prompt,
            "response": response,
            "messages": messages,
            "llm_response": llm_response,
        }
        return reward

    @register(dispatch_mode=Dispatch.DP_MP_COMPUTE, clear_cache=False)
    def compute_rewards(self, data: DataProto):
        # self.logger.info(f"奖励的输入长度{len(data)}")
        global_step = data.meta_info.get("global_step", 0)
        is_offload_states = data.meta_info.get("is_offload_states", True)
        metrics = {}

        if self.judge_model_type == "inference" and self.strategy:
            with state_offload_manger(
                strategy=self.strategy,
                metrics=metrics,
                metric_infix=f"{self.cluster_name}/compute_rewards",
                is_offload_states=is_offload_states,
            ):
                return self._compute_rewards_impl(data, metrics)
        else:
            return self._compute_rewards_impl(data, metrics)

    def _compute_rewards_impl(self, data: DataProto, metrics: Dict):
        prompts_text_list = self.tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=True)
        response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=True)

        scores = []
        li_rewards = []
        # li_len = []
        actor_scores = []
        if not re.search(r'is: (\w+)-oriented \(Go', prompts_text_list[0]):
            # 有问题这时候，直接排除了
            self.logger.info(f"输入混乱！！！！！！！！！")
            output = DataProto.from_dict(
            tensors={
                    "token_level_rewards": torch.zeros_like(data.batch["responses"], dtype=torch.float16),
                    "response_level_rewards": torch.tensor([0.5]*len(prompts_text_list), dtype=torch.float16),
                    "scores": torch.tensor([0.5]*len(prompts_text_list), dtype=torch.float16),
                },
                meta_info={"custom_metrics": metrics},
            )
            return output

        # if len(prompts_text_list)==8 and re.search(r'is: (\w+)-oriented \(Go', prompts_text_list[0]).group(1)=="social":
        #     return self._compute_follow_rewards_impl(data, metrics)
        for prompt_txt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal, turn, state in zip(
        prompts_text_list,response_text_list,data.non_tensor_batch["dialog"],
        data.non_tensor_batch['agent1_name'],data.non_tensor_batch['agent1_goal'],
        data.non_tensor_batch['agent2_name'],data.non_tensor_batch['agent2_goal'],
        data.non_tensor_batch['turn'], data.non_tensor_batch["state"]
        ):
            reward = self._get_llm_judgment(prompt_txt, response, dialog, agent1_name, agent1_goal, agent2_name, agent2_goal,turn)
            if reward == None:
                scores.append(0.5)
                continue
            # used_token = data['used_token']
            state_reward = extract_rewards(state, turn)
            reward_dict = {
                        "actor_state": state_reward['actor'],
                        "oppo_state": state_reward['oppo'],
                        "actor_reward": reward['actor'],
                        "oppo_reward": reward['oppo'],
                        "actor_grad": reward['actor'] - state_reward['actor'], 
                        "oppo_grad": reward['oppo'] - state_reward['oppo'],
                        # "turn": turn,
                        # "rollout": rollout,
                        # "level": data['level']
                        # "used_token": used_token,
                        "answer_token": len(self.tokenizer.tokenize(extract_first_to_first_brace(response)))
                    }
            # scaling the grad
            actor_grad = scale_gradient(reward_dict['actor_grad'], reward_dict['actor_state'])
            actor_scores.append(reward['actor'])
            # scale to [0, 1]
            reward_dict['actor_grad_scaled'] = (actor_grad + 1) / 2
            
            sc = reward_dict['actor_grad_scaled']      
            
            reward_dict['grad_score'] = sc
            
            reward_dict['answer_length_score'] = get_answer_length_score(150, reward_dict['answer_token'])
            reward_dict['score'] = reward_dict['grad_score'] * reward_dict['answer_length_score']
            li_rewards.append(reward_dict)
            # li_len.append(reward_dict["answer_token":])
            scores.append(reward_dict["score"])
            # self.logger.info(f"{json.dumps(reward_dict, ensure_ascii=False)}")

        scores_tensor = torch.tensor(scores, dtype=torch.float16)
        token_level_rewards = torch.zeros_like(data.batch["responses"], dtype=torch.float16)
        response_level_rewards = scores_tensor
        df = pd.DataFrame(li_rewards)
        mean_dict = df.mean().to_dict()

        li_tmp = [i["actor_reward"] for i in li_rewards]
        self.logger.info(f"状态列表如下{li_tmp}目前长度如下{len(prompts_text_list)}")

        var_state = np.var(li_tmp)
        self.logger.info(f"计算方差如下{var_state}")
        mean_dict["var"] = var_state
        

        if len(prompts_text_list)==8 and var_state<0.3:
            prescore_tensor,follow_score = self._compute_follow_rewards_impl(data, metrics)
            li_length_weight = torch.tensor([i['answer_length_score'] for i in li_rewards],dtype=torch.float16)
            if len(prescore_tensor)==len(li_length_weight):
                score_tensor = prescore_tensor*li_length_weight
            else:
                score_tensor = prescore_tensor
            mean_dict["if_follow"] = 1
            mean_dict["folow_score"] = follow_score
            output = DataProto.from_dict(
                tensors={
                "token_level_rewards": token_level_rewards,
                "response_level_rewards": scores_tensor,
                "scores": scores_tensor
            },
                meta_info={"custom_metrics": mean_dict},
            )
            return output

        mean_dict["if_follow"] = 0
        mean_dict["folow_score"] = 0
        # self.logger.info(f"这一组的打分是这样的{actor_scores}")
        output = DataProto.from_dict(
            tensors={
                "token_level_rewards": token_level_rewards,
                "response_level_rewards": response_level_rewards,
                "scores": scores_tensor,
            },
            meta_info={"custom_metrics": mean_dict},
        )

        # self.logger.info(f"目标奖励——奖励计算字典{li_rewards[0]}")
        # output.meta_info = {"metrics": mean_dict}
        # print(f"Computed rewards for {len(scores)} samples")
        return output
