from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F

class LabelModel:
    def __init__(self, model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct", device = 'cuda:1', use_auto = False, cache_dir = 'cache/') -> None:
        self.model_name = model_name
        if use_auto:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map = 'auto', cache_dir = cache_dir)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir = cache_dir).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = cache_dir)
        self.device = device
    
    def set_constrained(self, zero = '1', one = '2'):
        self.zero_token = zero
        self.one_token = one
        self.zero_token_id = self.tokenizer.convert_tokens_to_ids(zero)
        self.one_token_id = self.tokenizer.convert_tokens_to_ids(one)
    
    def __call__(self, context, first, second, prompt_template, reprompt_template = None, **kwargs):
        prompt = prompt_template(context, first, second)
        if reprompt_template is not None:
            probs = self.generate_reprompt(prompt, reprompt_template, **kwargs)
        else:
            probs = self.generate(prompt)
        
        prompt = prompt_template(context, second, first)
        
        if reprompt_template is not None:
            probs += torch.flip(self.generate_reprompt(prompt, reprompt_template, **kwargs), dims = [0])
        else:
            probs +=  torch.flip(self.generate(prompt), dims = [0])
        
        probs = probs / 2
        if probs[0] >= probs[1]:
                return self.zero_token
        return self.one_token
        

    def generate_reprompt(self, prompt, reprompt_template, **kwargs):
        message = [{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." }, {"role": "user", "content": prompt}]
        input_ids = self.tokenizer.apply_chat_template(
            message, 
            add_generation_prompt=True, 
            return_tensors="pt"
            ).to(self.model.device)
        
        output_ids = self.model.generate(input_ids, eos_token_id=self.tokenizer.eos_token_id, **kwargs)
        generated_text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        return self.generate(reprompt_template(generated_text))

    
    def generate(self, prompt):
        message = [{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant." }, {"role": "user", "content": prompt}]
        inputs = self.tokenizer.apply_chat_template(
            message, 
            add_generation_prompt=True, 
            return_tensors="pt"
            ).to(self.model.device)
        
        
        with torch.no_grad():
            logits = self.model(inputs)['logits'][:, -1, :]
            
            zero_logits = logits[0, self.zero_token_id].item()
            one_logits = logits[0, self.one_token_id].item()
            
            return F.softmax(torch.tensor([zero_logits, one_logits]), dim=0)


class TaskModel:
    def __init__(self, model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct", device = 'cuda:1', cache_dir = 'cache/', use_auto = False) -> None:
        self.model_name = model_name
        if use_auto:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map = 'auto', cache_dir = cache_dir)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir = cache_dir).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir = cache_dir)
        self.device = device
    
    def __call__(self, prompt, **kwargs):
        if 'instruct' in self.model_name.lower():
            message = [{"role": "user", "content": prompt}]
            
            input_ids = self.tokenizer.apply_chat_template(
                message, 
                add_generation_prompt=True, 
                return_tensors="pt"
            ).to(self.model.device)
        else:
            input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids'].to(self.model.device)
        
    
        output_ids = self.model.generate(input_ids, eos_token_id= self.tokenizer.eos_token_id, **kwargs)
        return self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True)
    
    
        
    