from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

class HybridTextGenerator:
    def __init__(self, main_model_name, aux_model_name, device="cuda"):
        self.device = device
        self.main_tokenizer = AutoTokenizer.from_pretrained(main_model_name)
        self.main_model = AutoModelForCausalLM.from_pretrained(
            main_model_name,
            device_map={"": device},
            torch_dtype=torch.float16
        )
        self.aux_tokenizer = AutoTokenizer.from_pretrained(aux_model_name)
        self.aux_model = AutoModelForCausalLM.from_pretrained(aux_model_name).to(self.device)
        
        self.common_words = self.get_common_words()
        self.common_ids_main = torch.tensor(
            [self.main_tokenizer.encode(word, add_special_tokens=False)[0] for word in self.common_words],
            device=self.device
        )
        self.common_ids_aux = torch.tensor(
            [self.aux_tokenizer.encode(word, add_special_tokens=False)[0] for word in self.common_words],
            device=self.device
        )

    @staticmethod
    def get_common_words(): 
       return [' the',' to',' and',' of',' a',' in',' that',' you',' it',' for',' on',' he',' with',' this',' as',' we',' but',' at',' they',' what',' his',' from',' by',' or',' she',' my',' all',' an',' her',' about',' me',' if',' your',' can',' who',' out',' their',' like',' would',' when',' him',' them',' some',' how',' which',' than',' our',' into',' because',' these',' over',' us',' its',' where',' after',' any',' those',' should',' may',' through',' why',' before',' off',' while',' around',' another',' both',' between',' every',' each',' might',' since',' against',' without',' must',' during',' under',' though',' until',' whether',' among',' along',' within',' across',' behind',' either',' himself',' although',' outside',' themselves',' is',' was',' be',' have',' are',' do',' had',' has',' were',' will',' did',' been',' could',' does',' need',' being',' am',' used',' doing',' having']


    def hybrid_generate(self, text, prefix_len=32, max_new_tokens=48,remove_prefix = False):
        inputs = self.main_tokenizer(text, return_tensors="pt", padding=True).to(self.device)
        prefix_tokens = inputs.input_ids[0][:prefix_len]
        generated_ids = prefix_tokens.unsqueeze(0)
        
        for _ in range(max_new_tokens):
            with torch.no_grad():
                curr_text = self.main_tokenizer.decode(generated_ids[0])
                main_outputs = self.main_model(generated_ids)
                main_logits = main_outputs.logits[0, -1]
                p_main = torch.softmax(main_logits, dim=0)
                
                aux_inputs = self.aux_tokenizer(curr_text, return_tensors="pt").to(self.device)
                aux_outputs = self.aux_model(aux_inputs.input_ids)
                aux_logits = aux_outputs.logits[0, -1]
                p_aux = torch.softmax(aux_logits, dim=0)
            
            p_main_sum = p_main[self.common_ids_main].sum()
            p_aux_sum = p_aux[self.common_ids_aux].sum()
            
            if p_aux_sum.item() == 0:
                p_gen = p_main.clone()
            else:
                p_gen = p_main.clone()
                for main_id, aux_id in zip(self.common_ids_main, self.common_ids_aux):
                    if aux_id < len(p_aux):
                        p_gen[main_id] = (p_main_sum) * (p_aux[aux_id]) / (p_aux_sum)
                p_gen = p_gen / p_gen.sum()
            
            next_token = torch.argmax(p_gen).unsqueeze(0)
            generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
        if remove_prefix:
            return self.main_tokenizer.decode(generated_ids[0][prefix_len:], skip_special_tokens=True)
        return self.main_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    def next_token_logits(self, inputs):
        prefix_tokens = inputs
        generated_ids = prefix_tokens.unsqueeze(0)
        
        with torch.no_grad():
            curr_text = self.main_tokenizer.decode(generated_ids[0])
            main_outputs = self.main_model(generated_ids)
            main_logits = main_outputs.logits[0, -1]
            p_main = torch.softmax(main_logits, dim=0)
            
            aux_inputs = self.aux_tokenizer(curr_text, return_tensors="pt").to(self.device)
            aux_outputs = self.aux_model(aux_inputs.input_ids)
            aux_logits = aux_outputs.logits[0, -1]
            p_aux = torch.softmax(aux_logits, dim=0)
        
        p_main_sum = p_main[self.common_ids_main].sum()
        p_aux_sum = p_aux[self.common_ids_aux].sum()
        
        if p_aux_sum.item() == 0:
            p_gen = p_main.clone()
        else:
            p_gen = p_main.clone()
            for main_id, aux_id in zip(self.common_ids_main, self.common_ids_aux):
                if aux_id < len(p_aux):
                    p_gen[main_id] = (p_main_sum) * (p_aux[aux_id]) / (p_aux_sum)
            p_gen = p_gen / p_gen.sum()
    
        return p_gen

