
import importlib
import math
import re
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.pytorch_utils import Conv1D

from ..utils import PeftConfig, PeftType, transpose


def is_bnb_available():
    return importlib.util.find_spec("bitsandbytes") is not None


if is_bnb_available():
    import bitsandbytes as bnb


class HypernetworkMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dims):
        super(HypernetworkMLP, self).__init__()
        self.mlp_common = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )
        self.mlp_A = nn.Linear(hidden_dim, int(output_dim[0]))
        self.mlp_B = nn.Linear(hidden_dim, int(output_dim[1]))

        # Initialize the layers' weights to zero
        self.mlp_A.bias.data.zero_()
        self.mlp_B.weight.data.zero_()
        self.mlp_B.bias.data.zero_()

    def forward(self, x):
        common = self.mlp_common(x)
        output_A = self.mlp_A(common)
        output_B = self.mlp_B(common)
        return output_A, output_B


        


@dataclass
class OPLoraConfig(PeftConfig):
    """
    Configuration class to store the configuration of a OPLora model.
    """

    r: int = field(default=8, metadata={"help": "Lora attention dimension"})
    target_modules: Optional[Union[List[str], str]] = field(
        default=None,
        metadata={
            "help": "List of module names or regex expression of the module names to replace with OPLora."
        },
    )
    lora_alpha: int = field(default=None, metadata={"help": "Lora alpha"})
    lora_dropout: float = field(default=None, metadata={"help": "Lora dropout"})
    merge_weights: bool = field(
        default=False, metadata={"help": "Merge weights of the original model and the Lora model"}
    )
    fan_in_fan_out: bool = field(
        default=False,
        metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
    )
    enable_lora: Optional[List[bool]] = field(default=None, metadata={"help": "Used with `lora.MergedLinear`."})
    bias: str = field(default="none", metadata={"help": "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"})
    modules_to_save: Optional[List[str]] = field(
        default=None,
        metadata={
            "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
        },
    )
    hypernetwork_input_dim: int = field(default=128, metadata={"help": "Input dimension for hypernetwork MLP"})
    hypernetwork_hidden_dim: int = field(default=32, metadata={"help": "Hidden dimension for hypernetwork MLP"})

    def __post_init__(self):
        self.peft_type = PeftType.OPLORA


class OPLoraModel(torch.nn.Module):
    """
    Creates Low Rank Adapter (OPLora) model from a pretrained transformers model.

    Args:
        model ([`transformers.PreTrainedModel`]): The model to be adapted.
        config ([`OPLoraConfig`]): The configuration of the OPLora model.
    """

    def __init__(self, config, model):
        super().__init__()
        self.peft_config = config
        self.model = model
        self._find_and_replace()
        mark_only_hyper_lora_as_trainable(self.model, self.peft_config.bias)
        self.forward = self.model.forward

    def _find_and_replace(self):
        loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
        if loaded_in_8bit and not is_bnb_available():
            raise ImportError(
                "To use OPLora with 8-bit quantization, please install the `bitsandbytes` package."
            )
        is_target_modules_in_base_model = False
        is_hf_device_map_available = hasattr(self.model, "hf_device_map")
        kwargs = {
            "r": self.peft_config.r,
            "lora_alpha": self.peft_config.lora_alpha,
            "lora_dropout": self.peft_config.lora_dropout,
            "fan_in_fan_out": self.peft_config.fan_in_fan_out,
            "merge_weights": (self.peft_config.merge_weights or self.peft_config.inference_mode)
            and not is_hf_device_map_available,
            "hypernetwork_input_dim": self.peft_config.hypernetwork_input_dim,
            "hypernetwork_hidden_dim": self.peft_config.hypernetwork_hidden_dim,
        }
        key_list = [key for key, _ in self.model.named_modules()]
        for key in key_list:
            if isinstance(self.peft_config.target_modules, str):
                target_module_found = re.fullmatch(self.peft_config.target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in self.peft_config.target_modules)
            if target_module_found:
                if not is_target_modules_in_base_model:
                    is_target_modules_in_base_model = True
                parent, target, target_name = self._get_submodules(key)
                bias = target.bias is not None
                if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
                    kwargs.update(
                        {
                            "has_fp16_weights": target.state.has_fp16_weights,
                            "memory_efficient_backward": target.state.memory_efficient_backward,
                            "threshold": target.state.threshold,
                            "index": target.index,
                        }
                    )
                    new_module = HyperLinear8bitLt(target.in_features, target.out_features, bias=bias, **kwargs)
                elif isinstance(target, torch.nn.Linear):
                    new_module = HyperLinear(target.in_features, target.out_features, bias=bias, **kwargs)
                self._replace_module(parent, target_name, new_module, target)
        if not is_target_modules_in_base_model:
            raise ValueError(
                f"Target modules {self.peft_config.target_modules} not found in the base model."
            )

    def _get_submodules(self, key):
        parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
        target_name = key.split(".")[-1]
        target = self.model.get_submodule(key)
        return parent, target, target_name

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if old_module.bias is not None:
            new_module.bias = old_module.bias
        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "hyper_lora_" in name:
                module.to(old_module.weight.device)

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped model."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)

    @property
    def modules_to_save(self):
        return None

    def get_peft_config_as_dict(self, inference: bool = False):
        config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(self.peft_config).items()}
        if inference:
            config["inference_mode"] = True
        return config

    def _set_adapter_layers(self, enabled=True):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.disable_adapters = False if enabled else True

    def enable_adapter_layers(self):
        self._set_adapter_layers(enabled=True)

    def disable_adapter_layers(self):
        self._set_adapter_layers(enabled=False)


def mark_only_hyper_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
    for n, p in model.named_parameters():
        if "hypernetwork" not in n and  "input_vector" not in n and "rescale" not in n:
            p.requires_grad = False
    if bias == "none":
        return
    elif bias == "all":
        for n, p in model.named_parameters():
            if "bias" in n:
                p.requires_grad = True
    elif bias == "lora_only":
        for m in model.modules():
            if isinstance(m, OPLoraLayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError


class OPLoraLayer:
    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
        hypernetwork: Optional[nn.Module] = None,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        self.hypernetwork = hypernetwork  # Hypernetwork MLP
        # Optional dropout
        if lora_dropout > 0.0:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        self.merged = False
        self.merge_weights = merge_weights
        self.disable_adapters = False



class HyperLinear(nn.Linear, OPLoraLayer):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        hypernetwork_input_dim: int = 128,
        hypernetwork_hidden_dim: int = 256,
        **kwargs,
    ):
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        self.hypernetwork = HypernetworkMLP(hypernetwork_input_dim, hypernetwork_hidden_dim, [r * in_features , r* out_features]))
        input_vector = torch.empty(1,hypernetwork_input_dim).normal_(mean=0.0, std=0.01)
        self.input_vector = nn.Parameter(input_vector)
        OPLoraLayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights,
                                hypernetwork=self.hypernetwork)
        self.fan_in_fan_out = fan_in_fan_out
        if r>0:
            self.weight.requires_grad = False
            self.scaling = self.lora_alpha / self.r
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T


    def generate_lora_parameters(self):
        if self.hypernetwork is not None and self.input_vector is not None:
            output_A, output_B = self.hypernetwork(self.input_vector)
            lora_A_size = (self.in_features, self.r)
            lora_B_size = (self.r, self.out_features)
            self.lora_A = output_A.reshape(lora_A_size)
            self.lora_B = output_B.reshape(lora_B_size) 

    def forward(self, x: torch.Tensor):
        previous_dtype = self.weight.dtype
        if self.disable_adapters:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        else:
            self.generate_lora_parameters()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
            if self.r > 0:
                result += ((self.lora_dropout(x.to(self.lora_A.dtype)) @ self.lora_A) @ self.lora_B) * self.scaling
        return result


if is_bnb_available():
    pass
