import os
import torch
from omegaconf import DictConfig
from transformers import AutoModelForCausalLM
from torch.nn.parallel import DistributedDataParallel

class BasePolicy: 
    def __init__(self, cfg: DictConfig, torch_dtype: torch.dtype = 'auto', checkpoint_dir: str = None): 
        self.cfg = cfg
        if checkpoint_dir:
            self.network = AutoModelForCausalLM.from_pretrained(os.path.join(cfg.io.load_root, checkpoint_dir), trust_remote_code=True, torch_dtype=torch_dtype)
        else:
            self.network = AutoModelForCausalLM.from_pretrained(cfg.policy.model, trust_remote_code=True, torch_dtype=torch_dtype)

    @torch.no_grad()
    def inference(self, batch, return_hidden_states: bool = False): 
        if isinstance(self.network, DistributedDataParallel):
            output = self.network.module.model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        else:
            output = self.network.model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
        if return_hidden_states:
            sequence_lengths = batch["attention_mask"].sum(dim=1)
            start_of_response = torch.argmax((batch["labels"] != -100).int(), dim=1)
            last_hidden_states = output.last_hidden_state[torch.arange(batch["input_ids"].shape[0], device=batch["input_ids"].device), sequence_lengths - 1]
            second_to_last_hidden_states = output.last_hidden_state[torch.arange(batch["input_ids"].shape[0], device=batch["input_ids"].device), sequence_lengths - 2]
            mean_hidden_states = []
            for i, (seq_len, start_of_response) in enumerate(zip(sequence_lengths, start_of_response)):
                mean_hidden_states.append(output.last_hidden_state[i, start_of_response:seq_len].mean(dim=0))
            mean_hidden_states = torch.stack(mean_hidden_states)
            return last_hidden_states, second_to_last_hidden_states, mean_hidden_states
        
        return None
    
    def forward(self, batch):
        output = self.network(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
        return output
        
    