from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments,BitsAndBytesConfig
import torch
import os
import torch.nn.functional as F



class ModelConfig:
    def __init__(
            self, 
            model_type: str = "",
            model_path: str = "",
            tokenizer_path: str = "",
            device_id: str = "",
            


        ):
        self.model_type: str = model_type
        self.model_path: str = model_path
        self.tokenizer_path: str = tokenizer_path
        self.device_id: str = device_id
      
        



class Model_Factory(object):

    def __init__(self,model_config:ModelConfig):

        self.tokenizer = AutoTokenizer.from_pretrained(model_config.tokenizer_path,trust_remote_code=True,)
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        # self.tokenizer.padding_side = "right"

        self.model = AutoModelForCausalLM.from_pretrained(
                model_config.model_path,
                torch_dtype=torch.bfloat16,
                load_in_4bit=False,
                trust_remote_code=True,

    ).eval()
        self.model_type = model_config.model_type
        self.device = model_config.device_id
        self.model = self.model.to(self.device)
        self.__previous_id = torch.tensor(0)
        
        if self.model_type != "glm4-instruct":
            self.tokenizer.padding_side = "right"
# inference generate stage 
    def Speculative_generate(self,input_text,each_step_max_generate_token):
        each_step_max_generate_token = each_step_max_generate_token
        if  not self.__previous_id.shape:
            messages = [{"role": "user", "content": input_text},]

            input_ids = self.tokenizer.apply_chat_template(
                                    messages,
                                    add_generation_prompt=True,
                                    return_tensors="pt"
                                    ).to(self.device)

        else:
            token_ids = self.tokenizer.encode(input_text,
                                              add_special_tokens=False,
                                              return_tensors="pt").to(self.device)
            input_ids = torch.cat((self.__previous_id,token_ids),1)
            
        self.__previous_id = input_ids

        attention_mask = torch.ones(input_ids.shape,dtype=torch.long,device=self.device)
        generate_kwargs = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,
            do_sample=True,
            temperature=0.6, 
            top_p=0.9, 
            pad_token_id = self.tokenizer.eos_token_id,
            max_new_tokens=each_step_max_generate_token,
            )
        outputs = self.model.generate(**generate_kwargs)
        
        response = outputs[0][input_ids.shape[-1]:]
        

        if_space = str(self.tokenizer.decode(outputs[0][input_ids.shape[-1]-1:input_ids.shape[-1]+1]))
        return_span = self.tokenizer.decode(response)


        return return_span,self.model_type



# This function is used to verify the quality of newly generated text by parallel computing.
    @torch.no_grad()
    def model_verify(self,new_generate):
        

        # tokenize all text including previous text and the newly generated text
        
        previous_text = self.tokenizer.decode(self.__previous_id[0])
        new_updates_text_list = [previous_text+" "+text  for text in new_generate]

        inputs = self.tokenizer.batch_encode_plus(new_updates_text_list,padding=True, return_tensors="pt").to(self.device)

        # get the length of previous text
        
        length_previous_ids = self.__previous_id.shape[-1]-1
    
        # feed into the model to get the logits for verify the quality of new generated text
        logits = self.model(**inputs).logits

        score = torch.stack([self.get_verify_score(length_previous_ids,inputs["input_ids"][i],logits[i],inputs["attention_mask"][i]) for i in range(logits.size(0))])


        return score,self.model_type 

# calculate the mean logits score of each candidate generated text
    def get_verify_score(self,start_position,tokens_id,logits,att_mask):

        att_length = att_mask.sum(-1).item()
        softmax_logits=F.softmax(logits, dim=-1)
        softmax_logits=softmax_logits.squeeze()[start_position:att_length-1]
        ids=tokens_id.unsqueeze(-1)[start_position+1:att_length]
        mean_score=softmax_logits.gather(1,ids).transpose(0,1).squeeze().mean(-1)
        return mean_score   

    def Clear_Cache(self):
        self.__previous_id = torch.tensor(0)

    

        




