from typing import Optional 
import torch 
import torch .nn .functional as F 
from torch .distributed .tensor .parallel import loss_parallel 
from transformers import PreTrainedModel 
from .base import LM ,SampleOutput 


class HFLoss :
    """
    Compute the loss for the autoregressive model.
    """

    def __call__ (self ,model :PreTrainedModel ,batch :dict ):
        if "labels"not in batch :
            batch ["labels"]=batch ["input_ids"].detach ().clone ()


        input_ids =batch .pop ("input_ids")[:,:-1 ].contiguous ()
        labels =batch .pop ("labels")[:,1 :].contiguous ()

        labels [labels ==model .config .pad_token_id ]=-100 

        outputs =model (input_ids )
        logits =outputs .logits 





        loss =F .cross_entropy (
        logits .reshape (-1 ,logits .size (-1 )),
        labels .reshape (-1 ),
        )

        ppl =torch .exp (loss )
        return dict (loss =loss ,perplexity =ppl )


class Autoregressive (LM ):
    support_loss_parallel :bool =True 

    def __init__ (
    self ,
    loss_config :Optional [dict ]=None ,
    generation_config :Optional [dict ]=None ,
    **kwargs ,
    ):
        super ().__init__ (loss_config =loss_config ,**kwargs )
        if loss_config is None :
            self .loss_fn =HFLoss ()
        self .generation_config =generation_config 





    def sample (self ,num_samples ,**kwargs )->SampleOutput :
        if self .generation_config is None :
            self .generation_config ={

            "do_sample":False ,
            "num_beams":1 ,
            "bos_token_id":self .model .config .bos_token_id ,
            "eos_token_id":self .model .config .eos_token_id ,
            "pad_token_id":self .model .config .pad_token_id ,
            "max_length":self .model .config .max_position_embeddings ,
            }
        self .generation_config ["return_dict_in_generate"]=True 
        self .generation_config ["output_logits"]=True 
        inputs =torch .ones ((num_samples ,1 ),dtype =torch .long ,device =self .device )
        inputs =inputs *self .model .config .bos_token_id 
        outputs =self .model .generate (inputs ,**self .generation_config )
        logits =torch .cat (outputs .logits ).permute (
        1 ,0 ,2 
        )
        generated =self .tokenizer .batch_decode (outputs )
        return SampleOutput (output_ids =outputs ,texts =generated ,logits =logits )

    def compute_metrics (self ,batch ,loss_dict ):

        loss =loss_dict ["val/loss"]
        bpc =loss /torch .log (torch .tensor (2.0 ,device =loss .device ))
        return {"bpc":bpc }
