import torch
import torch.nn.functional as F

def generate_antisymmetric_matrices_batch(batch_size, n):
    """Generate a batch of random 2nd-order antisymmetric tensors (matrices) of size n x n."""
    # Create a batch of random matrices
    A = torch.randn(batch_size, n, n)
    
    # Antisymmetrize: A_antisym = (A - A^T) / 2
    A_antisym = (A - A.transpose(1, 2)) / 2
    
    return A_antisym

def contract_over_pairs_batch_Tij_Type1(batch):
    r""" For each antisymmetric tensor T in the batch, we calculate
        f(T)_i = \sum_{j: j \neq i} (i+j)*T_ij
    """
    batch_size, n, _, = batch.shape  # Batch size and tensor size (n x n)
    
    # Initialize a tensor to store the result of the contraction for each tensor in the batch
    new_results = torch.zeros(batch_size, n)
    
    # Perform the contraction over each tensor in the batch
    for b in range(batch_size):
        for i in range(n):
            result = 0.0
            for j in range(n):
                result += (i + j) * batch[b, i, j]
            new_results[b, i] = result
    #return new_results
    return F.relu(new_results)

def contract_over_pairs_batch_Tij_Type2(batch):
    r""" For each antisymmetric tensor T in the batch, we calculate
        f(T)_i = \sum_{j: j \neq i} T_ij
    """

    """
    batch_size, n, _, = batch.shape  # Batch size and tensor size (n x n)
    
    # Initialize a tensor to store the result of the contraction for each tensor in the batch
    new_results = torch.zeros(batch_size, n)
    
    # Perform the contraction over each tensor in the batch
    for b in range(batch_size):
        for i in range(n):
            result = 0.0
            for j in range(n):
                result += batch[b, i, j]
            new_results[b, i] = result
    """

    new_results = batch.sum(dim=2) 
    #return new_results
    return F.relu(new_results)

def generate_equivariant_target_data(batch):
    y = (batch ** 2).sum(dim=-1)   # target: row-wise squared sums
    return y


def generate_toy_data(num_tuples, n, seed = None):
    ## Note I have gone for the standard f(T)_i = \sum_{j: j \neq i} T_ij here
    if seed is not None:
        torch.manual_seed(seed)
    antisymm_matrices = generate_antisymmetric_matrices_batch(num_tuples, n)
    traces = contract_over_pairs_batch_Tij_Type2(antisymm_matrices)
    #traces = generate_equivariant_target_data(antisymm_matrices)
    return antisymm_matrices, traces

# 3D Antisymmetric Tensors
def generate_antisymmetric_matrices_batch_third(batch_size, n):
    """Generate a batch of random 3rd-order antisymmetric tensors of size n x n x n.
    
    This vectorized implementation is more efficient than the loop-based version.
    
    Args:
        batch_size: Number of tensors to generate
        n: Size of each dimension of the tensor
        
    Returns:
        Tensor of shape (batch_size, n, n, n) that is antisymmetric in all indices
    """
    # Create a tensor to hold all results
    result = torch.zeros(batch_size, n, n, n)
    
    # Create indices for all i < j < k combinations
    indices = []
    for i in range(n):
        for j in range(i+1, n):
            for k in range(j+1, n):
                indices.append((i, j, k))
    
    if not indices:  # Handle case where n is too small
        return result
        
    # Generate all random values at once
    values = torch.randn(batch_size, len(indices))
    
    # Assign values to the tensor
    for idx, (i, j, k) in enumerate(indices):
        # Even permutations (same sign)
        result[:, i, j, k] = values[:, idx]
        result[:, j, k, i] = values[:, idx]
        result[:, k, i, j] = values[:, idx]
        
        # Odd permutations (opposite sign)
        result[:, i, k, j] = -values[:, idx]
        result[:, j, i, k] = -values[:, idx]
        result[:, k, j, i] = -values[:, idx]
    
    return result

import torch

def contract_over_pairs_batch_Tijk(batch):
    r""" For each antisymmetric tensor T in the batch, we calculate
        f(T)_ij = \sum_{k: k \neq i, k \neq j} T_ijk
        
    Args:
        batch: Tensor of shape (batch_size, n, n, n) where each tensor is antisymmetric
        
    Returns:
        Tensor of shape (batch_size, n, n) containing the contraction result
    """
    batch_size, n, _, _ = batch.shape
    result = torch.zeros(batch_size, n, n)
    
    # Create indices for each dimension
    i_indices = torch.arange(n)
    j_indices = torch.arange(n)
    k_indices = torch.arange(n)
    
    # Create meshgrid of indices
    i_grid, j_grid, k_grid = torch.meshgrid(i_indices, j_indices, k_indices, indexing='ij')
    
    # Create masks for the conditions: k != i, k != j
    mask_k_neq_i = (k_grid != i_grid)
    mask_k_neq_j = (k_grid != j_grid)
    
    # Combine the masks
    combined_mask = mask_k_neq_i & mask_k_neq_j
    
    # Convert boolean mask to float
    float_mask = combined_mask.float()
    
    # Use einsum to perform the summation over k where conditions are met
    # 'bijk,ijk->bij' means:
    # - Take batch[b,i,j,k] and multiply by float_mask[i,j,k]
    # - Sum over index k to get result[b,i,j]
    result = torch.einsum('bijk,ijk->bij', batch, float_mask)
    
    return result

def generate_toy_data_3D(num_tuples, n, seed = None):
    ## Note I have gone for the standard f(T)_i = \sum_{k: k \neq i,j} T_ijk here
    if seed is not None:
        torch.manual_seed(seed)
    antisymm_matrices = generate_antisymmetric_matrices_batch_third(num_tuples, n)
    traces = contract_over_pairs_batch_Tijk(antisymm_matrices)
    traces = F.tanh(traces)
    return antisymm_matrices, traces

import torch

def batch_pfaffian_abs(batch):
    """
    Compute the absolute value of the Pfaffian for a batch of antisymmetric matrices.
    
    Args:
        batch: Tensor of shape (batch_size, n, n), antisymmetric matrices
    
    Returns:
        Tensor of shape (batch_size,) with absolute Pfaffian values
    """
    dets = torch.linalg.det(batch)
    eps = 1e-12
    pfaffian_abs = torch.sqrt(torch.abs(dets) + eps)
    return pfaffian_abs

def generate_toy_data_pfaff(num_tuples, n, seed = None):
    ## Note I have gone for the standard f(T)_i = \sum_{j: j \neq i} T_ij here
    if seed is not None:
        torch.manual_seed(seed)
    antisymm_matrices = generate_antisymmetric_matrices_batch(num_tuples, n)
    abs_pfaff = batch_pfaffian_abs(antisymm_matrices)
    #traces = generate_equivariant_target_data(antisymm_matrices)
    return antisymm_matrices, abs_pfaff


if __name__ == "__main__":
    mat, traces = generate_toy_data(5, 3)
    print(mat, traces)
    print(mat.shape) 
    print(traces.shape)

    mat, traces = generate_toy_data_3D(1, 4)
    print(mat)
    print(traces)
    print(mat.shape) 
    print(traces.shape)
