import torch
import sys

def reshape_attention_matrix(att_matrix):
    """
    Reshape attention matrix from (B, num_heads, N, C // num_heads) to (B, N, C).

    :param att_matrix: Attention matrix with shape (B, num_heads, N, C // num_heads)
    :param num_heads: Number of attention heads
    :return: Reshaped attention matrix with shape (B, N, C)
    """
    B, _, N, _ = att_matrix.shape
    # Transpose to (B, N, num_heads, C // num_heads)
    att_matrix = att_matrix.transpose(1, 2)
    # Reshape to (B, N, C)
    return att_matrix.reshape(B, N, -1)

def reverse_reshape_attention_matrix(reshaped_att_matrix, num_heads):
    """
    Reverse the reshaping of an attention matrix from (B, N, C) to (B, num_heads, N, C // num_heads).

    :param reshaped_att_matrix: Reshaped attention matrix with shape (B, N, C)
    :param num_heads: Number of attention heads
    :return: Original attention matrix with shape (B, num_heads, N, C // num_heads)
    """
    B, N, C = reshaped_att_matrix.shape
    # Ensure C is divisible by num_heads to prevent errors
    if C % num_heads != 0:
        raise ValueError("C must be divisible by num_heads")
    
    # Reshape back to (B, N, num_heads, C // num_heads)
    att_matrix = reshaped_att_matrix.reshape(B, N, num_heads, C // num_heads)
    # Transpose to (B, num_heads, N, C // num_heads)
    att_matrix = att_matrix.permute(0, 2, 1, 3)
    
    return att_matrix

def energy_alignment(Q_t, K_s, V_s, W_K, align_strength=0.01, verbose=False):
    """
    Args:
        Q_t (Tensor, (B, num_heads, N, C//num_heads)): query of teacher
        K_s (Tensor, (B, num_heads, N, C//num_heads)): key of student
        V_s (Tensor, (B, num_heads, N, C//num_heads)): value of student
    """    
    gamma_attn = 0.5  # Define these as per your requirements
    gamma_reg = 0.5
    beta = 0.5
    alignment_strength = align_strength

    num_heads = Q_t.shape[1]
    Q = reshape_attention_matrix(Q_t)
    K = reshape_attention_matrix(K_s)

    # Attention component
    attn_component = gamma_attn * (torch.softmax(beta * torch.matmul(K, Q.transpose(-1, -2)), dim=2) @ Q @ W_K.T)
    # Regularization component
    reg_component = gamma_reg * (torch.softmax(0.5 * torch.diagonal(torch.matmul(K, K.transpose(-1, -2)), dim1=-2, dim2=-1), dim=0).unsqueeze(1) @ K @ W_K.T)
    # Combine for final loss term
    alignment = attn_component - reg_component
    # reverse reshape energy alignment back to (B, num_heads, N, C//num_heads)
    alignment = reverse_reshape_attention_matrix(alignment, num_heads)
    V_s -= alignment_strength * alignment
    
    if verbose:
        return V_s, '{:.4f}'.format(alignment.mean().detach().item())
    
    return V_s

def parameters_string(module):
    lines = [
        "",
        "List of model parameters:",
        "=========================",
    ]

    row_format = "{name:<40} {shape:>20} ={total_size:>12,d}"
    params = list(module.named_parameters())
    for name, param in params:
        if param.requires_grad:
            lines.append(row_format.format(
                name=name,
                shape=" * ".join(str(p) for p in param.size()),
                total_size=param.numel()
            ))
    lines.append("=" * 75)
    lines.append(row_format.format(
        name="all parameters",
        shape="sum of above",
        total_size=sum(int(param.numel()) for name, param in params)
    ))
    lines.append("")
    return "\n".join(lines)


def assert_exactly_one(lst):
    assert sum(int(bool(el)) for el in lst) == 1, ", ".join(str(el)
                                                            for el in lst)
 
def export(fn):
    mod = sys.modules[fn.__module__]
    if hasattr(mod, '__all__'):
        mod.__all__.append(fn.__name__)
    else:
        mod.__all__ = [fn.__name__]
    return fn


def parameter_count(module):
    return sum(int(param.numel()) for param in module.parameters())
