import os
import json
import torch
from typing import Literal
from src.utils.tallmask_utils import construct_consensus_mask
from src.utils.utils import vector_to_state_dict, state_dict_to_vector, topk_values_mask
from src.utils.ties_utils import ties_merging
from omegaconf import DictConfig

def _per_head_subspace_boosting(param, config, svd_thresh, cumsum):
    # Process attention weights per head
    num_heads = 16 if config.model == "ViT-L-14" else 12
    embed_dim = param.shape[1]
    head_dim = embed_dim // num_heads

    W_Q, W_K, W_V = torch.split(param, embed_dim, dim=0)
    W_mat = [W_Q, W_K, W_V]

    # Reshape the weights to separate heads per Q, K, V
    W_heads = [W.reshape(num_heads, head_dim, embed_dim) for W in W_mat]
    clamped_weights = [torch.zeros_like(wh) for wh in W_heads]
    
    # Clamping approach using the cumulative sum of singular values as the threshold
    if cumsum:
        for head_idx in range(num_heads):            
            U_Q, S_Q, Vh_Q = torch.linalg.svd(W_heads[0][head_idx], full_matrices=False)
            U_K, S_K, Vh_K = torch.linalg.svd(W_heads[1][head_idx], full_matrices=False)
            U_V, S_V, Vh_V = torch.linalg.svd(W_heads[2][head_idx], full_matrices=False)

            left_singular_vectors = [U_Q, U_K, U_V]
            singular_values = [S_Q, S_K, S_V]
            right_singular_vectors = [Vh_Q, Vh_K, Vh_V]

            # Apply subspace boosting to each head per Q, K, V
            for i, left_singular_v in enumerate(left_singular_vectors):
                S = singular_values[i]
                total_sum = S.sum()
                cumulative = torch.cumsum(S, dim=0)
                k = (cumulative / total_sum >= svd_thresh).nonzero(as_tuple=False)
                cutoff_idx = k[0].item()

                S_damped = torch.clamp(S, min=S[cutoff_idx])
                clamped_weights[i][head_idx] = (left_singular_v * S_damped.unsqueeze(0)) @ right_singular_vectors[i]

        # Concatenate all reshaped attention weights
        W_final = [clamped_weight.reshape(embed_dim, embed_dim) for clamped_weight in clamped_weights]
        return torch.cat(W_final, dim=0) 
    else:
        raise NotImplementedError("Per_head_subspace_boosting is not implemented for cumsum=False.")
    
def _per_qkv_subspace_boosting(param, config, attn_svd_thresh, cumsum):
    embed_dim = param.shape[1]
    W_Q, W_K, W_V = torch.split(param, embed_dim, dim=0)
    W_mat = [W_Q, W_K, W_V]

    clamped_weights = [torch.zeros_like(W) for W in W_mat]

    if cumsum:
        for i, mat in enumerate(W_mat):
            U, S, Vh = torch.linalg.svd(mat, full_matrices=False)
            # Clamping approach using the cumulative sum of singular values as the threshold
            total_sum = S.sum()
            cumulative = torch.cumsum(S, dim=0)
            k = (cumulative / total_sum >= attn_svd_thresh).nonzero(as_tuple=False)
            cutoff_idx = k[0].item()

            S_damped = torch.clamp(S, min=S[cutoff_idx])
            clamped_weights[i] = (U * S_damped.unsqueeze(0)) @ Vh
        
        # Concatenate all weights
        return torch.cat(clamped_weights, dim=0)
    else:
        raise NotImplementedError("Per_qkv_subspace_boosting is not implemented for cumsum=False.")

def subspace_boosting(
        tv_flat_checks, 
        ptm_check, 
        base_method: Literal["sum", "ties", "consensus"] = "ties",
        config: DictConfig = None,
        reset_thresh=20, # TODO: refactor the parameter list and just use the config
        svd_thresh=0.01,
        attn_svd_thresh=0.10,
        cumsum=True, 
        remove_keys=[]
    ):
    """
    Subspace boosting for merging task vectors.

    Parameters:
        tv_flat_checks: Flattened task vectors.
        ptm_check: 
            Pretrained model.
        base_method: 
            Based merging method. Options are "sum", "ties", or "consensus". Defaults to "ties".
        config: 
            Configuration object containing method parameters (e.g., config.method.k, config.method.use_ties).
        reset_thresh: default 20
            Threshold parameter used for ties merging. defaults to 20.
        svd_thresh: default 0.01
            Threshold for singular value boosting. If cumsum is True, used as a cumulative ratio threshold;
            otherwise used as a fraction of the total number of singular values. Defaults to 0.01.
        cumsum:
            Whether to use the cumulative sum approach for thresholding the singular values.
        remove_keys:
            Optional list of keys to remove from the state dict conversion.
    
    Returns:
        A merged flat vector representing the task vector after subspace boosting.

    Raises:
        ValueError: If the base_method is not one of the defined options.
    """
    # Base merging method
    if base_method == "sum": # Task Arithmetic
        tv_flat_checks = topk_values_mask(tv_flat_checks, K=config.method.k)
        merged_flat_vector = tv_flat_checks.sum(dim=0)
    elif base_method == "ties":
        merged_flat_vector = ties_merging(tv_flat_checks, reset_thresh=reset_thresh, merge_func="dis-mean")
    elif base_method == "consensus":
        # construct consensus mask (assuming the TALL masks have already been constructed)
        # Set the prun_thre_k=2 to remove both catastrophic and selfish weights
        consensus_mask = construct_consensus_mask(ptm_check, 2, config, remove_keys)

        if config.method.use_ties:
            merged_flat_vector = ties_merging(tv_flat_checks, reset_thresh=20, merge_func="dis-sum")
        else: # Use TA by default
            tv_flat_checks = topk_values_mask(tv_flat_checks, K=config.method.k)
            merged_flat_vector = tv_flat_checks.sum(dim=0)
        merged_flat_vector = merged_flat_vector * consensus_mask
    else:
        raise ValueError(f"Method {base_method} not defined.")
    
    # Merging approach for attention weight matrices
    apply_to_attn = config.method.apply_to_attn
    # apply_to_attn=False: no subspace boosting for attention weights
    if apply_to_attn not in [False, "full_attn", "per_qkv", "per_head"]:
        raise ValueError(f"Apply to attention method {apply_to_attn} not defined.")
    
    merged_state_dict = vector_to_state_dict(merged_flat_vector, ptm_check, remove_keys)

    keys_to_eval = [
        "attn.in_proj_weight", # Attention weight matrices
        "attn.out_proj.weight",
        "c_fc.weight",
        "c_proj.weight",
    ]

    for key, param in merged_state_dict.items():
        if any(i in key for i in keys_to_eval) and isinstance(param, torch.Tensor):
            # Process attention weights per head or qkv
            if keys_to_eval[0] in key:
                if apply_to_attn == "per_head":
                    merged_state_dict[key] = _per_head_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum)
                elif apply_to_attn == "per_qkv":
                    merged_state_dict[key] = _per_qkv_subspace_boosting(param, config, config.method.attn_svd_thresh, cumsum)
            
            # Process full attention weights and MLP weights
            if apply_to_attn == "full_attn" or (keys_to_eval[0] not in key):
                U, S, Vh = torch.linalg.svd(param, full_matrices=False)

                # Clamping approach using the cumulative sum of singular values as the threshold
                if cumsum:
                    total_sum = S.sum()
                    cumulative = torch.cumsum(S, dim=0)
                    thresh = config.method.attn_svd_thresh if (keys_to_eval[0] in key) else svd_thresh

                    k = (cumulative / total_sum >= thresh).nonzero(as_tuple=False)
                    cutoff_idx = k[0].item()

                    S_damped = torch.clamp(S, min=S[cutoff_idx])
                else: # Clamping approach using the threshold as an index
                    cutoff_idx = int(thresh * S.numel())
                    S_damped = torch.clamp(S, min=S[cutoff_idx])

                merged_state_dict[key] = (U * S_damped.unsqueeze(0)) @ Vh

    return state_dict_to_vector(merged_state_dict)