from safetensors.torch import load_file
import re
import os
import torch.nn as nn
from itertools import chain
from peft.tuners.lora.model import LoraModel
from peft.utils.integrations import dequantize_module_weight, gather_params_ctx, get_bnb_param_type
from peft.utils import (
    TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
    ModulesToSaveWrapper,
    _freeze_adapter,
    _get_submodules,
    get_peft_model_state_dict,
    get_quantization_config,
)
from peft.tuners.lora.layer import Conv2d, LoraLayer, dispatch_default


LORA_INIT_DIR = os.environ.get('LORA_INIT_DIR')  
if LORA_INIT_DIR is None:
    raise ValueError("wrong.")


def init_allorank_from_file(new_module, adapter_name, current_key):

    A_path = os.path.join(LORA_INIT_DIR, "lora_A_init.safetensors")
    B_path = os.path.join(LORA_INIT_DIR, "lora_B_init.safetensors")

    key_A = f"base_model.model.{current_key}.weight.lora_A"
    key_B = f"base_model.model.{current_key}.weight.lora_B"

    A_weights = load_file(A_path)
    B_weights = load_file(B_path)

    if key_A not in A_weights or key_B not in B_weights:
        raise KeyError(f"No key: {key_A} or {key_B}")

    A_tensor = A_weights[key_A].to(new_module.lora_A[adapter_name].weight.device)
    B_tensor = B_weights[key_B].to(new_module.lora_B[adapter_name].weight.device)

    new_module.lora_A[adapter_name].weight.data.copy_(A_tensor)
    new_module.lora_B[adapter_name].weight.data.copy_(B_tensor)

    weight = new_module.get_base_layer().weight
    dtype = weight.dtype

    BA = B_tensor @ A_tensor
    weight = weight.data - new_module.scaling[adapter_name] * BA
    weight = weight.to(dtype)
    new_module.get_base_layer().weight.data = weight



def monkey_patched_create_and_replace(
    self,
    lora_config,
    adapter_name,
    target,
    target_name,
    parent,
    current_key,
):
    if current_key is None:
        raise ValueError("Current Key shouldn't be `None`")

    pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys()))
    target_name_key = next(filter(lambda key: re.match(rf".*\.{key}$", current_key), pattern_keys), current_key)
    r = lora_config.rank_pattern.get(target_name_key, lora_config.r)
    alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha)

    kwargs = {
        "r": r,
        "lora_alpha": alpha,
        "lora_dropout": lora_config.lora_dropout,
        "fan_in_fan_out": lora_config.fan_in_fan_out,
        "init_lora_weights": lora_config.init_lora_weights,
        "use_rslora": lora_config.use_rslora,
        "use_dora": lora_config.use_dora,
        "ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
        "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
        "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
    }

    quant_methods = ["gptq", "aqlm", "awq"]
    for quant_method in quant_methods:
        quantization_config = get_quantization_config(self.model, method=quant_method)
        if quantization_config is not None:
            kwargs[f"{quant_method}_quantization_config"] = quantization_config

    if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
        target.update_layer(
            adapter_name,
            r,
            lora_alpha=alpha,
            lora_dropout=lora_config.lora_dropout,
            init_lora_weights=lora_config.init_lora_weights,
            use_rslora=lora_config.use_rslora,
            use_dora=lora_config.use_dora,
        )
    else:
        new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)

        if isinstance(new_module, LoraLayer):
            init_allorank_from_file(
                new_module,
                adapter_name,
                current_key=current_key,
            )

        if adapter_name not in self.active_adapters:
            new_module.requires_grad_(False)
        self._replace_module(parent, target_name, new_module, target)


def apply_alloranklora_monkey_patch():

    LoraModel._create_and_replace = monkey_patched_create_and_replace

