"""Utilities for hypernetwork-based LoRA merging and visualization.

This module provides helper functions used by hypernetwork to:
- Extract and precompute LoRA deltas (ΔW) from per-domain adapters
- Stack per-site, per-domain LoRA columns into tensors expected by the hypernet
- Generate merged deltas using hypernet weights 
- Generate heatmaps for merge weights
- Register forward hooks to apply LoRA deltas without mutating base CLIP weights
"""

from __future__ import annotations
import random
import logging
from collections import OrderedDict
from operator import attrgetter
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from omegaconf import OmegaConf
import open_clip
import wandb    

import math, torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple
from tqdm import tqdm
import re
import os

import matplotlib.pyplot as plt

# ------------------------------------------------------------------------------------
#  Helpers for precomputing and merging LoRA deltas
# ------------------------------------------------------------------------------------

def _adapter_name(module):
    """Return the adapter name for a PEFT LoRA module."""
    # Works for PEFT LoRA layers (ModuleDict of adapters)
    if hasattr(module, "lora_A") and hasattr(module.lora_A, "keys"):
        ks = list(module.lora_A.keys())
        if ks:
            return ks[0]  
    return "default"

@torch.inference_mode()
def precompute_lora_deltas(source_models, lora_sites, device):
    """Precompute LoRA deltas ΔW per site and per source domain.

    Computes ΔW = (B @ A) * scale for each requested site and domain.

    Parameters
    ----------
    source_models : Dict[str, nn.Module]
        Mapping domain -> model with LoRA adapters.
    lora_sites : List[str]
        Names of LoRA sites (e.g., 'visual.transformer.resblocks.0.attn.in_proj.qkv').
    device : torch.device | str
        Device for intermediate computation and returned tensors.

    Returns
    -------
    Dict[str, Dict[str, torch.Tensor]]
        Nested mapping site -> domain -> ΔW tensor matching base weight shape [out, in].
    """
    print("Pre-computing LoRA deltas …")
    output_deltas: Dict[str, Dict[str, torch.Tensor]] = {}

    for site in tqdm(lora_sites, ncols=100, desc="ΔW sites"):
        per_domain: Dict[str, torch.Tensor] = {}

        for domain, mdl in source_models.items():
            module = _find_module(mdl, site)

            try:
                adapter = _adapter_name(module)
                A = module.lora_A[adapter].weight.to(device).to(torch.float32)  # [r, in]
                B = module.lora_B[adapter].weight.to(device).to(torch.float32)  # [out, r]
                r = A.shape[0]
            except Exception:
                print(f"Error: No LoRA at {site} for domain {domain}")
                exit()
            
            # 1) Get scale value (alpha/r) if stored
            scale = None
            if hasattr(module, "scaling"):
                s = module.scaling
                if isinstance(s, dict):
                    s = s.get(adapter, None)
                if torch.is_tensor(s):
                    scale = float(s.detach().cpu().item())
                elif isinstance(s, (int, float)):
                    scale = float(s)

            # 2) Fallback: compute s manually -> model.peft_config[adapter].lora_alpha / r
            if scale is None and hasattr(mdl, "peft_config") and adapter in getattr(mdl, "peft_config", {}):
                alpha = getattr(mdl.peft_config[adapter], "lora_alpha", None)
                if alpha is not None:
                    scale = float(alpha) / float(r)

            # 3) Fallback -> scale = 1.0
            if scale is None:
                scale = 1.0  

            # Compute delta
            delta = (B @ A) * scale  # [out, in]

            # Optional: scale using row-wise magnitude vector if it exists
            mag = None
            if hasattr(module, "lora_magnitude_vector"):
                mv = module.lora_magnitude_vector
                if isinstance(mv, dict):
                    mv = mv.get(adapter, None)
                if torch.is_tensor(mv):
                    mag = mv.to(device).to(delta.dtype)
            if mag is not None:
                delta = delta * mag.unsqueeze(1)

            # Make sure delta is in the correct orientation
            if getattr(module, "fan_in_fan_out", False):
                delta = delta.t().contiguous()

            # Set dtype and device
            delta = delta.to(dtype=module.weight.dtype, device=device)

            # Make sure delta has the correct shape
            if delta.shape != module.weight.shape:
                raise RuntimeError(
                    f"ΔW shape mismatch at {site} for domain {domain}: "
                    f"delta {delta.shape} vs weight {module.weight.shape}"
                )

            per_domain[domain] = delta

        output_deltas[site] = per_domain

    return output_deltas


def stack_columns_for_sites_by_type(
    deltas_by_site: Dict[str, Dict[str, torch.Tensor]],
    site_names: List[str],                 # only qkv OR only proj sites
    domain_order: List[str],
    device: torch.device,
    mask_domain_name: Optional[str] = None,
) -> torch.Tensor:
    """
    Build a [S_t, D, C, F] column tensor for a site type (all qkv OR all proj).
      - S_t: len(site_names)
      - D:   len(domain_order)
      - C,F: Numbers of columns and rows in delta matrix
    If mask_domain_name matches a domain in domain_order, that [D] slice is zeroed.

    Returns:
      column_tensor: [S_t, D, C, F]
    """
    if len(site_names) == 0:
        return torch.empty((0, len(domain_order), 0, 0), device=device)

    domain_order = list(domain_order) 

    first_delta = next(iter(deltas_by_site[site_names[0]].values()))  # [F, C]
    F, C = first_delta.shape

    S_t, D = len(site_names), len(domain_order)
    column_tensor = torch.empty((S_t, D, C, F), device=device)

    # Optional domain mask
    mask_idx = None
    if mask_domain_name is not None:
        try:
            mask_idx = domain_order.index(mask_domain_name)
        except ValueError:
            mask_idx = None

    for s_i, site in enumerate(site_names):
        for d_i, dom in enumerate(domain_order):
            if mask_idx is not None and d_i == mask_idx:
                column_tensor[s_i, d_i] = torch.zeros((C, F), device=device)
            else:
                delta = deltas_by_site[site][dom]  # [F, C]
                column_tensor[s_i, d_i] = delta.t()  # [C, F]

    return column_tensor.contiguous()



def _extract_layer_id(site_name: str) -> int:
    """Extract an integer layer id from a site string, if present.
    Returns -1 if no index can be parsed.
    """
    m = re.search(r"\bresblocks\.(\d+)\b", site_name) or \
        re.search(r"\bblocks?\.(\d+)\b", site_name) or \
        re.search(r"\blayer(\d+)\b", site_name)
    return int(m.group(1)) if m else -1

def generate_merge_deltas(
    hypernet: nn.Module,
    deltas_by_site: Dict[str, Dict[str, torch.Tensor]],  
    site_order: List[str],                                
    domain_order: List[str],                              
    domain_representation: torch.Tensor,                  
    device: torch.device,
    mask_domain_name: Optional[str] = None,               
) -> Dict[str, torch.Tensor]:
    """Generate merged deltas per site using hypernet-computed weights.

    For each site type (qkv/proj), stacks per-domain ΔW columns, obtains
    per-domain weights via the hypernet, and linearly combines columns to
    produce a merged ΔW for that site.

    Parameters
    ----------
    hypernet : nn.Module
        Hypernetwork that outputs domain weights.
    deltas_by_site : Dict[str, Dict[str, torch.Tensor]]
        Mapping site -> (domain -> ΔW [F, C]).
    site_order : List[str]
        Ordered list of all sites to process (qkv and proj).
    domain_order : List[str]
        Ordered list of domain names.
    domain_representation : torch.Tensor
        Target-domain embedding provided to the hypernetwork.
    device : torch.device
        Device on which to build tensors and run the hypernet.
    mask_domain_name : Optional[str]
        If provided, that source domain is masked out in the hypernet weights.

    Returns
    -------
    Dict[str, torch.Tensor]
        Mapping site_name -> merged ΔW tensor of shape [F, C] suitable for hooks.
    """
    # Group sites by type (preserve order within each type)
    qkv_sites  = [name for name in site_order if name.endswith(".qkv")]
    proj_sites = [name for name in site_order if name.endswith(".proj")]
    other_sites = [name for name in site_order if (not name.endswith(".qkv") and not name.endswith(".proj"))]
    if other_sites:
        raise ValueError(f"Unsupported site types (not .qkv/.proj): {other_sites}")

    # Build column tensors per type (optional domain masking)
    cols_qkv = stack_columns_for_sites_by_type(deltas_by_site, qkv_sites,  domain_order, device, mask_domain_name)
    cols_proj = stack_columns_for_sites_by_type(deltas_by_site, proj_sites, domain_order, device, mask_domain_name)

    # Make per-type site id tensors using direct layer ids
    site_ids_qkv  = torch.tensor([_extract_layer_id(s) for s in qkv_sites], dtype=torch.long, device=device)
    site_ids_proj = torch.tensor([_extract_layer_id(s) for s in proj_sites], dtype=torch.long, device=device)

    # Domain mask index 
    mask_idx = domain_order.index(mask_domain_name) if mask_domain_name is not None else None

    # Run hypernet per type
    w_qkv  = hypernet(
            column_tensor=cols_qkv,                 # [S_t, D, C, F]
            site_type="qkv",                   # 'qkv' or 'proj'
            domain_representation=domain_representation,
            mask_domain_idx=mask_idx,
            site_ids=site_ids_qkv,
        )
    w_proj = hypernet(
            column_tensor=cols_proj,                 # [S_t, D, C, F]
            site_type="proj",                   # 'qkv' or 'proj'
            domain_representation=domain_representation,
            mask_domain_idx=mask_idx,
            site_ids=site_ids_proj,
        )
    
    # Merge deltas using hypernet weights
    merged_qkv = torch.einsum('sdc,sdcf->scf', w_qkv, cols_qkv)  # [S_t,C,F] 
    merged_qkv = merged_qkv.permute(0, 2, 1).contiguous()  
    map_qkv  = {site: merged_qkv[i]  for i, site in enumerate(qkv_sites)}

    merged_proj = torch.einsum('sdc,sdcf->scf', w_proj, cols_proj)  # [S_t,C,F] 
    merged_proj = merged_proj.permute(0, 2, 1).contiguous()   
    map_proj = {site: merged_proj[i] for i, site in enumerate(proj_sites)}

    out: Dict[str, torch.Tensor] = {}
    for site in site_order:
        if site.endswith(".qkv"):
            out[site] = map_qkv[site]
        else:
            out[site] = map_proj[site]
    return out

# ------------------------------------------------------------------------------------
#  Heatmap visualization helpers
# ------------------------------------------------------------------------------------

@torch.inference_mode()
def compute_domain_weights(
    hypernet: nn.Module,
    deltas_by_site: Dict[str, Dict[str, torch.Tensor]],
    site_order: List[str],
    domain_order: List[str],
    domain_representation: torch.Tensor,
    device: torch.device,
    mask_domain_name: Optional[str] = None,
):
    """
    Return raw weight tensors from the hypernet used for heatmap visualization.

    Returns a dict with:
      - 'w_qkv':  [S_q, D, C*]
      - 'w_proj': [S_p, D, C*]
      - 'per_domain_qkv':  [D] mean over S_q and C*
      - 'per_domain_proj': [D] mean over S_p and C*
      - 'per_site_domain_qkv':  [S_q, D] mean over C*
      - 'per_site_domain_proj': [S_p, D] mean over C*
      - 'qkv_sites':  List[str]
      - 'proj_sites': List[str]
    """
    # Group sites by type
    qkv_sites  = [name for name in site_order if name.endswith(".qkv")]
    proj_sites = [name for name in site_order if name.endswith(".proj")]

    # Build column tensors per type
    cols_qkv = stack_columns_for_sites_by_type(
        deltas_by_site, qkv_sites, domain_order, device, mask_domain_name
    )
    cols_proj = stack_columns_for_sites_by_type(
        deltas_by_site, proj_sites, domain_order, device, mask_domain_name
    )

    # Site ids extracted from layer indices
    site_ids_qkv  = torch.tensor([_extract_layer_id(s) for s in qkv_sites], dtype=torch.long, device=device)
    site_ids_proj = torch.tensor([_extract_layer_id(s) for s in proj_sites], dtype=torch.long, device=device)

    # Domain mask index if any
    mask_idx = domain_order.index(mask_domain_name) if mask_domain_name is not None else None

    # Forward through hypernet to get raw weights
    w_qkv = hypernet(
        column_tensor=cols_qkv,
        site_type="qkv",
        domain_representation=domain_representation,
        mask_domain_idx=mask_idx,
        site_ids=site_ids_qkv,
    )
    w_proj = hypernet(
        column_tensor=cols_proj,
        site_type="proj",
        domain_representation=domain_representation,
        mask_domain_idx=mask_idx,
        site_ids=site_ids_proj,
    )

    # Aggregates: average across sites and columns to get a per-domain profile
    per_domain_qkv  = w_qkv.mean(dim=(0, 2)) if w_qkv.numel() > 0 else torch.empty((len(domain_order),), device=device)
    per_domain_proj = w_proj.mean(dim=(0, 2)) if w_proj.numel() > 0 else torch.empty((len(domain_order),), device=device)

    # Aggregates: per-site x domain (averaged over columns)
    per_site_domain_qkv  = w_qkv.mean(dim=2) if w_qkv.numel() > 0 else torch.empty((0, len(domain_order)), device=device)
    per_site_domain_proj = w_proj.mean(dim=2) if w_proj.numel() > 0 else torch.empty((0, len(domain_order)), device=device)

    return {
        "w_qkv": w_qkv,
        "w_proj": w_proj,
        "per_domain_qkv": per_domain_qkv,
        "per_domain_proj": per_domain_proj,
        "per_site_domain_qkv": per_site_domain_qkv,
        "per_site_domain_proj": per_site_domain_proj,
        "qkv_sites": qkv_sites,
        "proj_sites": proj_sites,
    }

def _make_heatmap(
    data: torch.Tensor,
    row_labels,
    col_labels,
    title: str,
    title_size: int = 13,
    tick_size: int = 13,
    label_size: int = 13,
):
    """
    Create a matplotlib heatmap figure using merge weights.
    """
    if torch.is_tensor(data):
        arr = data.detach().cpu().numpy()
    else:
        arr = data

    fig, ax = plt.subplots(figsize=(max(6, len(col_labels) * 0.5), max(4, len(row_labels) * 0.3)))
    im = ax.imshow(arr, aspect='auto', cmap='viridis')
    ax.set_title(title, fontsize=title_size)
    ax.set_xticks(range(len(col_labels)))
    ax.set_xticklabels(col_labels, rotation=45, ha='right', fontsize=tick_size)
    ax.set_yticks(range(len(row_labels)))
    ax.set_yticklabels(row_labels, fontsize=tick_size)
    ax.set_xlabel('Source domains', fontsize=label_size)
    ax.set_ylabel('LoRA layers (resblock id)', fontsize=label_size)
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.ax.set_ylabel('weight', rotation=-90, va='bottom', fontsize=label_size)
    cbar.ax.tick_params(labelsize=tick_size)
    fig.tight_layout()
    return fig


def render_domain_weight_heatmaps(weights_info: Dict[str, torch.Tensor], domain_labels, out_dir: str, prefix: str = "weights", step: int = 0):
    """
    Render and save heatmaps for per-site per-domain weights for qkv and proj.
    Returns dict of file paths and matplotlib figures.
    """
    os.makedirs(out_dir, exist_ok=True)

    results = {}

    # QKV heatmap
    qkv = weights_info.get("per_site_domain_qkv")
    qkv_sites = weights_info.get("qkv_sites", [])
    if qkv is not None and qkv.numel() > 0:
        row_labels = [str(_extract_layer_id(s)) for s in qkv_sites]
        col_labels = domain_labels
        fig_qkv = _make_heatmap(qkv, row_labels=row_labels, col_labels=col_labels, title=f"ATTN domain weights (step {step})")
        path_qkv = os.path.join(out_dir, f"{prefix}_qkv_step{step}.png")
        fig_qkv.savefig(path_qkv, dpi=200, bbox_inches='tight')
        results['qkv'] = { 'fig': fig_qkv, 'path': path_qkv }

    # PROJ heatmap
    proj = weights_info.get("per_site_domain_proj")
    proj_sites = weights_info.get("proj_sites", [])
    if proj is not None and proj.numel() > 0:
        row_labels = [str(_extract_layer_id(s)) for s in proj_sites]
        col_labels = domain_labels
        fig_proj = _make_heatmap(proj, row_labels=row_labels, col_labels=col_labels, title=f"PROJ domain weights (step {step})")
        path_proj = os.path.join(out_dir, f"{prefix}_proj_step{step}.png")
        fig_proj.savefig(path_proj, dpi=200, bbox_inches='tight')
        results['proj'] = { 'fig': fig_proj, 'path': path_proj }

    return results

# ------------------------------------------------------------------------------------
#  Hook factory & registration helpers
# ------------------------------------------------------------------------------------

def remove_all_hooks_(model):
    """
    Remove forward/pre-forward/backward hooks from every sub-module in each
    visual transformer block of CLIP.
    """
    for layer in model.visual.transformer.resblocks:
        for module in layer.modules():
            module._forward_hooks = OrderedDict()
            module._forward_pre_hooks = OrderedDict()
            if hasattr(module, "_backward_hooks"):
                module._backward_hooks = OrderedDict()


def _find_module(model, site: str):
    """Resolve and return a submodule by dotted attribute path."""
    return attrgetter(site)(model)

def _lora_hook_factory(delta_weight: torch.Tensor, debug_info=None):
    """Return a forward‑hook that adds x·ΔWᵀ to module output."""
    delta_weight = delta_weight.contiguous()

    def _hook(module, args, output):
        dw = delta_weight
        if dw.device != module.weight.device or dw.dtype != module.weight.dtype:
            dw = dw.to(device=module.weight.device, dtype=module.weight.dtype, copy=False)

        x = args[0] 
        # Debug verification: show that hook computes x @ (W + ΔW)ᵀ = x @ Wᵀ + x @ ΔWᵀ
        if debug_info is not None:
            print(f"\n=== HOOK EXECUTION DEBUG for {debug_info['site']} ===")
            print(f"Input x shape: {x.shape}")
            print(f"Original module weight unchanged: {torch.equal(module.weight, debug_info['original_weight'])}")
            
            # Compute original output: x @ W^T
            original_output = F.linear(x, module.weight, module.bias)
            
            # Compute delta: x @ ΔW^T
            delta = F.linear(x, dw)
            
            # The hook should produce: original_output + delta = x @ W^T + x @ ΔW^T = x @ (W + ΔW)^T
            expected_merged_output = F.linear(x, module.weight + dw, module.bias)
            
            print(f"Original output (first element): {original_output.flatten()[0]}")
            print(f"Delta output (first element): {delta.flatten()[0]}")
            print(f"Expected merged output (first element): {expected_merged_output.flatten()[0]}")
            print(f"Hook will produce (first element): {(output + delta).flatten()[0] if not isinstance(output, tuple) else (output[0] + delta).flatten()[0]}")
            
            # Verify they match
            actual_hook_output = output + delta if not isinstance(output, tuple) else output[0] + delta
            
            # Debug: Check shapes and differences
            print(f"Expected shape: {expected_merged_output.shape}, Actual shape: {actual_hook_output.shape}")
            diff = torch.abs(actual_hook_output - expected_merged_output)
            max_diff = torch.max(diff)
            mean_diff = torch.mean(diff)
            print(f"Max difference: {max_diff}, Mean difference: {mean_diff}")
            
            matches = torch.allclose(actual_hook_output, expected_merged_output, atol=1e-5, rtol=1e-5)
            print(f"Hook output matches expected merged: {matches}")
            
            if not matches:
                matches_relaxed = torch.allclose(actual_hook_output, expected_merged_output, atol=1e-4, rtol=1e-4)
                print(f"Matches with relaxed tolerance (1e-4): {matches_relaxed}")
            
            print("=" * 50)
        else:
            delta = F.linear(x, dw)  
            
        if isinstance(output, tuple):
            return (output[0] + delta, *output[1:])
        return output + delta

    return _hook

def register_lora_hooks(model, delta_dict: Dict[str, torch.Tensor], debug_verification=False):
    """Register forward hooks that add LoRA deltas to module outputs.

    Parameters
    ----------
    model : nn.Module
        Base model whose submodules will receive LoRA delta hooks.
    delta_dict : Dict[str, torch.Tensor]
        Mapping from site string to ΔW tensor [out, in] for that site.
    debug_verification : bool
        If True, prints debug info and verifies one site's hook numerically.

    Returns
    -------
    List[torch.utils.hooks.RemovableHandle]
        List of hook handles so callers can remove them later.
    """
    handles = []
    debug_site = None
    
    for site, delta in delta_dict.items():
        m = _find_module(model, site)
        
        debug_info = None
        if debug_verification and debug_site is None:
            debug_site = site
            print(f"\n=== DEBUG HOOK VERIFICATION for site: {site} ===")
            print(f"Original weight shape: {m.weight.shape}")
            print(f"Original weight (first 3x3): \n{m.weight[:3, :3]}")
            print(f"LoRA delta shape: {delta.shape}")
            print(f"LoRA delta (first 3x3): \n{delta[:3, :3]}")
            print(f"Expected merged (first 3x3): \n{(m.weight + delta)[:3, :3]}")
            print("=" * 50)
            
            debug_info = {
                'site': site,
                'original_weight': m.weight.clone()
            }
        
        h = m.register_forward_hook(_lora_hook_factory(delta, debug_info))
        handles.append(h)
    if not handles:
        print("No LoRA hooks registered – check site list.")

    return handles
