import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoModelForSequenceClassification,
    AutoConfig, 
    BitsAndBytesConfig
)
import math
from openai import OpenAI
import asyncio
from typing import List, Dict, Tuple, Optional, Union
from abc import ABC, abstractmethod
import json
import os
import re

try:
    # from .constants import get_scoring_prompts, OPENAI_API_KEY, SCORING_SYSTEM_PROMPT, OBJECTIVE_SCORING_PROMPT, SCORING_PROMPT_BASE_TEMPLATE, SCORING_OBJECTIVE_DESCRIPTIONS, SCORING_RUBRICS_HH, SCORING_RUBRICS_TLDR
    from .reward_combiner import RewardCombiner, create_reward_combiner
    from .objective_scorer import ObjectiveScorer
except ImportError:
    # from constants import get_scoring_prompts, OPENAI_API_KEY, SCORING_SYSTEM_PROMPT, OBJECTIVE_SCORING_PROMPT, SCORING_PROMPT_BASE_TEMPLATE, SCORING_OBJECTIVE_DESCRIPTIONS, SCORING_RUBRICS_HH, SCORING_RUBRICS_TLDR
    from reward_combiner import RewardCombiner, create_reward_combiner
    from objective_scorer import ObjectiveScorer


class RewardFunction(ABC):
    """
    Abstract base class for all reward functions.
    Defines the interface that all reward functions must implement.
    """
    
    def __init__(
        self,
        device: str = "auto",
        max_length: int = 4096,
        normalize_scores: bool = False,
        cache_dir: str = "../"
    ):
        """
        Initialize base reward function.

        Args:
            device: Device to run models on
            max_length: Maximum sequence length for tokenization
            normalize_scores: Whether to normalize scores to [0, 1] range
        """
        self.device = device
        self.max_length = max_length
        self.normalize_scores = normalize_scores
        self.cache_dir = cache_dir
    
    @abstractmethod
    def compute_reward(self, queries: List[str], responses: List[str], denormalize_scores: bool = False) -> torch.Tensor:
        """
        Compute rewards for a batch of query-response pairs.

        Args:
            queries: List of original queries/prompts
            responses: List of responses to evaluate

        Returns:
            Tensor of rewards with shape (batch_size,)
        """
        pass

    @abstractmethod
    def group_compute_reward(self, queries: List[str], responses_list: List[List[str]], denormalize_scores: bool = False) -> torch.Tensor:
        """
        Compute rewards for groups of responses (trajectories) for each query.

        Args:
            queries: List of original queries/prompts
            responses_list: List of lists, where each inner list contains multiple responses for a query
            denormalize_scores: Whether to return denormalized scores

        Returns:
            Tensor of rewards with shape (total_responses,) flattened across all queries
        """
        pass

    def _normalize_score(self, score: float, min_val: float = 1.0, max_val: float = 10.0) -> float:
        """
        Normalize a score to [0, 1] range.

        Args:
            score: Raw score to normalize
            min_val: Minimum value of the score range
            max_val: Maximum value of the score range

        Returns:
            float: Normalized score between 0 and 1
        """
        normalized = (score - min_val) / (max_val - min_val)
        return max(0.0, min(1.0, normalized))  # Clamp to [0, 1]
    
    def save_config(self, config_path: str):
        """
        Save configuration to a file.
        Default implementation saves basic attributes.
        """
        config = {
            "device": self.device,
            "max_length": self.max_length,
            "class_name": self.__class__.__name__
        }
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
    
    @classmethod
    def from_config(cls, config_path: str, **kwargs):
        """
        Load configuration from a file.
        Should be overridden by subclasses for custom loading.
        """
        with open(config_path, 'r') as f:
            config = json.load(f)
        config.update(kwargs)  # Allow override of config values
        return cls(**config)


class LLMRewardFunction(RewardFunction):
    """
    Reward function that uses an LLM to score responses based on multiple objectives.
    Uses ObjectiveScorer for individual objective scoring.
    """
    
    def __init__(
        self,
        model_name: str = "meta-llama/Llama-3.1-8B-Instruct",
        use_api: bool = False,
        combiner_type: str = 'linear',
        manual_weights: Dict[str, float] = None,
        manual_bias: float = 0.0,
        device: str = "auto",
        max_length: int = 4096,
        reward_combiner: RewardCombiner = None,
        objective_names: List[str] = None,
        init_args=None,
        dataset_type: str = "hh",  # "hh" for HH-RLHF, "tldr" for Reddit TLDR
        use_detailed_rubric: bool = False,  # Whether to use detailed scoring rubrics
        normalize_scores: bool = False,  # Whether to normalize scores to [0, 1]
        cache_dir: str = None,
        save_dir: str = None,
        max_concurrent: int = 50,  # Maximum concurrent API calls for async scoring
        **combiner_kwargs
    ):
        """
        Initialize the LLM-based reward function.
        
        Args:
            model_name: Name or path to the LLM model (or OpenAI model if use_api=True)
            use_api: Whether to use OpenAI API instead of local model
            reward_combiner: Pre-configured RewardCombiner instance
            combiner_type: Type of combiner ('linear', 'linear_regression', etc.)
            manual_weights: For 'linear' type, dictionary of manual weights
            manual_bias: For 'linear' type, bias term
            device: Device to run the model on
            max_length: Maximum length for tokenized input
            objective_names: List of objectives to evaluate
            init_args: Initialization arguments for model loading (passed to ObjectiveScorer)
            dataset_type: Type of dataset ('hh' for HH-RLHF, 'tldr' for Reddit TLDR)
            use_detailed_rubric: Whether to use detailed scoring rubrics for objectives
            max_concurrent: Maximum concurrent API calls for async scoring (default: 50)
            **combiner_kwargs: Additional arguments for combiner
        """
        super().__init__(device=device, max_length=max_length, normalize_scores=normalize_scores, cache_dir=cache_dir)

        # Store max_concurrent for async scoring
        self.max_concurrent = max_concurrent

        # Set objectives
        self.objective_names = objective_names
        
        # Initialize ObjectiveScorer to handle scoring
        self.scorer = ObjectiveScorer(
            use_detailed_rubric=use_detailed_rubric,
            dataset_type=dataset_type,
            use_api=use_api,
            model_name=model_name,
            device=device,
            max_length=max_length,
            load_quantized=True,  # Use quantization if init_args provided
            cache_dir=cache_dir,
            save_dir=save_dir  # Custom rewards doesn't have a save_dir by default
        )
        
        # Initialize reward combiner
        if reward_combiner is not None:
            self.reward_combiner = reward_combiner
        else:
            # Default weights if needed
            if combiner_type == 'linear' and manual_weights is None:
                manual_weights = {obj: 1.0 / len(self.objective_names) 
                                for obj in self.objective_names}
            
            self.reward_combiner = create_reward_combiner(
                combiner_type=combiner_type,
                objective_names=self.objective_names,
                manual_weights=manual_weights,
                manual_bias=manual_bias,
                **combiner_kwargs
            )
    
    def _score_single_objective(self, query: str, response: str, objective: str) -> float:
        """
        Score a single response for one objective using ObjectiveScorer.

        Args:
            query: The input query/prompt
            response: The response to evaluate
            objective: The objective to score against

        Returns:
            Score between 1 and 10, normalized score between 0 and 1 (or same as score if normalization disabled)
        """
        # Delegate to ObjectiveScorer
        score = self.scorer.score_single_objective(query, response, objective)
        normalized_score = score
        # Normalize if requested
        if self.normalize_scores:
            normalized_score = self._normalize_score(score, min_val=1.0, max_val=10.0)
        return score, normalized_score

    def _group_score_single_objective(self, query: str, responses: List[str], objective: str) -> Tuple[List[float], List[float]]:
        """
        Score a group of responses for one objective using ObjectiveScorer's group scoring.

        Args:
            query: The input query/prompt
            responses: List of responses to evaluate
            objective: The objective to score against

        Returns:
            Tuple of (scores, normalized_scores) where:
            - scores: Original scores from group_score_single_objective
            - normalized_scores: Normalized scores if self.normalize_scores else same as scores
        """
        # Delegate to ObjectiveScorer's group scoring
        scores = self.scorer.group_score_single_objective(query, responses, objective)

        normalized_scores = []
        for score in scores:
            normalized_score = score
            # Normalize if requested (similar to _score_single_objective)
            if self.normalize_scores:
                normalized_score = self._normalize_score(score, min_val=1.0, max_val=10.0)
            normalized_scores.append(normalized_score)

        return scores, normalized_scores

    async def _async_group_score_single_objective(self, query: str, responses: List[str], objective: str) -> Tuple[List[float], List[float]]:
        """
        Async version: Score a group of responses for one objective.

        Args:
            query: The input query/prompt
            responses: List of responses to evaluate
            objective: The objective to score against

        Returns:
            Tuple of (scores, normalized_scores)
        """
        # Delegate to ObjectiveScorer's async group scoring
        scores = await self.scorer.async_group_score_single_objective(query, responses, objective)

        normalized_scores = []
        for score in scores:
            normalized_score = score
            if self.normalize_scores:
                normalized_score = self._normalize_score(score, min_val=1.0, max_val=10.0)
            normalized_scores.append(normalized_score)

        return scores, normalized_scores

    async def _async_score_single_objective(self, query: str, response: str, objective: str) -> Tuple[float, float]:
        """
        Async version: Score a single response for one objective.

        Args:
            query: The input query/prompt
            response: The response to evaluate
            objective: The objective to score against

        Returns:
            Tuple of (score, normalized_score)
        """
        # Delegate to ObjectiveScorer's async single scoring
        score = await self.scorer.async_score_single_objective(query, response, objective)

        normalized_score = score
        if self.normalize_scores:
            normalized_score = self._normalize_score(score, min_val=1.0, max_val=10.0)

        return score, normalized_score

    def compute_reward(self, queries: List[str], responses: List[str], denormalize_scores: bool = False) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Compute rewards for a batch of query-response pairs.

        Automatically uses async parallel scoring if available for better performance.

        Args:
            queries: List of queries/prompts
            responses: List of responses to evaluate
            denormalize_scores: Unused, kept for backward compatibility

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, objective_scores_list)
        """
        print('Using compute_reward')
        # Check if async is available - if so, use the parallel version
        use_async = hasattr(self.scorer, 'async_client') and self.scorer.use_api
        if use_async:
            return self._compute_reward_async_impl(queries, responses, denormalize_scores)

        # Get objectives with non-zero weights (optimization for linear regression)
        active_objectives = set(self.reward_combiner.get_active_objectives())

        # Sync version: Return both normalized and denormalized rewards
        batch_size = len(queries)
        normalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        denormalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        objective_scores_list = []

        for i, (query, response) in enumerate(zip(queries, responses)):
            # Score each objective (skip inactive ones with zero coefficients)
            norm_objective_scores = {}
            denorm_objective_scores = {}
            for objective in self.objective_names:
                if objective in active_objectives:
                    score, normalized_score = self._score_single_objective(query, response, objective)
                    denorm_objective_scores[objective] = score
                    norm_objective_scores[objective] = normalized_score
                else:
                    # Skip scoring for objectives with zero coefficients
                    denorm_objective_scores[objective] = 0.0
                    norm_objective_scores[objective] = 0.0

            # Combine scores
            norm_combined = self.reward_combiner.combine_rewards(norm_objective_scores)
            denorm_combined = self.reward_combiner.combine_rewards(denorm_objective_scores)
            normalized_rewards[i] = norm_combined
            denormalized_rewards[i] = denorm_combined
            objective_scores_list.append(norm_objective_scores)

        return normalized_rewards, denormalized_rewards, objective_scores_list

    def _compute_reward_async_impl(
        self,
        queries: List[str],
        responses: List[str],
        denormalize_scores: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Async implementation: Compute rewards using parallel scoring.

        Parallelizes ALL (query_idx, objective) combinations simultaneously.

        Args:
            queries: List of queries/prompts
            responses: List of responses to evaluate
            denormalize_scores: Unused, kept for backward compatibility

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, objective_scores_list)
        """
        batch_size = len(queries)
        # Get objectives with non-zero weights (optimization for linear regression)
        active_objectives = set(self.reward_combiner.get_active_objectives())
        print('Active objectives for scoring:', active_objectives)

        async def run_all_scoring_parallel():
            """Run ALL (query, objective) scoring tasks in parallel."""
            semaphore = asyncio.Semaphore(self.max_concurrent)

            async def score_with_semaphore(idx: int, obj_idx: int, query: str, response: str, objective: str):
                """Score a single (query, objective) pair with rate limiting."""
                async with semaphore:
                    score, normalized_score = await self._async_score_single_objective(query, response, objective)
                    return idx, obj_idx, score, normalized_score

            # Create scoring tasks only for active objectives
            tasks = []
            for idx, (query, response) in enumerate(zip(queries, responses)):
                for obj_idx, objective in enumerate(self.objective_names):
                    if objective in active_objectives:
                        task = score_with_semaphore(idx, obj_idx, query, response, objective)
                        tasks.append(task)

            # Run ALL tasks in parallel (rate-limited by semaphore)
            return await asyncio.gather(*tasks)

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            import nest_asyncio
            nest_asyncio.apply()
            results = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            results = asyncio.run(run_all_scoring_parallel())

        # Organize results: scores_by_idx[idx][objective] = (score, normalized_score)
        scores_by_idx = [{} for _ in range(batch_size)]
        for idx, obj_idx, score, normalized_score in results:
            objective = self.objective_names[obj_idx]
            scores_by_idx[idx][objective] = (score, normalized_score)

        # Assemble rewards
        normalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        denormalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        objective_scores_list = []

        for idx in range(batch_size):
            norm_objective_scores = {}
            denorm_objective_scores = {}
            for objective in self.objective_names:
                if objective in active_objectives:
                    score, normalized_score = scores_by_idx[idx][objective]
                    denorm_objective_scores[objective] = score
                    norm_objective_scores[objective] = normalized_score
                else:
                    # Skip scoring for objectives with zero coefficients
                    denorm_objective_scores[objective] = 0.0
                    norm_objective_scores[objective] = 0.0

            # Combine scores
            norm_combined = self.reward_combiner.combine_rewards(norm_objective_scores)
            denorm_combined = self.reward_combiner.combine_rewards(denorm_objective_scores)
            normalized_rewards[idx] = norm_combined
            denormalized_rewards[idx] = denorm_combined
            objective_scores_list.append(norm_objective_scores)

        return normalized_rewards, denormalized_rewards, objective_scores_list

    def group_compute_reward(self, queries: List[str], responses_list: List[List[str]], denormalize_scores: bool = False) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Compute rewards for groups of responses using group scoring.

        Automatically uses async parallel scoring if available for better performance.

        Args:
            queries: List of queries/prompts
            responses_list: List of lists, where each inner list contains responses for a query
            denormalize_scores: Unused, kept for backward compatibility

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, objective_scores_list)
        """
        # Check if async is available - if so, use the parallel version
        print('Using group_compute_reward')
        use_async = hasattr(self.scorer, 'async_client') and self.scorer.use_api
        if use_async:
            return self._group_compute_reward_async_impl(queries, responses_list, denormalize_scores)

        # OLD: Single reward return based on denormalize_scores flag
        # all_rewards = []
        # all_objective_scores = []
        # for query, responses in zip(queries, responses_list):
        #     num_responses = len(responses)
        #     rewards_for_query = torch.zeros(num_responses, dtype=torch.bfloat16)
        #     objective_scores_per_response = [{} for _ in range(num_responses)]
        #     for objective in self.objective_names:
        #         scores, normalized_scores = self._group_score_single_objective(query, responses, objective)
        #         if denormalize_scores:
        #             for i, score in enumerate(scores):
        #                 objective_scores_per_response[i][objective] = score
        #         else:
        #             for i, score in enumerate(normalized_scores):
        #                 objective_scores_per_response[i][objective] = score
        #     for i, objective_scores in enumerate(objective_scores_per_response):
        #         combined_score = self.reward_combiner.combine_rewards(objective_scores)
        #         rewards_for_query[i] = combined_score
        #         all_objective_scores.append(objective_scores)
        #     all_rewards.append(rewards_for_query)
        # combined_rewards = torch.cat(all_rewards)
        # return combined_rewards, all_objective_scores

        # Sync version: Return both normalized and denormalized rewards
        all_normalized_rewards = []
        all_denormalized_rewards = []
        all_objective_scores = []

        for query, responses in zip(queries, responses_list):
            num_responses = len(responses)
            normalized_rewards_for_query = torch.zeros(num_responses, dtype=torch.bfloat16)
            denormalized_rewards_for_query = torch.zeros(num_responses, dtype=torch.bfloat16)

            # Collect scores for all objectives using group scoring
            # Store both normalized and denormalized scores per response
            norm_objective_scores_per_response = [{} for _ in range(num_responses)]
            denorm_objective_scores_per_response = [{} for _ in range(num_responses)]

            for objective in self.objective_names:
                scores, normalized_scores = self._group_score_single_objective(query, responses, objective)
                for i in range(num_responses):
                    denorm_objective_scores_per_response[i][objective] = scores[i]
                    norm_objective_scores_per_response[i][objective] = normalized_scores[i]

            # Combine scores for each response (both normalized and denormalized)
            for i in range(num_responses):
                norm_combined = self.reward_combiner.combine_rewards(norm_objective_scores_per_response[i])
                denorm_combined = self.reward_combiner.combine_rewards(denorm_objective_scores_per_response[i])
                normalized_rewards_for_query[i] = norm_combined
                denormalized_rewards_for_query[i] = denorm_combined
                all_objective_scores.append(norm_objective_scores_per_response[i])

            all_normalized_rewards.append(normalized_rewards_for_query)
            all_denormalized_rewards.append(denormalized_rewards_for_query)

        combined_normalized = torch.cat(all_normalized_rewards)
        combined_denormalized = torch.cat(all_denormalized_rewards)
        return combined_normalized, combined_denormalized, all_objective_scores

    def _group_compute_reward_async_impl(
        self,
        queries: List[str],
        responses_list: List[List[str]],
        denormalize_scores: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
        """
        Async implementation: Compute rewards using parallel scoring.

        Parallelizes ALL (query_idx, objective) combinations simultaneously for maximum efficiency.
        Called by group_compute_reward() when async is available.

        Args:
            queries: List of queries/prompts
            responses_list: List of lists, where each inner list contains responses for a query
            denormalize_scores: Unused, kept for backward compatibility

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, objective_scores_list)
        """
        num_queries = len(queries)
        num_objectives = len(self.objective_names)
        total_tasks = num_queries * num_objectives

        async def run_all_scoring_parallel():
            """Run ALL (query, objective) scoring tasks in parallel."""
            semaphore = asyncio.Semaphore(self.max_concurrent)

            async def score_with_semaphore(query_idx: int, obj_idx: int, query: str, responses: List[str], objective: str):
                """Score a single (query, objective) pair with rate limiting."""
                async with semaphore:
                    scores, normalized_scores = await self._async_group_score_single_objective(query, responses, objective)
                    return query_idx, obj_idx, scores, normalized_scores

            # Create ALL scoring tasks for all (query, objective) combinations
            tasks = []
            for query_idx, (query, responses) in enumerate(zip(queries, responses_list)):
                for obj_idx, objective in enumerate(self.objective_names):
                    task = score_with_semaphore(query_idx, obj_idx, query, responses, objective)
                    tasks.append(task)

            # Run ALL tasks in parallel (rate-limited by semaphore)
            results = await asyncio.gather(*tasks)

            # Reorganize results by query_idx
            # scores_by_query[query_idx][objective] = (scores, normalized_scores)
            scores_by_query = [{} for _ in range(num_queries)]
            for query_idx, obj_idx, scores, normalized_scores in results:
                objective = self.objective_names[obj_idx]
                scores_by_query[query_idx][objective] = (scores, normalized_scores)

            return scores_by_query

        # Run the async scoring
        try:
            loop = asyncio.get_running_loop()
        except RuntimeError:
            loop = None

        if loop and loop.is_running():
            # Already in an async context (e.g., Jupyter notebook)
            import nest_asyncio
            nest_asyncio.apply()
            scores_by_query = asyncio.get_event_loop().run_until_complete(run_all_scoring_parallel())
        else:
            # Normal Python - just run
            scores_by_query = asyncio.run(run_all_scoring_parallel())

        # Assemble results
        all_normalized_rewards = []
        all_denormalized_rewards = []
        all_objective_scores = []

        for query_idx, responses in enumerate(responses_list):
            num_responses = len(responses)
            normalized_rewards_for_query = torch.zeros(num_responses, dtype=torch.bfloat16)
            denormalized_rewards_for_query = torch.zeros(num_responses, dtype=torch.bfloat16)

            # Build objective scores per response from the async results
            norm_objective_scores_per_response = [{} for _ in range(num_responses)]
            denorm_objective_scores_per_response = [{} for _ in range(num_responses)]

            for objective in self.objective_names:
                scores, normalized_scores = scores_by_query[query_idx][objective]
                for i in range(num_responses):
                    denorm_objective_scores_per_response[i][objective] = scores[i]
                    norm_objective_scores_per_response[i][objective] = normalized_scores[i]

            # Combine scores for each response
            for i in range(num_responses):
                norm_combined = self.reward_combiner.combine_rewards(norm_objective_scores_per_response[i])
                denorm_combined = self.reward_combiner.combine_rewards(denorm_objective_scores_per_response[i])
                normalized_rewards_for_query[i] = norm_combined
                denormalized_rewards_for_query[i] = denorm_combined
                all_objective_scores.append(norm_objective_scores_per_response[i])

            all_normalized_rewards.append(normalized_rewards_for_query)
            all_denormalized_rewards.append(denormalized_rewards_for_query)

        combined_normalized = torch.cat(all_normalized_rewards)
        combined_denormalized = torch.cat(all_denormalized_rewards)
        return combined_normalized, combined_denormalized, all_objective_scores

    def save_config(self, config_path: str):
        """Save configuration including scorer settings."""
        config = {
            "model_name": self.scorer.model_name,
            "use_api": self.scorer.use_api,
            "dataset_type": self.scorer.dataset_type,
            "use_detailed_rubric": self.scorer.use_detailed_rubric,
            "objective_names": self.objective_names,
            "device": self.scorer.device,
            "max_length": self.scorer.max_length,
            "class_name": self.__class__.__name__
        }
        
        # Save reward combiner separately
        combiner_path = config_path.replace('.json', '_combiner')
        self.reward_combiner.save(combiner_path)
        
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)


class RewardModelFunction(RewardFunction):
    """
    Reward function using a trained reward model from HuggingFace.
    These models typically output a scalar reward value directly.
    """
    
    def __init__(
        self,
        model_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
        device: str = "auto",
        max_length: int = 512,
        use_quantization: bool = False,
        init_args=None,
        normalize_scores: bool = False,
        cache_dir: str = None
    ):
        """
        Initialize reward model function.
        
        Args:
            model_name: HuggingFace model name or local path
            device: Device to run the model on
            max_length: Maximum sequence length
            use_quantization: Whether to use 4-bit quantization
            init_args: Additional initialization arguments
            normalize_scores: Whether to normalize scores to [0, 1] using sigmoid
        """
        super().__init__(device=device, max_length=max_length, normalize_scores=normalize_scores, cache_dir=cache_dir)

        self.model_name = model_name
        
        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        # if self.tokenizer.pad_token is None:
        #     self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Load model - try sequence classification first, fallback to causal LM
        try:
            if use_quantization:
                # Load with quantization
                torch_dtype_map = {
                    'bfloat16': torch.bfloat16,
                    'float16': torch.float16,
                    'float32': torch.float32
                }
                compute_dtype = torch_dtype_map.get(
                    'bfloat16', torch.bfloat16
                )
                
                bnb_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type='nf4',
                    bnb_4bit_compute_dtype=compute_dtype
                )
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    model_name,
                    # revision='main' if model_name not in ['allenai/Llama-3.1-8B-Instruct-RM-RB2'] else '7',
                    device_map=device,
                    torch_dtype=torch.bfloat16,
                    quantization_config=bnb_config,
                    trust_remote_code=True,
                    use_safetensors=True,
                    num_labels=1
                )
            else:
                # Load without quantization
                self.model = AutoModelForSequenceClassification.from_pretrained(
                    model_name,
                    device_map=device,
                    torch_dtype=torch.bfloat16,
                    trust_remote_code=True,
                    use_safetensors=True,
                    num_labels=1
                )
            self.model_type = "classification"
        except Exception as e:
            print(e)
            print("Falling back to causal LM with reward head...")
            # Fallback to causal LM with reward head
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map=device,
                torch_dtype=torch.bfloat16,
                trust_remote_code=True,
                # use_safetensors=True
            )
            self.model_type = "causal"
        
        # Set to eval mode
        self.model.eval()
    
    def _score_single_objective(self, query: str, response: str, objective: str = None) -> float:
        """
        Score a single response using the reward model.
        Note: objective parameter is ignored for reward models as they output a single score.
        
        Args:
            query: The query/prompt
            response: The response to evaluate
            objective: Ignored for reward models (kept for interface consistency)
            
        Returns:
            Float reward score
        """
        # Format input (query + response)
        # text = f"{query} {response}"
        
        # Tokenize
        # inputs = self.tokenizer(
        #     text,
        #     return_tensors="pt",
        #     truncation=True,
        #     max_length=self.max_length,
        #     padding=True
        # )
        if 'Skywork' in self.model_name:
            text = [{"role": "user", "content": query}, {"role": "assistant", "content": response}]
            text_formatted = self.tokenizer.apply_chat_template(text, tokenize=False)
            if self.tokenizer.bos_token is not None and text_formatted.startswith(self.tokenizer.bos_token):
                text_formatted = text_formatted[len(self.tokenizer.bos_token):]
            inputs = self.tokenizer(text_formatted, return_tensors="pt").to(self.model.device)
        elif 'deberta' in self.model_name:
            print('Using deberta tokenizer')
            inputs = self.tokenizer(query, response, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
        elif 'gpt2-large-helpful' in self.model_name or 'Ray2333' in self.model_name:
            # Ray2333/gpt2-large-helpful-reward_model expects Human/Assistant format
            # with query and response as separate tokenizer arguments
            # First, strip any existing User:/Assistant: tags from the query
            clean_query = re.sub(r'^(User|Assistant):\s*', '', query.strip())
            clean_query = re.sub(r'\n\n(User|Assistant):\s*', '\n\n', clean_query)
            clean_query = clean_query.strip()
            clean_response = response.strip()
            # Format with Human/Assistant pattern expected by this model
            formatted_query = f"\n\nHuman: {clean_query} \n\nAssistant:"
            inputs = self.tokenizer(formatted_query, clean_response, return_tensors='pt', truncation=True, max_length=self.max_length)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        else:
            text = [{"role": "user", "content": query}, {"role": "assistant", "content": response}]
            text_formatted = self.tokenizer.apply_chat_template(text, tokenize=False)
            if self.tokenizer.bos_token is not None and text_formatted.startswith(self.tokenizer.bos_token):
                text_formatted = text_formatted[len(self.tokenizer.bos_token):]
            inputs = self.tokenizer(text_formatted, return_tensors="pt").to(self.model.device)

        # inputs = self.tokenizer(
        #     query,
        #     response,
        #     return_tensors="pt"
        # )
        # inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

        with torch.no_grad():
            if self.model_type == "classification":
                # For sequence classification models
                outputs = self.model(**inputs)
                # Get the reward score (usually the first logit)
                reward = None
                if outputs.logits.shape[0] == 1:
                    reward = outputs.logits[0]
                elif outputs.logits.shape[1] == 1:
                    reward = outputs.logits[0, 0]
                else:
                    reward = outputs.logits[0].max()
                # reward = outputs.logits[0, 0] if outputs.logits.shape[1] == 1 else outputs.logits[0].max()
            else:
                # For causal LM with reward head
                outputs = self.model(**inputs, output_hidden_states=True)
                # Try to get reward from model-specific output
                if hasattr(outputs, 'rewards'):
                    reward = outputs.rewards[0]
                elif hasattr(outputs, 'score'):
                    reward = outputs.score[0]
                else:
                    # Fallback: use mean of last hidden state
                    hidden = outputs.hidden_states[-1]
                    reward = hidden[:, -1, :].mean()

        score = float(reward.cpu())
        breakpoint()
        normalized_score = score

        # Apply sigmoid normalization if requested (for reward models)
        if self.normalize_scores:
            # Apply sigmoid to map arbitrary reward model output to [0, 1]
            import math
            # score = 1.0 / (1.0 + math.exp(-score))
            normalized_score = 1.0 / (1.0 + math.exp(-score))

        return score, normalized_score

    def _batch_score_single_objective(
        self,
        queries: List[str],
        responses: List[str],
        objective: str = None,
        batch_size: int = 8
    ) -> Tuple[List[float], List[float]]:
        """
        Score multiple query-response pairs in batches using the reward model.

        This is more efficient than calling _score_single_objective() repeatedly
        because it batches the model forward passes.

        Args:
            queries: List of queries/prompts
            responses: List of responses to evaluate
            objective: Ignored for reward models (kept for interface consistency)
            batch_size: Number of samples to process in each forward pass

        Returns:
            Tuple of (raw_scores, normalized_scores) as lists
        """
        all_scores = []
        all_normalized_scores = []

        num_samples = len(queries)

        # Ensure tokenizer has pad token and model config is updated
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        if ('allenai/Llama-3.1-8B-Instruct-RM-RB2' in self.model_name or 'gpt2-large-helpful' in self.model_name or 'Ray2333' in self.model_name) and self.model.config.pad_token_id is None:
            self.model.config.pad_token_id = self.tokenizer.pad_token_id

        for batch_start in range(0, num_samples, batch_size):
            batch_end = min(batch_start + batch_size, num_samples)
            batch_queries = queries[batch_start:batch_end]
            batch_responses = responses[batch_start:batch_end]

            # Format inputs based on model type
            if 'Skywork' in self.model_name:
                texts_formatted = []
                for query, response in zip(batch_queries, batch_responses):
                    # Remove "User:" and "Assistant:" tags from query (maybe delete in future, if not necessary)
                    query = re.sub(r'^(User|Assistant):\s*', '', query.strip())
                    query = re.sub(r'\n\n(User|Assistant):\s*', '', query)
                    response = response.strip()
                    # End remove "User:" and "Assistant:" tags
                    text = [{"role": "user", "content": query}, {"role": "assistant", "content": response}]
                    text_formatted = self.tokenizer.apply_chat_template(text, tokenize=False)
                    if self.tokenizer.bos_token is not None and text_formatted.startswith(self.tokenizer.bos_token):
                        text_formatted = text_formatted[len(self.tokenizer.bos_token):]
                    texts_formatted.append(text_formatted)
                inputs = self.tokenizer(
                    texts_formatted,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=self.max_length
                )
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            elif 'deberta' in self.model_name:
                print('Using deberta tokenizer')
                inputs = self.tokenizer(
                    batch_queries,
                    batch_responses,
                    return_tensors="pt",
                    truncation=True,
                    max_length=self.max_length,
                    padding=True
                )
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            elif 'gpt2-large-helpful' in self.model_name or 'Ray2333' in self.model_name:
                # Ray2333/gpt2-large-helpful-reward_model expects Human/Assistant format
                # with query and response as separate tokenizer arguments
                formatted_queries = []
                cleaned_responses = []
                for query, response in zip(batch_queries, batch_responses):
                    # Strip any existing User:/Assistant: tags from the query
                    clean_query = re.sub(r'^(User|Assistant):\s*', '', query.strip())
                    clean_query = re.sub(r'\n\n(User|Assistant):\s*', '\n\n', clean_query)
                    clean_query = clean_query.strip()
                    # Format with Human/Assistant pattern expected by this model
                    formatted_query = f"\n\nHuman: {clean_query} \n\nAssistant:"
                    formatted_queries.append(formatted_query)
                    cleaned_responses.append(response.strip())
                inputs = self.tokenizer(
                    formatted_queries,
                    cleaned_responses,
                    return_tensors="pt",
                    truncation=True,
                    max_length=self.max_length,
                    padding=True
                )
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            else:
                texts_formatted = []
                for query, response in zip(batch_queries, batch_responses):
                    text = [{"role": "user", "content": query}, {"role": "assistant", "content": response}]
                    text_formatted = self.tokenizer.apply_chat_template(text, tokenize=False)
                    if self.tokenizer.bos_token is not None and text_formatted.startswith(self.tokenizer.bos_token):
                        text_formatted = text_formatted[len(self.tokenizer.bos_token):]
                    texts_formatted.append(text_formatted)
                inputs = self.tokenizer(
                    texts_formatted,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=self.max_length
                )
                inputs = {k: v.to(self.model.device) for k, v in inputs.items()}

            with torch.no_grad():
                if self.model_type == "classification":
                    outputs = self.model(**inputs)
                    # Get rewards for each item in batch
                    # logits shape: [batch_size, num_labels] or [batch_size] or [batch_size, 1]
                    if outputs.logits.dim() == 1:
                        rewards = outputs.logits
                    elif outputs.logits.shape[-1] == 1:
                        rewards = outputs.logits.squeeze(-1)
                    else:
                        # Take max across labels if multiple outputs
                        rewards = outputs.logits.max(dim=-1).values
                else:
                    # For causal LM with reward head
                    outputs = self.model(**inputs, output_hidden_states=True)
                    if hasattr(outputs, 'rewards'):
                        rewards = outputs.rewards
                        if rewards.dim() > 1:
                            rewards = rewards.squeeze(-1)
                    elif hasattr(outputs, 'score'):
                        rewards = outputs.score
                        if rewards.dim() > 1:
                            rewards = rewards.squeeze(-1)
                    else:
                        # Fallback: use last token's hidden state mean
                        hidden = outputs.hidden_states[-1]  # [batch, seq_len, hidden_dim]
                        # Handle padding: find actual last token for each sequence
                        if 'attention_mask' in inputs:
                            attention_mask = inputs['attention_mask']
                            # Get index of last non-pad token for each sequence
                            seq_lengths = attention_mask.sum(dim=1) - 1  # [batch_size]
                            # Gather last hidden state for each sequence
                            batch_indices = torch.arange(hidden.shape[0], device=hidden.device)
                            last_hidden = hidden[batch_indices, seq_lengths, :]  # [batch, hidden_dim]
                            rewards = last_hidden.mean(dim=-1)  # [batch_size]
                        else:
                            rewards = hidden[:, -1, :].mean(dim=-1)

            # Convert to float and normalize
            for reward in rewards:
                score = float(reward.cpu())
                normalized_score = score

                if self.normalize_scores:
                    normalized_score = 1.0 / (1.0 + math.exp(-score))

                all_scores.append(score)
                all_normalized_scores.append(normalized_score)

        return all_scores, all_normalized_scores

    def compute_reward(self, queries: List[str], responses: List[str], denormalize_scores: bool = False, batch_size: int = 8) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """
        Compute rewards using the trained reward model.

        Args:
            queries: List of queries/prompts
            responses: List of responses to evaluate
            denormalize_scores: Unused, kept for backward compatibility
            batch_size: Batch size for model forward passes

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, None)
        """
        # OLD: Sequential scoring (kept for reference)
        # batch_size = len(queries)
        # normalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        # denormalized_rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        # for i, (query, response) in enumerate(zip(queries, responses)):
        #     score, normalized_score = self._score_single_objective(query, response, objective=None)
        #     denormalized_rewards[i] = score
        #     normalized_rewards[i] = normalized_score
        # normalized_rewards = normalized_rewards.to(dtype=torch.bfloat16)
        # denormalized_rewards = denormalized_rewards.to(dtype=torch.bfloat16)
        # return normalized_rewards, denormalized_rewards, None
        print('Using compute_reward for RewardModelFunction')
        # NEW: Use batched scoring for efficiency
        denorm_scores, norm_scores = self._batch_score_single_objective(
            queries, responses, objective=None, batch_size=batch_size
        )

        normalized_rewards = torch.tensor(norm_scores, dtype=torch.bfloat16)
        denormalized_rewards = torch.tensor(denorm_scores, dtype=torch.bfloat16)

        return normalized_rewards, denormalized_rewards, None

    def group_compute_reward(self, queries: List[str], responses_list: List[List[str]], denormalize_scores: bool = False) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """
        Compute rewards for groups of responses.

        Flattens all query-response pairs and computes rewards in a single batched call
        for maximum efficiency with the reward model.

        Returns:
            Tuple of (normalized_rewards, denormalized_rewards, None)
        """
        print('Using group_compute_reward for RewardModelFunction')
        # OLD: Sequential per-query calls (less efficient)
        # all_normalized = []
        # all_denormalized = []
        # for query, responses in zip(queries, responses_list):
        #     repeated_queries = [query] * len(responses)
        #     norm_rewards, denorm_rewards, _ = self.compute_reward(repeated_queries, responses)
        #     all_normalized.append(norm_rewards)
        #     all_denormalized.append(denorm_rewards)
        # combined_normalized = torch.cat(all_normalized)
        # combined_denormalized = torch.cat(all_denormalized)
        # return combined_normalized, combined_denormalized, None

        # NEW: Flatten all queries and responses for single batched call
        all_queries_flat = []
        all_responses_flat = []
        for query, responses in zip(queries, responses_list):
            for response in responses:
                all_queries_flat.append(query)
                all_responses_flat.append(response)

        # Single batched call for all query-response pairs
        norm_rewards, denorm_rewards, _ = self.compute_reward(all_queries_flat, all_responses_flat)

        # Results are already in flattened order matching the expected output
        return norm_rewards, denorm_rewards, None

    def save_config(self, config_path: str):
        """Save reward model configuration."""
        config = {
            "model_name": self.model_name,
            "device": self.device,
            "max_length": self.max_length,
            "model_type": self.model_type,
            "class_name": self.__class__.__name__
        }
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)


class DPORewardFunction(RewardFunction):
    """
    Reward function for Direct Preference Optimization (DPO).
    This is a skeleton implementation to be filled in later.
    
    DPO typically uses implicit rewards derived from the policy and reference policy
    log probabilities rather than an explicit reward model.
    """
    
    def __init__(
        self,
        policy_model_name: str,
        reference_model_name: str,
        beta: float = 0.1,
        device: str = "auto",
        max_length: int = 512,
        cache_dir: str = None
    ):
        """
        Initialize DPO reward function.
        
        Args:
            policy_model_name: Name/path of the policy model being trained
            reference_model_name: Name/path of the reference model
            beta: Temperature parameter for DPO
            device: Device to run models on
            max_length: Maximum sequence length
        """
        super().__init__(device=device, max_length=max_length, normalize_scores=False, cache_dir=cache_dir)

        self.policy_model_name = policy_model_name
        self.reference_model_name = reference_model_name
        self.beta = beta
        
        # TODO: Load policy and reference models
        # self.policy_model = ...
        # self.reference_model = ...
        # self.tokenizer = ...
        
        raise NotImplementedError("DPORewardFunction is not yet implemented")
    
    def _score_single_objective(self, query: str, response: str, objective: str = None) -> float:
        """
        Score a single response using DPO implicit reward.
        Note: objective parameter is ignored for DPO as it computes implicit rewards.
        
        Args:
            query: The query/prompt
            response: The response to evaluate
            objective: Ignored for DPO (kept for interface consistency)
            
        Returns:
            Float implicit reward score: β * log[π(y|x) / π_ref(y|x)]
        """
        # TODO: Implement DPO scoring
        # 1. Tokenize query + response
        # 2. Get log probability from policy model: log π(y|x)
        # 3. Get log probability from reference model: log π_ref(y|x)
        # 4. Return β * (log_prob_policy - log_prob_ref)
        
        raise NotImplementedError("DPO single objective scoring not yet implemented")
    
    def compute_reward(self, queries: List[str], responses: List[str]) -> torch.Tensor:
        """
        Compute DPO implicit rewards.
        
        The DPO reward is typically:
        r(x, y) = β * log[π(y|x) / π_ref(y|x)]
        
        where:
        - π is the policy model
        - π_ref is the reference model
        - β is the temperature parameter
        """
        batch_size = len(queries)
        rewards = torch.zeros(batch_size, dtype=torch.bfloat16)
        
        for i, (query, response) in enumerate(zip(queries, responses)):
            # Use _score_single_objective for consistency
            # DPO doesn't have multiple objectives, so we pass None
            try:
                score = self._score_single_objective(query, response, objective=None)
                rewards[i] = score
            except NotImplementedError:
                # Return placeholder until implementation is complete
                rewards[i] = 0.0
        
        # For now, raise error to indicate incomplete implementation
        raise NotImplementedError("DPO reward computation not yet implemented")
    
    def save_config(self, config_path: str):
        """Save DPO configuration."""
        config = {
            "policy_model_name": self.policy_model_name,
            "reference_model_name": self.reference_model_name,
            "beta": self.beta,
            "device": self.device,
            "max_length": self.max_length,
            "class_name": self.__class__.__name__
        }
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)


# Keep the original CustomRewardModel class for backward compatibility
class ZeroBackbone(nn.Module):
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        h = input_ids.unsqueeze(-1)
        return type("Output", (), {"hidden_states": [h]})


class CustomRewardModel(nn.Module):
    """
    Custom Reward Model wrapper for TRL PPOTrainer compatibility.
    This class wraps any RewardFunction to make it compatible with TRL's interface.
    """
    base_model_prefix = "pretrained_model"
    
    def __init__(
        self,
        reward_function: RewardFunction,
        tokenizer: AutoTokenizer,
        query_response_separator: str = "Assistant:"
    ):
        """
        Initialize the custom reward model.
        
        Args:
            reward_function: Instance of any RewardFunction subclass
            tokenizer: Tokenizer for decoding input_ids
            query_response_separator: String separating query from response
        """
        super().__init__()
        self.reward_function = reward_function
        self.tokenizer = tokenizer
        self.pretrained_model = ZeroBackbone()
        self.query_response_separator = query_response_separator
    
    def _split_query_response(self, text: str) -> Tuple[str, str]:
        """Split text into query and response."""
        if self.query_response_separator in text:
            last_idx = text.rfind(self.query_response_separator)
            query = text[:last_idx + len(self.query_response_separator)].strip()
            response = text[last_idx + len(self.query_response_separator):].strip()
        else:
            query = ""
            response = text.strip()
        return query, response
    
    def score(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Compute per-token scores for TRL compatibility."""
        batch_size, seq_len, _ = hidden_states.shape
        input_ids = hidden_states.squeeze(-1).long()
        
        decoded_texts = self.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        queries, responses = zip(*[self._split_query_response(text) for text in decoded_texts])
        
        raw_scores, _ = self.reward_function.compute_reward(list(queries), list(responses))
        raw_scores = raw_scores.to(dtype=torch.bfloat16, device=hidden_states.device)
        
        # Broadcast to all token positions
        scores_matrix = raw_scores.unsqueeze(1).repeat(1, seq_len)
        return scores_matrix.unsqueeze(-1)
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, **kwargs):
        """TRL-compatible forward method."""
        batch_size, seq_len = input_ids.shape
        
        # Clean input_ids if needed
        if attention_mask is not None:
            clean_input_ids = input_ids.clone()
            clean_input_ids = torch.where(
                attention_mask.bool(),
                input_ids,
                self.tokenizer.pad_token_id
            )
        else:
            clean_input_ids = input_ids
        
        decoded_texts = self.tokenizer.batch_decode(clean_input_ids, skip_special_tokens=True)
        queries, responses = zip(*[self._split_query_response(text) for text in decoded_texts])
        
        scores_tensor, _ = self.reward_function.compute_reward(list(queries), list(responses))
        scores_tensor = scores_tensor.to(dtype=torch.bfloat16, device=input_ids.device)
        
        # Create output tensor with scores at the end
        final_scores = torch.zeros(batch_size, seq_len, 1, dtype=torch.bfloat16, device=input_ids.device)
        
        if attention_mask is not None:
            last_token_indices = attention_mask.sum(1) - 1
        else:
            last_token_indices = torch.full((batch_size,), seq_len - 1, dtype=torch.long, device=input_ids.device)
        
        final_scores[torch.arange(batch_size), last_token_indices, 0] = scores_tensor
        
        return final_scores, scores_tensor, last_token_indices
    
    def save_pretrained(self, save_directory: str):
        """Save reward function configuration."""
        os.makedirs(save_directory, exist_ok=True)
        self.reward_function.save_config(os.path.join(save_directory, "reward_config.json"))
    
    def to(self, device):
        """Handle device placement."""
        super().to(device)
        return self
    
    def eval(self):
        """Set to evaluation mode."""
        super().eval()
        return self
    
    def train(self, mode=True):
        """Set training mode."""
        super().train(mode)
        return self