import torch
import numpy as np

from pdb import set_trace as pds

def optimized_pytorch_toeplitz(V):
    """
    Creates the Toeplitz matrix for each row in V, optimized for memory usage.
    
    INPUT:
    V: torch tensor (batch_size x d)
    
    OUTPUT:
    T: (batch_size x d x d)
    """
    batch_size, d = V.shape
    device = V.device
    
    # Create indices for the Toeplitz matrix
    row_indices = torch.arange(d, device=device).unsqueeze(0).expand(d, -1)
    col_indices = torch.arange(d, device=device).unsqueeze(1).expand(-1, d)
    indices = row_indices - col_indices
    
    # Use advanced indexing to create the Toeplitz matrix
    T = V[:, indices.clamp(min=0)]
    
    T = T.permute((0,2,1))
    
    return T

def roll_columns_after_largest_norm_parallel(x, k):
    num_heads, n, _ = x.shape  # assume shape is [num_heads, n, n]
    
    # Calculate the norm of each column for all heads
    column_norms = torch.norm(x, dim=1)  # shape: [num_heads, n]
    
    # Find indices of k/2 columns with largest norm for each head
    _, top_indices = torch.topk(column_norms, k=n-k, dim=1)  # shape: [num_heads, n-k]
    
    # Sort the indices for each head
    top_indices, _ = torch.sort(top_indices, dim=1)
    
    # Create a copy of the input tensor to modify
    result = x.clone()
    
    # Iterate over heads and top indices
    for head in range(num_heads):
        for idx in top_indices[head]:
            if idx < n - 1:
                next_col = idx + 1
                result[head, :, next_col] = torch.roll(x[head, :, idx], shifts=1, dims=0)

    return result, top_indices

seed = 42  # Set a specific seed for reproducibility
# Set the seed before generating the random tensor
torch.manual_seed(seed)
np.random.seed(seed)

def main():
    # Example usage
    num_heads = 2
    n = 5
    k = 4
    
    # Generate a 3D tensor
    x = torch.randn(num_heads, n, n)

    print("Original tensor shape:", x.shape)
    print("heads of original tensor:")
    print(x)

    result, top_indices = roll_columns_after_largest_norm_parallel(x, k)

    print("\nTop indices for each head:")
    print(top_indices)

    print("\nModified tensor shape:", result.shape)
    print("heads of modified tensor:")
    print(result)

    # # Verification for the first head
    # print("\nVerification of changes for the first head:")
    # for idx in top_indices[0]:
    #     if idx < n - 1:
    #         print(f"Column {idx} (original):")
    #         print(x[0, :, idx])
    #         print(f"Column {idx+1} (modified, should be rolled version of column {idx}):")
    #         print(result[0, :, idx+1])
    #         print()


    attn_weights = x.unsqueeze(0)
    seq_len = attn_weights.shape[-1]
    # attn_weights_conv = pytorch_toeplitz(attn_weights[0,:,:,k-1])[:,:,:seq_len-k+1]
    attn_weights_conv = optimized_pytorch_toeplitz(attn_weights[0,:,:,k-1])[:,:,:seq_len-k+1]
    attn_weights = torch.cat([attn_weights[0,:,:,:k-1], attn_weights_conv], dim = -1).unsqueeze(0)
    print("Modified attn_weights shape:", attn_weights.shape)
    print(f"attn_weights: \n{attn_weights}")



if __name__ == "__main__":
    main()
