# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import os
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict

import torch
from transformers import PreTrainedTokenizer

from ...protocol import DataProto
from .config import RewardConfig
from .length_reward_func import *

class RewardInput(TypedDict):
    response: str
    response_length: int
    ground_truth: str


class RewardScore(TypedDict):
    overall: float
    format: Optional[float]
    accuracy: Optional[float]


SequentialRewardFunction = Callable[[RewardInput], RewardScore]

BatchRewardFunction = Callable[[List[RewardInput]], List[RewardScore]]


class FunctionRewardManager(ABC):
    """Reward manager for rule-based reward."""

    def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
        if config.reward_function is None:
            raise ValueError("Reward function is not provided.")

        if not os.path.exists(config.reward_function):
            raise FileNotFoundError(f"Reward function file {config.reward_function} not found.")

        spec = importlib.util.spec_from_file_location("custom_reward_fn", config.reward_function)
        module = importlib.util.module_from_spec(spec)
        try:
            sys.modules["custom_reward_fn"] = module
            spec.loader.exec_module(module)
        except Exception as e:
            raise RuntimeError(f"Failed to load reward function: {e}")

        if not hasattr(module, config.reward_function_name):
            raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")

        reward_fn = getattr(module, config.reward_function_name)
        print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
        self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
        self.config = config
        self.tokenizer = tokenizer

    @abstractmethod
    def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
        """Compute reward for a batch of data."""
        ...


class SequentialFunctionRewardManager(FunctionRewardManager):
    reward_fn: SequentialRewardFunction

    def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
        reward_metrics = defaultdict(list)
        response_ids = data.batch["responses"]
        response_length = data.batch["response_mask"].sum(dim=-1)
        for i in range(len(data)):
            valid_response_ids = response_ids[i][: response_length[i]]
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            )
            score = self.reward_fn(
                {
                    "response": response_str,
                    "response_length": response_length[i],
                    "ground_truth": data.non_tensor_batch["ground_truth"][i],
                }
            )
            reward_tensor[i, response_length[i] - 1] = score["overall"]
            for key, value in score.items():
                reward_metrics[key].append(value)

        return reward_tensor, reward_metrics


class BatchFunctionRewardManager(FunctionRewardManager):
    reward_fn: BatchRewardFunction

    def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
        reward_inputs = []
        response_ids = data.batch["responses"]
        response_length = data.batch["response_mask"].sum(dim=-1)
        for i in range(len(data)):
            valid_response_ids = response_ids[i][: response_length[i]] 
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            ) 
            reward_inputs.append(
                {
                    "response": response_str,
                    "response_length": response_length[i], 
                    "ground_truth": data.non_tensor_batch["ground_truth"][i], 
                }
            )
        ## NOTE  
        scores = self.reward_fn(reward_inputs) ##
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) ## 【512,4096】
        reward_metrics = defaultdict(list)
        for i, score in enumerate(scores):  # 512
            reward_tensor[i, response_length[i] - 1] = score["overall"]  
            for key, value in score.items():
                reward_metrics[key].append(value)

        return reward_tensor, reward_metrics


class AdaThinkFunctionRewardManager(FunctionRewardManager):
    reward_fn: BatchRewardFunction
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)  # Initialize the parent class
        self.meta = {
            'entropy_mu': 2.5,
            'entropy_var': 1.0,
            'expect_len': {}
        }
    def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
        reward_inputs = []
        response_ids = data.batch["responses"]
        index = data.non_tensor_batch['uid'] 
        response_length = data.batch["response_mask"].sum(dim=-1)
        # breakpoint()
        count = 0
        for i in range(len(data)): 
            valid_response_ids = response_ids[i][: response_length[i]] 
            response_str = self.tokenizer.decode(
                valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
            ) 
            reward_inputs.append(
                {
                    "response": response_str,
                    "response_length": response_length[i], 
                    "ground_truth": data.non_tensor_batch["ground_truth"][i],
                }
            )
            if i%8==0 and count<2:
                count+=1
                print("\n[Prompt]\n",data.non_tensor_batch["question"][i],"\n[output]\n",response_str,"\n", "[Ground Truth]\n",data.non_tensor_batch["ground_truth"][i])
        scores = self.reward_fn(reward_inputs) 
        if self.config.model.model_path:
            rm_scores = data.batch["rm_scores"]
            new_scores = []
            for idx, score in enumerate(scores):
                score['accuracy'] = rm_scores[idx].item()
                new_scores.append(score)
            scores = new_scores
            
        reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) ## 【512,4096】
        
        reward_metrics = defaultdict(list)
        id2score = defaultdict(list)
        id2complexity = {}
        bsz = len(scores)
        for i in range(bsz):
            id2score[index[i]].append(scores[i]['accuracy'])
        for idx in id2score:
            assert len(id2score[idx]) > 1, "GRPO needs rollout.n > 1."
            id2complexity[idx] = 1- sum(id2score[idx])/len(id2score[idx])
        
        # Determine which reward method to use
        method = self.config.length_reward_method
        method_version = self.config.method_version
        
        reward_tensor, reward_metrics, self.meta = entropy_length_reward_question_level(
        data=data,
        config=self.config,
        index=index,
        scores=scores,
        id2complexity={},     
        response_length=response_length,
        reward_tensor=reward_tensor,
        reward_metrics=reward_metrics,
        meta=self.meta       )
        
        return reward_tensor, reward_metrics
    
    