import os

import torch
from transformers import AutoConfig, AutoTokenizer

from src.nlp.models.BaseModel import BaseModel


class LlamaModelLMHead(BaseModel):
    # Used only for the special case with Llama where we precompute
    # the Llama backbone embeddings and then use this for the downstream
    # task.

    def __init__(self, cfg, device):
        super().__init__(cfg, device)
        
        self.model_name = "meta-llama/Llama-3.1-8B"
        config = AutoConfig.from_pretrained(self.model_name, token=os.environ["HF_TOKEN"])

        # Note: Other parts of the code are expecting these to be defined, but they
        # are not used in this model and experiment setting.
        self.hidden_dim = config.hidden_size
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name,
            token=os.environ["HF_TOKEN"],
        )
        self.tokenizer.padding_side = cfg.padding_side
        self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        self.linear_head = torch.nn.Linear(self.hidden_dim, self.hidden_dim, dtype=torch.bfloat16)
        self.linear_finetune = False

        self.lm_head = torch.nn.Linear(
            in_features=config.hidden_size,
            out_features=config.vocab_size + 1,  # We added pad token.
            dtype=torch.bfloat16,
            bias=False,
        )

        self.init_pooling(torch.bfloat16)


    def forward(self, batch):
        # Note: Unused in this model/experiment setting.
        raise NotImplementedError()

    def _distributed_setup(self):
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.device = self.local_rank
        self.to(self.local_rank)
        return self.local_rank, self.global_rank