import torch
from pdb import set_trace as pds

def optimized_pytorch_toeplitz(V):
    """
    Creates the Toeplitz matrix for each row in V, optimized for memory usage.
    
    INPUT:
    V: torch tensor (batch_size x d)
    
    OUTPUT:
    T: (batch_size x d x d)
    """
    batch_size, d = V.shape
    device = V.device
    
    # Create indices for the Toeplitz matrix
    row_indices = torch.arange(d, device=device).unsqueeze(0).expand(d, -1)
    col_indices = torch.arange(d, device=device).unsqueeze(1).expand(-1, d)
    indices = row_indices - col_indices
    
    # Use advanced indexing to create the Toeplitz matrix
    T = V[:, indices.clamp(min=0)]
    
    T = T.permute((0,2,1))
    
    return T

def pytorch_toeplitz(V):
    '''
    It creates the Toeplitz matrix for each row in V.

    INPUT:
    V: torch tensor (batch_size x d)

    OUTPUT:
    T: (batch_size x d x d)

    EXAMPLE:
    V = torch.tensor([[1, 0.5, 0.1], [1.3, 0.9, -0.1]])
    T = pytorch_toeplitz(V)
    print(T.shape)
    print(T[0])
    print(T[1])
    '''
        
    d = V.shape[1]
    A = V.unsqueeze(1).unsqueeze(2)
    A_nofirst_flipped = torch.flip(A[:, :, :, 1:], dims=[3]) 
    A_concat = torch.concatenate([A_nofirst_flipped, A], dim=3) 
    unfold = torch.nn.Unfold(kernel_size=(1, d))
    T = unfold(A_concat)
    T = torch.flip(T, dims=[2])

    return T


def roll_columns_after_largest_norm(matrix, k):
    n = matrix.shape[1]  # number of columns
    
    # Calculate the norm of each column
    column_norms = torch.norm(matrix, dim=0)
    
    # Find indices of k/2 columns with largest norm
    _, top_indices = torch.topk(column_norms, k=k//2)

    print(f"top_indices: {top_indices}")
    
    # Sort the indices to process them in order
    top_indices = torch.sort(top_indices).values

    print(f"top_indices: {top_indices}")
    
    # Create a copy of the matrix to modify
    result = matrix.clone()
    
    for idx in top_indices:
        # Ensure we're not at the last column
        if idx < n - 1:
            # Roll the column after the current one
            next_col = (idx + 1) % n
            result[:, next_col] = torch.roll(matrix[:, idx], shifts=1, dims=0)
    
    return result

seed = 42
torch.manual_seed(seed)
# np.random.seed(seed)

def main():
    # x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3).T.float()
    # y = x[:,0].unsqueeze(0)
    # pds()
    # z = pytorch_toeplitz(y)

    # z2 = optimized_pytorch_toeplitz(y)




    # Example usage
    n = 5
    k = 4
    matrix = torch.randn(n, n)
    print("Original matrix:")
    print(matrix)

    result = roll_columns_after_largest_norm(matrix, k)
    print("\nModified matrix:")
    print(result)

    pds()




if __name__ == "__main__":
    main()

