import os
import torch
import numpy as np
from typing import List, Optional, Union, Tuple, Dict, Any, ClassVar
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from autometrics.metrics.reference_free.ReferenceFreeMetric import ReferenceFreeMetric
from autometrics.metrics.utils.device_utils import get_model_device, ensure_tensor_on_device

class GRMRewardModel(ReferenceFreeMetric):
    """---
# Metric Card for GRMRewardModel

The GRMRewardModel is a general-purpose reward model designed to evaluate the quality and safety of LLM-generated outputs. It achieves high generalization performance by applying a novel regularization method on hidden states during supervised fine-tuning. GRMRewardModel is fine-tuned on the decontaminated Skywork/Skywork-Reward-Preference-80K-v0.2 dataset and achieves state-of-the-art results among models of comparable size (3B), even outperforming some 8B reward models and proprietary LLM judges on RewardBench.

## Metric Details

### Metric Description

The GRMRewardModel is a transformer-based reward model that assigns scalar scores to LLM responses based on their alignment with human preferences. Its main innovation lies in *Hidden State Regularization (HSR)*, a method that regularizes the representation space of hidden states across different examples to improve generalization. Unlike most reward models that rely solely on output logits or fine-tune only the final layer, GRM constrains intermediate hidden states using contrastive learning objectives, enabling more robust preference modeling.

This particular instantiation, `Ray2333/GRM-Llama3.2-3B-rewardmodel-ft`, is based on the Llama-3.2-3B-Instruct model and fine-tuned on the Skywork preference dataset using pairwise comparisons of completions. Given a message (prompt + response), the model outputs a scalar reward score indicating the desirability of the response.

- **Metric Type:** Reference-Free  
- **Range:** Unbounded (typically scaled to [−5, 5] depending on implementation)  
- **Higher is Better?:** Yes  
- **Reference-Based?:** No  
- **Input-Required?:** Yes

### Formal Definition

Let $x$ denote the input prompt and $y$ the response generated by a model. The GRMRewardModel, $f_\\theta$, takes the concatenated sequence $(x, y)$ and computes a scalar reward:

$$
\\text{Reward} = f_\\theta(x, y)
$$

The model is trained on pairwise preferences $(x, y^{+}, y^{-})$ using a pairwise loss function:

$$
\\mathcal{L}_{\\text{pairwise}} = -\\log \\sigma(f_\\theta(x, y^{+}) - f_\\theta(x, y^{-}))
$$

Additionally, during training, a hidden state regularization term is applied across layers:

$$
\\mathcal{L}_{\\text{HSR}} = \\sum _{l=1}^{L} \\text{Sim}(h_l^{+}, h_l^{-})
$$

where $h_l^{+}, h_l^{-}$ are the hidden states at layer $l$ for the preferred and dispreferred completions respectively, and $\\text{Sim}$ is a similarity or contrastive loss (e.g., cosine similarity or InfoNCE loss).

### Inputs and Outputs

- **Inputs:**  
  - Tokenized message (prompt + response pair), formatted using the chat template from the underlying Llama tokenizer.  
  - Attention mask.  
  - (During training) Preference pairs.

- **Outputs:**  
  - A scalar reward value representing alignment with human preferences.

## Intended Use

### Domains and Tasks

- **Domain:** Text Generation  
- **Tasks:** Dialogue Generation, Response Generation, Safety Evaluation

### Applicability and Limitations

- **Best Suited For:**  
  Evaluating response quality and safety in conversational agents, especially in RLHF or reranking pipelines where pairwise preferences are available.

- **Not Recommended For:**  
  Evaluating creative generation (e.g., poetry or storytelling) or tasks that require ground-truth references (e.g., translation, summarization).

## Metric Implementation

### Reference Implementations

- **Libraries/Packages:**  
  - [Hugging Face Transformers](https://huggingface.co/Ray2333/GRM-Llama3.2-3B-rewardmodel-ft)  
  - Pretrained model available: `Ray2333/GRM-Llama3.2-3B-rewardmodel-ft`

### Computational Complexity

- **Efficiency:**  
  Inference is comparable to a standard forward pass through a 3B-parameter LLM, efficient for scoring single or batched examples.

- **Scalability:**  
  Scales well to large-scale evaluation or online RLHF pipelines. Small model size (3B) enables practical deployment on modest hardware (e.g., single A100).

## Known Limitations

- **Biases:**  
  Depends on the biases of the preference data (Skywork dataset), which may reflect annotator or cultural preferences.  
  Sensitivity to prompt formatting may affect score stability.

- **Task Misalignment Risks:**  
  Since it is trained on generic preference data, it may not align with task-specific criteria or specialized user needs.

- **Failure Cases:**  
  May output high reward scores for fluent but factually incorrect or manipulative responses, especially outside the training distribution.  
  Does not explicitly verify factual correctness or consistency.

## Related Metrics

- **OpenAI's GPT-Judge**: LLM-as-a-judge paradigm with GPT-4.  
- **DPO/RM models from Anthropic and OpenAI**: Trained using preference data, but often larger in scale and less interpretable.  
- **RRHF Reward Models**: Reward models trained on Reddit or StackExchange upvotes.  
- **UltraRM, Starling-RM**: Competing public reward models evaluated on RewardBench.

## Further Reading

- **Papers:**  
  - Yang et al. (2024), *Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs*, NeurIPS 2024  
    https://openreview.net/forum?id=jwh9MHEfmY

- **Blogs/Tutorials:**  
  - [More Information Needed]

## Citation

For Llama 3:
```
@inproceedings{yang2024regularizing,  
  title={Regularizing Hidden States Enables Learning Generalizable Reward Model for {LLM}s},  
  author={Rui Yang and Ruomeng Ding and Yong Lin and Huan Zhang and Tong Zhang},  
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},  
  year={2024},  
  url={https://openreview.net/forum?id=jwh9MHEfmY}  
}
```

## Metric Card Authors

- **Authors:** ANONYMOUS
- **Acknowledgment of AI Assistance:**  
  Portions of this metric card were drafted with assistance from generative AI. All content has been reviewed and curated by the author to ensure accuracy.  
- **Contact:** ANONYMOUS@example.com"""

    # Resource usage statistics (in megabytes)
    gpu_mem: ClassVar[float] = 6160.84033203125  # in MB
    cpu_mem: ClassVar[float] = 2003.9375  # in MB
    description: ClassVar[str] = "The GRMRewardModel is a general-purpose reward model designed to evaluate the quality and safety of LLM-generated outputs. It achieves high generalization performance by applying a novel regularization method on hidden states during supervised fine-tuning. GRMRewardModel is fine-tuned on the decontaminated Skywork/Skywork-Reward-Preference-80K-v0.2 dataset and achieves state-of-the-art results among models of comparable size (3B), even outperforming some 8B reward models and proprietary LLM judges on RewardBench."

    def __init__(
        self,
        name: str = "GRMRewardModel",
        description: str = "The GRMRewardModel is a general-purpose reward model designed to evaluate the quality and safety of LLM-generated outputs. It achieves high generalization performance by applying a novel regularization method on hidden states during supervised fine-tuning. GRMRewardModel is fine-tuned on the decontaminated Skywork/Skywork-Reward-Preference-80K-v0.2 dataset and achieves state-of-the-art results among models of comparable size (3B), even outperforming some 8B reward models and proprietary LLM judges on RewardBench.",
        model_name: str = "Ray2333/GRM-Llama3.2-3B-rewardmodel-ft",
        torch_dtype = "float16",
        device_map: Union[str, dict] = "auto",
        batch_size: int = 1,
        persistent: bool = True,
        **kwargs
    ):
        super().__init__(name, description, model_name=model_name, torch_dtype=torch_dtype, device_map=device_map, batch_size=batch_size, persistent=persistent, **kwargs)
        self.model_name = model_name
        self.torch_dtype = torch.float16 if torch_dtype == "float16" else torch.bfloat16 if torch_dtype == "bfloat16" else torch.float16
        self.device_map = device_map
        self.batch_size = batch_size
        self.persistent = persistent
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = None
        self.model = None

        self.exclude_from_cache_key('model_name', 'device_map', 'batch_size', 'persistent')

    def _load_model(self):
        """Load tokenizer and model into memory."""
        if self.model is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name,
                torch_dtype=self.torch_dtype,
                device_map=self.device_map,
                trust_remote_code=True
            )
            self.model.eval()

    def _unload_model(self):
        """Unload model and tokenizer from memory to free resources."""
        if self.model is not None:
            del self.model
            del self.tokenizer
            torch.cuda.empty_cache()
            self.model = None
            self.tokenizer = None

    def _call_model(self, tok: Union[torch.Tensor, Dict[str, Any]], model_device: torch.device) -> torch.Tensor:
        """Helper method to call the model with appropriate arguments based on the type of tokenized input."""
        with torch.no_grad():
            if isinstance(tok, torch.Tensor):
                # If tok is a tensor, it's the input_ids tensor
                outputs = self.model(input_ids=tok)
            else:
                # If tok is a dictionary, unpack it as kwargs
                outputs = self.model(**tok)
                
            logits = outputs.logits if hasattr(outputs, 'logits') else outputs[0]
            return logits

    def _calculate_impl(self, input: str, output: str, references=None, **kwargs) -> float:
        """
        Score a single input-output pair.
        """
        if self.model is None:
            self._load_model()
        # prepare conversation
        conv = [
            {"role": "user", "content": input},
            {"role": "assistant", "content": output}
        ]
        
        # Get model device and ensure tensors are on that device
        model_device = get_model_device(self.model, fallback_device=self.device)
        
        # tokenize chat template and ensure it's on the model's device
        tok = self.tokenizer.apply_chat_template(conv, tokenize=True, return_tensors="pt")
        tok = ensure_tensor_on_device(tok, model_device)
        
        logits = self._call_model(tok, model_device)
        # shape (batch,1)
        score = logits.squeeze(-1).squeeze(0).cpu().item()
        
        if not self.persistent:
            self._unload_model()
        return score

    def _calculate_batched_impl(self, inputs: List[str], outputs: List[str], references=None, **kwargs) -> List[float]:
        """
        Score batches of input-output pairs.
        """
        if self.model is None:
            self._load_model()
            
        # Get model device
        model_device = get_model_device(self.model, fallback_device=self.device)
                
        all_scores: List[float] = []
        for i in range(0, len(inputs), self.batch_size):
            chunk_in = inputs[i:i+self.batch_size]
            chunk_out = outputs[i:i+self.batch_size]
            convs = [
                [{"role": "user", "content": inp}, {"role": "assistant", "content": out}]
                for inp, out in zip(chunk_in, chunk_out)
            ]
            # Ensure tensors are on the model's device
            tok = self.tokenizer.apply_chat_template(convs, tokenize=True, return_tensors="pt")
            tok = ensure_tensor_on_device(tok, model_device)
            
            logits = self._call_model(tok, model_device)
            vals = logits.squeeze(-1).cpu().tolist()
            # ensure list
            if isinstance(vals, float):
                all_scores.append(vals)
            else:
                all_scores.extend(vals)
        if not self.persistent:
            self._unload_model()
        return all_scores 