import torch 

def count_no_global_repeated_layers(dataset, num_lay=5, num_mat=5):
    '''
    Given the material dataset, return the number of materials which doesn't have repeated layers
    '''

    # Reshape all matrices at once: (N, 25) → (N, 5, 5)
    # Round to nearest integer
    reshaped = dataset[:, :num_lay * num_mat].reshape(-1, num_lay, num_mat).round()

    # Compute column-wise sum for all matrices
    column_sums = reshaped.sum(dim=1)  # Shape: (N, 5)

    # Get the max column sum per matrix
    max_values = column_sums.max(dim=1).values  # Shape: (N,)

    # Boolean mask for matrices where max column sum < 2 (i.e a layer is not repeated)
    mask = (max_values < 2)

    # Return mask in order to obtain a subset of the dataset where this condition applies
    return mask.sum().item(), mask


def count_no_global_repeated_layers_2(dataset, num_lay=5, num_mat=5):
    '''
    Given the material dataset, return the number of materials which doesn't have repeated layers
    '''
    indexed = torch.argmax(dataset[:,:num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2).reshape(-1, num_lay)
    unique_counts = torch.tensor([len(torch.unique(row)) for row in indexed])

    mask = unique_counts == num_mat

    return mask.sum().item(), mask


def count_no_consecutive_repeated_layers(dataset, num_lay=5, num_mat=5):
    '''
    Given a material dataset, return the number of materials which doesn't have two consecutive layers repeated
    '''

    # Convert the dataset from one-hot encoding (of length 5) to index notation
    # Es. [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1]] to [[1, 4]]
    # Obtain a tensor BATCH x NUM_LAY (since the materials are now a single scalar)
    indexed = torch.argmax(dataset[:,:num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2).reshape(-1, num_lay)

    # Compare the tensor up to the latest - 1 column with the tensor starting from the first column
    # In this way compare tensor[i][j] with tensor[i][j+1]
    # If they are equal, this means two adjacent materials are equal
    mask = torch.any(indexed[:,:num_lay - 1] == indexed[:, 1:], dim=(1))

    inverted_mask = mask == False

    return inverted_mask.sum().item(), inverted_mask


def count_number_palindrome_materials(dataset, num_lay=5, num_mat=5, up_to=2):
    '''
    Given a material dataset, return the number of materials whose first [up_to] and last [up_to] layers are palindrome
    '''
    indexed = torch.argmax(dataset[:,:num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2).reshape(-1, num_lay)
    
    mask = torch.ones(indexed.shape[0], dtype=torch.bool).to(indexed.device)

    for i in range(up_to):
        mask &= indexed[:, i] == indexed[:, num_lay - 1 - i]

    return mask.sum().item(), mask


def count_number_metamat_use_all_materials(dataset, num_lay=5, num_mat=5):
    '''
    Given a material dataset, return the number of metamaterials which use all available materials (i.e. all [num_mat])
    '''
    # Tensor N x NUM_LAY
    indexed = torch.argmax(dataset[:,:num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2).reshape(-1, num_lay)

    required_numbers = set(range(num_mat))
    row_contains_all = torch.tensor([set(row.tolist()) == required_numbers for row in indexed])

    return row_contains_all.sum().item(), row_contains_all


def __has_pattern__(row, pattern_len=2):
    for i in range(row.shape[0] - pattern_len):
        if row[i] != row[i + pattern_len]:
            return torch.tensor(False).to(row.device)
        
        for step in range(1, pattern_len):
            if row[i] == row[i + step]:
                return torch.tensor(False).to(row.device)
        
    return torch.tensor(True).to(row.device)
    
def count_number_hyperbolic_materials(dataset, num_lay=5, num_mat=5, pattern_len=2):
    '''
    Given a material dataset, return the number of materials whose layers follows an hyperbolic pattern, of length [pattern_len]
    '''
    indexed = torch.argmax(dataset[:,:num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2).reshape(-1, num_lay)

    mask = torch.stack([__has_pattern__(row, pattern_len) for row in indexed]) # type: ignore

    return mask.sum().item(), mask