import os
import sys
import torch
import torch.nn.functional as F
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from transformers import AutoModelForCausalLM, AutoTokenizer
from configs.LLM_configs import LLM_CONFIG

import logging

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from models.base_model import LLMModel

LOGGER = logging.getLogger("Uncertainty Align")

class QWENModel(LLMModel):

    _model = None
    _tokenizer = None

    def __init__(self,
                 model_name: str,
                 model_path: str = None,
                 temperature: float = 1.0,
                 max_completion_tokens: int = 1024,
                 top_p: float = 1.0
                 ):
        super().__init__(model_name)

        config = LLM_CONFIG["qwen"]
        if model_path:
            self.model_path = model_path
        else:
            self.model_path = config.get('model_path', os.path.join(os.path.dirname(__file__), "Qwen", "Qwen2.5-7B-Instruct"))
        self.temperature = temperature
        self.max_completion_tokens = max_completion_tokens
        self.top_p = top_p

        if QWENModel._model is None or QWENModel._tokenizer is None:
            self.load_model()

    def load_model(self):

        LOGGER.info(f"Loading model from {self.model_path}")
        print(f"Loading model from {self.model_path}")
        # apply_liger_kernel_to_qwen2()
        QWENModel._model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype="auto",
            device_map="auto",
        )
        QWENModel._tokenizer = AutoTokenizer.from_pretrained(self.model_path)


    def call(self, llm_input: str, 
             return_probs: bool = False, 
             return_top_num: int = 5,
             do_generation: bool = True,
             ) -> str:
        """
        Call Qwen model to get response
        
        Args:
            llm_input: Input text
            return_probs: Whether to return probability distribution
            return_top_num: Return top N tokens with highest probability
            
        Returns:
            If return_probs is False, return generated text
            If return_probs is True, return (generated_text, probability_distribution) tuple
        """
        assert self._model is not None and self._tokenizer is not None, "Model not loaded."

        messages = [
            {"role": "system", "content": llm_input}
        ]
        text = self._tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        model_inputs = self._tokenizer([text], return_tensors="pt", padding=True).to(self._model.device)

        if return_probs:
            with torch.no_grad():
                outputs = self._model(**model_inputs, output_hidden_states=True, return_dict=True)
                logits = outputs.logits
                
                last_token_logits = logits[0, -1, :]
                
                probs = F.softmax(last_token_logits, dim=0)
                
                topk_probs, topk_indices = torch.topk(probs, k=return_top_num)
                
                topk_probs_percent = topk_probs.cpu().float().numpy()
                
                probs_dict = {}
                for i, (idx, prob) in enumerate(zip(topk_indices, topk_probs_percent)):
                    token = self._tokenizer.decode([idx.item()]).strip()
                    if not token:
                        token = "<space>" if " " in self._tokenizer.decode([idx.item()]) else "<special>"
                    if token not in probs_dict:
                        probs_dict[token] = float(prob)
                    else:
                        probs_dict[token] += float(prob)
                
                if do_generation:
                    if self.temperature == 0:
                        generated_ids = self._model.generate(
                            **model_inputs,
                            max_new_tokens=self.max_completion_tokens,
                            do_sample=False
                        )
                    else:
                        generated_ids = self._model.generate(
                            **model_inputs,
                            max_new_tokens=self.max_completion_tokens,
                            temperature=self.temperature,
                            top_p=self.top_p
                        )
                        
                    generated_ids = [
                        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
                    ]
                    
                    response = self._tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
                else:
                    response = ""
                return response, [probs_dict]
        
        if self.temperature == 0:
            generated_ids = self._model.generate(
                **model_inputs,
                max_new_tokens=self.max_completion_tokens,
                temperature=None,
                top_p=None,
                top_k=None,
                do_sample=False
            )
        else:
            generated_ids = self._model.generate(
                **model_inputs,
                max_new_tokens=self.max_completion_tokens,
                temperature=self.temperature,
                top_p=self.top_p
            )
        
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        response = self._tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response


