from peft import LoraConfig, PromptEncoderConfig, PrefixTuningConfig, TaskType
import os
import torch

from typing import Dict, Any
from contextlib import contextmanager

# keep global handles so we can remove hooks cleanly
_GEN_HOOKS: Dict[str, Any] = {}


def lora_config():
    return LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=2,
        lora_alpha=16,
        lora_dropout=0.1,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            # "gate_proj",
            "up_proj",
            "down_proj",
        ],
    )


def ptuning_config():
    return PromptEncoderConfig(
        task_type=TaskType.CAUSAL_LM, num_virtual_tokens=20, encoder_hidden_size=128
    )


def prefix_tuning_config():
    return PrefixTuningConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=False, num_virtual_tokens=20
    )


PEFT_CONFIGS = {
    "lora": lora_config,
    "ptuning": ptuning_config,
    "prefix_tuning": prefix_tuning_config,
}


def peft_factory(name):
    return PEFT_CONFIGS[name]()


def save_adapter_weights(model, adapter_name, save_dir="saved_adapters"):
    os.makedirs(save_dir, exist_ok=True)
    adapter_weights = {}
    for name, module in model.named_modules():
        if hasattr(module, "lora_A") and adapter_name in module.lora_A:
            A = module.lora_A[adapter_name].weight.data.clone().cpu()
            B = None
            if hasattr(module, "lora_B") and adapter_name in module.lora_B:
                B = module.lora_B[adapter_name].weight.data.clone().cpu()
            adapter_weights[name] = {
                "lora_A": A,
                "lora_B": B,
            }
    save_path = os.path.join(save_dir, f"{adapter_name}.pt")
    torch.save(adapter_weights, save_path)


def extract_adapter_specs(model, target_modules, adapter_name):
    specs = {}
    for name, module in model.base_model.named_modules():
        if hasattr(module, "lora_A") and hasattr(module, "lora_B"):
            if adapter_name in module.lora_A and adapter_name in module.lora_B:
                if any(tm in name for tm in target_modules):
                    shape_A = tuple(module.lora_A[adapter_name].weight.shape)
                    shape_B = tuple(module.lora_B[adapter_name].weight.shape)
                    specs[name] = (module, shape_A, shape_B)
    return specs



def apply_adapter_deltas(layer_specs, adapter_deltas, adapter_name, base_weights):
    """
    For each module, rebind its .weight parameter to
      (base_weight + delta)
    This new tensor will have requires_grad=True and
    a grad_fn that links back to
    whatever produced `delta`.
    """
    for name, (module, *_) in layer_specs.items():
        deltaA, deltaB = adapter_deltas[name]
        baseA, baseB = base_weights[name]

        subA = module.lora_A[adapter_name]
        subB = module.lora_B[adapter_name]

        # direct assignment into the _parameters dict
        subA._parameters["weight"] = baseA + deltaA
        subB._parameters["weight"] = baseB + deltaB


def add_zero_adapter(model_lm, peft_config, adapter_name: str):
    """
    Add a PEFT adapter whose LoRA weights are zero and frozen.
    We'll inject the per-step generator residual through hooks.
    """
    model_lm.add_adapter(peft_config, adapter_name=adapter_name)
    model_lm.set_adapter(adapter_name)
    # zero + freeze the LoRA weights
    for name, module in model_lm.base_model.named_modules():
        if hasattr(module, "lora_A") and adapter_name in module.lora_A:
            a = module.lora_A[adapter_name]
            b = module.lora_B[adapter_name]
            with torch.no_grad():
                a.weight.zero_()
                b.weight.zero_()
            a.weight.requires_grad_(False)
            b.weight.requires_grad_(False)
    # leave trainability flags alone; these are leaves now.


def clear_gen_hooks():
    for h in _GEN_HOOKS.values():
        h.remove()
    _GEN_HOOKS.clear()


def install_generated_deltas(
    model_lm, layer_specs, adapter_name: str, adapter_deltas: dict
):
    """
    Install forward hooks that add y += scale * (x @ ΔA^T) @ ΔB^T
    ONLY when `adapter_name` is active in model_lm.active_adapters().
    Does not modify any Parameter; safe with PEFT toggles.
    """
    clear_gen_hooks()

    # capture reference so hook can check current active adapters dynamically
    def is_active():
        act = model_lm.active_adapters()
        # PEFT returns list/tuple/None across versions; normalize:
        if act is None:
            return False
        if isinstance(act, (list, tuple, set)):
            return adapter_name in act
        return act == adapter_name

    for name, (module, shape_A, shape_B) in layer_specs.items():
        ΔA, ΔB = adapter_deltas[name]  # shapes [r,in], [out,r] (match your code)

        # PEFT scaling, if available
        scale = 1.0
        try:
            if hasattr(module, "lora_alpha") and hasattr(module, "r"):
                scale = float(module.lora_alpha / module.r)
            elif hasattr(module, "scaling"):
                sc = module.scaling
                scale = (
                    float(next(iter(sc.values())))
                    if isinstance(sc, dict) and sc
                    else float(sc)
                )
        except Exception:
            pass
        scale_t = ΔA.new_tensor(scale)

        def make_hook(ΔA, ΔB, scale_t):
            def hook(mod, inputs, output):
                # Only add residual if our adapter is active
                if not is_active():
                    return output
                x = inputs[0]
                h = x.matmul(ΔA.transpose(-1, -2))  # [*, r]
                y = h.matmul(ΔB.transpose(-1, -2))  # [*, out]
                return output + y * scale_t

            return hook

        handle = module.register_forward_hook(make_hook(ΔA, ΔB, scale_t))
        _GEN_HOOKS[name] = handle
