import torch
from itertools import permutations

def create_equi_test(F):
    def equi_test(W, symm_blocks):
        index1 = 0
        errors = [0]
        for block in symm_blocks:
            for perm in permutations(range(block[0])):
                indices = torch.tensor(range(W.shape[1]))
                indices[index1:index1+block[0]] = torch.tensor(perm)+index1
                error = F(W[:,indices],W).detach().numpy()
    #             errors[-1] = max(errors[-1],error) 
                errors[-1] += error
            index1 += block[0]
            errors.append(0)
        return errors
    return equi_test

def create_equi_grad(F):
    def equi_grad(W, symm_blocks, params):
        index1 = 0
        errors = 0
        equi_grad = []
        for block in symm_blocks:
            for perm in permutations(range(block[0])):
                indices = torch.tensor(range(W.shape[1]))
                indices[index1:index1+block[0]] = torch.tensor(perm)+index1
                error = F(W[:,indices],W)
                errors += error
            index1 += block[0]
            equi_grad.append(torch.autograd.grad(errors,params)[0].detach().clone())
            errors = 0
        return equi_grad
    return equi_grad