import torch.nn as nn
import torch
from safetensors.torch import  save_file
import os
from transformers import AutoModel
import torch.nn.functional as F

class DecoderModel(nn.Module):
    def __init__(self, pretrained_model_name_or_path, mode='generation'):
        super().__init__()

        self.base_model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
        
        self.mode = mode
        # self.config = self.base_model.config
        if 'attn' in self.mode:
            self.attn = nn.Linear(self.base_model.config.hidden_size, 1)
        if 'dense' in self.mode:
            self.pooler_head = nn.Linear(self.base_model.config.hidden_size, self.base_model.config.hidden_size)
            self.pooler_activation = nn.GELU()

        self.classification_head = nn.Linear(self.base_model.config.hidden_size, 1)
    
    def forward(self, input_ids, attention_mask, get_pooler_output=False, **kwargs):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
        if 'attn' in self.mode:
            attn_score = self.attn(outputs.last_hidden_state).squeeze(-1) # (B, seq_len)
            attn_score = F.softmax(attn_score, dim=1).unsqueeze(-1) # (B, seq_len, 1), Normalize
            pooled_states = torch.sum(outputs.last_hidden_state * attn_score, dim=1) # (B, hidden_size)
        else:
            pooled_states = self.last_token_pool(outputs.last_hidden_state, attention_mask)
        if 'dense' in self.mode:
            pooled_states = self.pooler_head(pooled_states)
            pooled_states = self.pooler_activation(pooled_states)
        output = self.classification_head(pooled_states)
        if get_pooler_output:
            return {'logits': output, 'pooler_output': pooled_states}
        return {'logits': output}

    def last_token_pool(self, hidden_states, attention_mask):
        # first see if left padding is used
        left_padding = (attention_mask[:,-1].sum() == attention_mask.shape[0])
        if left_padding:
            return hidden_states[:, -1]
        else:
            seq_lens = attention_mask.long().sum(-1)
            return hidden_states[torch.arange(hidden_states.shape[0], device=hidden_states.device), seq_lens - 1]

    def save_pretrained(self, path, state_dict_file='model.safetensors'):
        if not os.path.exists(path):
            os.makedirs(path)
        self.base_model.config.save_pretrained(path)
        state_dict = self.state_dict()
        save_file(state_dict, os.path.join(path, state_dict_file))