import math
import torch

def _orthonormalize_columns(mat: torch.Tensor) -> torch.Tensor:
    """Return a matrix with orthonormal columns obtained via QR factorization."""
    # QR gives mat = Q R with Q having orthonormal columns
    # Use reduced mode for efficiency
    q, _ = torch.linalg.qr(mat, mode="reduced")
    return q

def _orthonormalize_rows(mat: torch.Tensor) -> torch.Tensor:
    """Return a matrix with orthonormal rows by orthonormalizing columns of the transpose."""
    return _orthonormalize_columns(mat.t()).t().contiguous()
  
def gram_schmidt_init(embed_dim, rank, num_heads, b_var, a_var): 
    B_all = torch.randn(embed_dim, rank * num_heads)
    B_all = _orthonormalize_columns(B_all)
    # Bs = [(B_all[:, i * rank : (i + 1) * rank] * b_var).clone() for i in range(num_heads)]
    Bs = [B_all[:, i * rank : (i + 1) * rank].clone() for i in range(num_heads)]

    A_all = torch.randn(rank * num_heads, embed_dim)
    A_all = _orthonormalize_rows(A_all)
    As = [A_all[i * rank : (i + 1) * rank, :].clone() for i in range(num_heads)]
    # As = [(A_all[i * rank : (i + 1) * rank, :] * a_var).clone() for i in range(num_heads)]

    return Bs, As
  
def constant_init(embed_dim, rank, num_heads, b_var, a_var): 
    B_shared = (torch.randn(embed_dim, rank) * b_var)
    A_shared = (torch.randn(rank, embed_dim) * a_var)
    Bs = [B_shared.clone() for _ in range(num_heads)]
    As = [A_shared.clone() for _ in range(num_heads)]

    return Bs, As 

def shared_subspace_init(embed_dim, rank, num_heads, b_var, a_var): 
    common_var = 0.1
    M = torch.randn(embed_dim, rank)
    N = torch.randn(rank, embed_dim)

    Bs, As = [], [] 
    for i in range(num_heads): 
        #  R_i = torch.randn(rank, rank) * (1.0 / math.sqrt(rank))
        R_i = torch.eye(rank) + 0.01 * torch.randn(rank, rank)
        Bs.append((M @ R_i * b_var).clone())
        As.append((R_i @ N * a_var ).clone())
    
    return Bs, As

def init_adapters(model, init_scheme, embed_dim, rank, num_heads, b_var, a_var): 
    for name, _ in model.named_parameters(): 
        if 'base_layer' in name and 'weight' in name: 
            Bs, As = None, None
            if init_scheme == 'random_normal':
                break
            elif init_scheme == 'constant':
                Bs, As = constant_init(embed_dim, rank, num_heads, b_var, a_var)
            elif init_scheme == 'gram_schmidt': 
                Bs, As = gram_schmidt_init(embed_dim, rank, num_heads, b_var, a_var)
            elif init_scheme == 'shared_subspace': 
                Bs, As = shared_subspace_init(embed_dim, rank, num_heads, b_var, a_var)
            
            for i in range(num_heads): 
                name_A = name.replace('base_layer.weight', f'lora_A_{i}')
                name_B = name.replace('base_layer.weight', f'lora_B_{i}') 
                
                model.state_dict()[name_A].data.copy_(As[i])
                model.state_dict()[name_B].data.copy_(Bs[i])