from transformers import GenerationConfig
import torch


class Generator:
    def __init__(self, config, tokenizer):
        self.config = config
        self.tokenizer = tokenizer

        self.config["eos_token_id"] = self.tokenizer.eos_token_id
        self.config["pad_token_id"] = self.tokenizer.pad_token_id
        #self.config["bos_token_id"] = self.tokenizer.eos_token_id

        print(self.tokenizer.eos_token_id, self.tokenizer.pad_token_id, self.tokenizer.bos_token_id)

        self.generation_config = GenerationConfig(**self.config)

    def get_per_token_logp(
        self,
        model,
        batch,
        labels,
        batch_size=None,
        prompt_length=0,
        use_no_grad=True,
        detach=False,
    ):
        # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
        model_logits = self.get_logits(
            model,
            batch,
            batch_size=batch_size,
            prompt_length=prompt_length,
            use_no_grad=use_no_grad,
        )
        # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
        model_logp = torch.log_softmax(model_logits, dim=-1)
        # tensor shape -> (batch_size, prompt_length:sequence_length, vocab_size)
        model_per_token_logp = torch.gather(
            model_logp,
            dim=-1,
            index=labels.unsqueeze(-1),
        ).squeeze(2)
        if detach:
            model_logits= model_logits.detach()
            model_logp = model_logp.detach()
            model_per_token_logp = model_per_token_logp.detach()
            
        return model_per_token_logp

    def get_logits(
        self, model, batch, batch_size=None, prompt_length=0, use_no_grad=False
    ):

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]

        if input_ids.shape != attention_mask.shape:
            raise ValueError(
                "Input ids and attention mask shape mismatch; tensor must have the same shape."
            )

        batch_size = batch_size or input_ids.size(
            0
        )  # Chunk inputs into smaller batches to reduce memory peak
        all_logits = []
        for start in range(0, input_ids.size(0), batch_size):
            input_ids_batch = input_ids[start : start + batch_size]
            attention_mask_batch = attention_mask[start : start + batch_size]

            context = torch.no_grad() if use_no_grad else torch.enable_grad()
            with context:
                model_output = model(
                    input_ids=input_ids_batch, attention_mask=attention_mask_batch
                )

            # call returns tensor shape -> (batch_size, sequence_length, vocab_size), where sequence_length = prompt_length+response
            if prompt_length == 0:
                logits = model_output.logits[:, :-1, :]
            else:
                logits = model_output.logits[:, prompt_length:-1, :]
            all_logits.append(logits)

        logits = torch.cat(all_logits, dim=0)

        return logits

    def sample(self, model, batch):
        input_ids = batch["input_ids"].detach().clone()
        attention_mask = batch["attention_mask"].detach().clone()

        model_output = model.generate(
            input_ids=input_ids.to(model.device),
            attention_mask=attention_mask.to(model.device),
            generation_config=self.generation_config,
            return_dict_in_generate=True,
            output_scores=False,
            output_logits=False,
            repetition_penalty=None,
            logits_processor=None,
            logits_warper=None,
        )

        return model_output.sequences
