import os
import sys
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model

sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from models.base_model import LLMModel
from configs.LLM_configs import LLM_CONFIG

LOGGER = logging.getLogger("Uncertainty Align")

class QwenLoraModule(nn.Module):
    """
    LoRA fine-tuning module for Qwen models, used for training and inference
    """
    
    def __init__(self, 
                model_path: str = None,
                inference_mode: bool = False
                ):
        super().__init__()
        
        config = LLM_CONFIG["qwen_lora"]
        default_lora_config = {
            "r": 16,
            "lora_alpha": 32,
            "lora_dropout": 0.05,
            "target_modules": ["q_proj", "v_proj"],
            "task_type": "CAUSAL_LM"
        }
        
        lora_config = config.get("lora_config", default_lora_config)
        if model_path is None:
            self.model_path = config.get("model_path", os.path.join(os.path.dirname(__file__), "Qwen", "Qwen2.5-7B-Instruct"))
        else:
            self.model_path = model_path
        self.inference_mode = inference_mode
        
        LOGGER.info(f"Loading base model from {self.model_path}")

        self.base_model = AutoModelForCausalLM.from_pretrained(
            self.model_path,
            torch_dtype="auto",
            device_map="auto",
            )
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        lora_config["inference_mode"] = inference_mode
        self.lora_config = LoraConfig(**lora_config)
        
        self.model = get_peft_model(self.base_model, self.lora_config)
    
    def forward(self, **inputs):
        """Forward pass, used for training"""
        return self.model(**inputs)
    
    def generate(self, **inputs):
        """Generate text, used for inference"""
        return self.model.generate(**inputs)
    
    def load_adapter(self, adapter_path, adapter_name="lora"):
        """Load LoRA weights"""
        if os.path.exists(adapter_path):
            LOGGER.info(f"Loading adapter from {adapter_path}")
            self.model.load_adapter(adapter_path, adapter_name=adapter_name)
            self.model.set_adapter(adapter_name)
            self.model.eval()
            LOGGER.info(f"Successfully loaded and activated adapter from {adapter_path}")
        else:
            LOGGER.warning(f"Adapter does not exist: {adapter_path}")
    
    def save_adapter(self, save_path, adapter_name="lora"):
        """Save LoRA weights"""
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        LOGGER.info(f"Saving adapter to {save_path}")
        self.model.save_pretrained(save_path, adapter_name=adapter_name)
    
    def prepare_inputs_for_generation(self, text):
        """Prepare inputs for generation"""
        messages = [{"role": "system", "content": text}]
        formatted_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        return self.tokenizer([formatted_text], return_tensors="pt", padding=True).to(self.model.device)


class QwenLoraModel(LLMModel):
    """
    LoRA fine-tuned version of Qwen model, providing external interface
    """
    _model_module = None
    _current_checkpoint = None

    def __init__(self,
                model_name: str,
                model_path: str = None,
                lora_checkpoint_path: str = None,
                checkpoint_name: str = None,
                temperature: float = 0,
                max_completion_tokens: int = 1024,
                top_p: float = 1.0):
        
        super().__init__(model_name)
        config = LLM_CONFIG.get("qwen_lora", LLM_CONFIG.get("qwen", {}))
        if model_path:
            self.model_path = model_path
        else:
            self.model_path = config.get('model_path', os.path.join(os.path.dirname(__file__), "Qwen", "Qwen2.5-7B-Instruct"))
        
        # Set LoRA related parameters
        self.lora_checkpoint_path = lora_checkpoint_path or config.get('lora_checkpoint_path', None)
        self.checkpoint_name = checkpoint_name
        
        # Set generation parameters
        self.temperature = temperature
        self.max_completion_tokens = max_completion_tokens
        self.top_p = top_p
        
        # LoRA configuration
        self.lora_config = config.get('lora_config', None)
        if QwenLoraModel._model_module is None:
            self.load_model()

    def load_model(self):
        """Load model"""
        QwenLoraModel._model_module = QwenLoraModule(
            model_path=self.model_path,
            inference_mode=True
        )

    def load_checkpoint(self):
        """Load checkpoint"""
        checkpoint_path = os.path.join(self.lora_checkpoint_path, self.checkpoint_name)
        if QwenLoraModel._current_checkpoint != checkpoint_path:
            QwenLoraModel._model_module.load_adapter(checkpoint_path)
            QwenLoraModel._model_module.model.eval()
            QwenLoraModel._current_checkpoint = checkpoint_path

    def set_checkpoint(self, checkpoint_name):
        """Set current checkpoint name to use"""
        self.checkpoint_name = checkpoint_name
        self.load_checkpoint()

    def call(self, llm_input: str, 
             return_probs: bool = False, 
             return_top_num: int = 5,
             do_generation: bool = True,
             ) -> str:

        model_module = QwenLoraModel._model_module
        assert model_module is not None, "Model not loaded."
        if QwenLoraModel._current_checkpoint != os.path.join(self.lora_checkpoint_path, self.checkpoint_name):
            self.load_checkpoint()

        model_inputs = model_module.prepare_inputs_for_generation(llm_input)
        
        if return_probs:
            with torch.no_grad():
                outputs = model_module.model(**model_inputs, output_hidden_states=True, return_dict=True)
                logits = outputs.logits
                
                last_token_logits = logits[0, -1, :]
                
                probs = F.softmax(last_token_logits, dim=0)
                
                topk_probs, topk_indices = torch.topk(probs, k=return_top_num)
                
                topk_probs_percent = topk_probs.cpu().float().numpy()
                
                probs_dict = {}
                for i, (idx, prob) in enumerate(zip(topk_indices, topk_probs_percent)):
                    token = model_module.tokenizer.decode([idx.item()]).strip()
                    if not token:
                        token = "<space>" if " " in model_module.tokenizer.decode([idx.item()]) else "<special>"
                    if token not in probs_dict:
                        probs_dict[token] = float(prob)
                    else:
                        probs_dict[token] += float(prob)
                
                if do_generation:
                    if self.temperature == 0:
                        generated_ids = model_module.generate(
                            **model_inputs,
                            max_new_tokens=self.max_completion_tokens,
                            do_sample=False
                        )
                    else:
                        generated_ids = model_module.generate(
                            **model_inputs,
                            max_new_tokens=self.max_completion_tokens,
                            temperature=self.temperature,
                            top_p=self.top_p,
                            do_sample=True
                        )
                        
                    generated_ids = [
                        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
                    ]
                    
                    response = model_module.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
                else:
                    response = ""
                return response, [probs_dict]
        
        # Regular text generation
        if self.temperature == 0:
            
            generated_ids = model_module.generate(
                **model_inputs,
                max_new_tokens=self.max_completion_tokens,
                do_sample=False
            )
        else:
            generated_ids = model_module.generate(
                **model_inputs,
                max_new_tokens=self.max_completion_tokens,
                temperature=self.temperature,
                top_p=self.top_p,
                do_sample=True
            )
        
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        
        response = model_module.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return response