import torch.nn as nn
from transformers import (
    AutoConfig,)

from utils import get_model


class PPOModel(nn.Module):
    def __init__(self, args, device: int):
        super().__init__()
        self.model_parallel = args.model_parallel
        self.config = AutoConfig.from_pretrained(args.model_path)
        self.base_model = get_model(args, device)
        self.base_model.eval() # no dropout for RL
        self.model_type = args.model_type

    def forward(self, **x):
        base_model_outputs = self.base_model(**x)
        return base_model_outputs
    
    def generate(self, **x):
        return self.base_model.generate(**x)
    
    def set_force_gradient_checkpointing(self, value: bool):
        if self.model_type == "qwen3":
            self.base_model.gradient_checkpointing_enable()
        else:
            self.base_model.set_force_gradient_checkpointing(value)
