import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.configuration_utils import PretrainedConfig
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import replace_return_docstrings
from typing import List, Optional, Tuple, Union
from torch.nn import CrossEntropyLoss

from lavis.models.blip2_models.modeling_llama import LlamaPreTrainedModel, LlamaForCausalLM, LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.utils import add_start_docstrings_to_model_forward
import loralib as lora

def apply_lora_to_llama(model, lora_r=16, target_modules=None):
    if target_modules is None:
        target_modules = ["q_proj", "k_proj","v_proj"]
    replaced = [] 
    for name, module in model.named_modules():
        if any(target_name in name for target_name in target_modules):
            if isinstance(module, nn.Linear):
                parent_name = ".".join(name.split(".")[:-1])
                attr_name = name.split(".")[-1]      
                if parent_name:
                    parent_module = model.get_submodule(parent_name)
                else:
                    parent_module = model
                
                lora_layer = lora.Linear(
                    module.in_features,
                    module.out_features,
                    r=lora_r,
                    bias=module.bias is not None
                )
                
                lora_layer.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    lora_layer.bias.data = module.bias.data.clone()
                setattr(parent_module, attr_name, lora_layer)
                replaced.append(name)
    lora.mark_only_lora_as_trainable(model)
    print("Replaced layers:")
    for name in replaced:
        print("  -", name)

    print("\nTrainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print("  -", name)
    return model

class LoraVicuna(LlamaPreTrainedModel):
    def __init__(self, llm_model, config, lora_r=4, target_modules=None):
        super().__init__(config)
        self.backbone = LlamaForCausalLM.from_pretrained(
            llm_model, torch_dtype=torch.float16, config=config
        )
        self.lora_r = lora_r
        self.target_modules = target_modules

        self.backbone = apply_lora_to_llama(
            self.backbone, 
            lora_r=lora_r, 
            target_modules=target_modules
        )
        # print(self.backbone.config)
        for name, param in self.backbone.named_parameters():
            if 'lora_' not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True
    
    def print_trainable_parameters(self):
        trainable_params = 0
        all_param = 0
        for _, param in self.backbone.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        print(f"trainable params in llm: {trainable_params:,} || all params: {all_param:,} || trainable%: {100 * trainable_params / all_param:.2f}")
    

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        reduction: Optional[str] = "mean",
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.backbone.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states =  outputs[0]
        logits = self.backbone.lm_head(hidden_states)
        base_weight = torch.ones_like(logits).to(self.device)
        logits = logits * (base_weight)
        
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = CrossEntropyLoss(reduction=reduction)
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            if reduction == "none":
                loss = loss.view(logits.size(0), -1).mean(1)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
    def prepare_inputs_for_generation(
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
    ):
        if past_key_values:
            input_ids = input_ids[:, -1:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -1].unsqueeze(-1)
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
        return reordered_past   
    
    def save_lora_weights(self, save_path):
        lora_state_dict = {}
        for name, param in self.backbone.named_parameters():
            if 'lora_' in name and param.requires_grad:
                lora_state_dict[name] = param.data
        torch.save(lora_state_dict, save_path)
        print(f"LoRA weights saved to {save_path}")

    def load_lora_weights(self, load_path):
        lora_state_dict = torch.load(load_path, map_location="cpu")
        missing_keys, unexpected_keys = self.backbone.load_state_dict(lora_state_dict, strict=False)
        print(f"LoRA weights loaded from {load_path}")
        if missing_keys:
            print(f"Missing keys: {missing_keys}")
        if unexpected_keys:
            print(f"Unexpected keys: {unexpected_keys}")


def create_lora_llm(llm_model, config, lora_r=4,  target_modules=None):
    if target_modules is None:
        #target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        target_modules = ["q_proj", "v_proj", "k_proj"]
    model = LoraVicuna(
        llm_model=llm_model, 
        config=config,
        lora_r=lora_r,
        target_modules=target_modules
    )
    model.print_trainable_parameters() 
    return model
