import torch
import matplotlib.pyplot as plt
import itertools

class AbstractTwoLayerModel(torch.nn.Module):
    """
    Abstract class for two layer models
    """
    def __init__(self, activation = 'relu') -> None:
        super().__init__()
        if activation == 'relu':
            self.activation = torch.relu
        elif activation == 'erf':
            self.activation = torch.erf
        else:
            raise ValueError("Unsupported activation")

    def forward(self, x):
        W = self.create_W()
        hidden = W@x
        hidden = self.activation(hidden)
        return torch.sum(hidden, dim=0)
    
    def create_W(self):
        raise NotImplementedError
    
    def get_dtype(self):
        for param in self.parameters():
            return param.dtype
    
    def visualize_weights(self, axes = None,perm = None, cmap = 'RdBu', vmax = None):
        W = self.create_W()
        if perm!=None:
            W = W[perm,:]
        if vmax==None:
            vmax = torch.max(torch.abs(W))
        if axes==None:
            return plt.imshow(W.detach().numpy(),cmap=cmap, vmax = vmax, vmin=-vmax)
        return axes.imshow(W.detach().numpy(),cmap=cmap, vmax = vmax, vmin=-vmax)
    

class TwoLayerModel(AbstractTwoLayerModel):
    def __init__(self, n_input ,n_hidden, W = None, activation = 'relu') -> None:
        super().__init__(activation = activation)
        if W == None:
            W = torch.randn(n_hidden,n_input)
        assert W.shape == torch.Size([n_hidden,n_input])
        self.W = torch.nn.Parameter(W)
    
    def create_W(self):
        return self.W


class EquiBlocks(AbstractTwoLayerModel):
    def __init__(self, n_input, symm_blocks, invar_nodes, equi_params = None, invar_params = None, activation='relu', normed=False) -> None:
        super().__init__(activation = activation)
        if equi_params == None or invar_params == None:
            total_nodes, equi_params, invar_params = self.gen_model(n_input, symm_blocks, invar_nodes)
        else:
            total_nodes = sum([block[0]*block[1] for block in symm_blocks])+invar_nodes

        self.equi_params = torch.nn.ParameterList([torch.nn.Parameter(param) for param in equi_params])
        self.invar_params = torch.nn.Parameter(invar_params)

        self.n_input = n_input
        self.symm_blocks = symm_blocks
        self.invar_nodes = invar_nodes
        self.total_nodes = total_nodes
        self.normed = normed
    
    def create_W(self):
        return self.create_W_(self.n_input, self.symm_blocks, self.invar_nodes, self.total_nodes, self.equi_params, self.invar_params, normed = self.normed)
    
    # def change_dtype(self, dtype):
    #     for param in self.equi_params:
    #         param.dtype

    @staticmethod
    def gen_model(d, symm_blocks = None, invar_nodes = None , requires_grad = False):
        equi_params = []
        invar_params_n = d
        total_nodes = 0
        for block in symm_blocks:
            equi_params.append(torch.randn(block[1], 2, requires_grad = requires_grad))
            invar_params_n -= block[0]-1
            total_nodes += block[0]*block[1]
        total_nodes += invar_nodes
        invar_params = torch.randn(invar_nodes, invar_params_n, requires_grad = requires_grad)
        return total_nodes, equi_params, invar_params
    
    @staticmethod
    def create_W_(d, symm_blocks, invar_nodes, total_nodes, equi_params, invar_params, dtype = torch.float32, normed = False):
        W = torch.zeros(total_nodes, d, dtype = dtype)
        expand = torch.zeros(invar_params.shape[1], d, dtype = dtype)
        index1, index2 = 0, 0
        for i, block in enumerate(symm_blocks):
            for params in equi_params[i]:
                params = params.type(dtype)
                if normed:
                    equi_block = torch.eye(block[0])*params[0]/torch.sqrt(torch.tensor(block[0]))+(1-torch.eye(block[0]))*params[1]/torch.sqrt(torch.tensor(block[0]**2-block[0]*1.0))
                else:
                    equi_block = torch.eye(block[0])*params[0]+(1-torch.eye(block[0]))*params[1]
                W[index2:index2+block[0],index1:index1+block[0]] = equi_block
                index2 += block[0]
            expand[i,index1:index1+block[0]] += 1
            index1 += block[0]
        expand[i+1:,index1:] = torch.eye(d-index1)
        invar_params = invar_params.type(dtype)
        W[index2:,:] = invar_params@expand
        return W


class EquiTwoLayer(AbstractTwoLayerModel):
    def __init__(self, n_input, symm_blocks, invar_nodes, equi_params = None, invar_params = None, activation='relu', normed = False, dtype = torch.float32) -> None:
        super().__init__(activation = activation)
        if equi_params == None or invar_params == None:
            total_nodes, equi_params, invar_params = self.gen_model(n_input, symm_blocks, invar_nodes)
        else:
            total_nodes = sum([block[0]*block[1] for block in symm_blocks])+invar_nodes

        self.equi_params = torch.nn.ParameterList([torch.nn.Parameter(param) for param in equi_params])
        self.invar_params = torch.nn.Parameter(invar_params)

        self.n_input = n_input
        self.symm_blocks = symm_blocks
        self.invar_nodes = invar_nodes
        self.total_nodes = total_nodes
        self.normed = normed

        self.dtype = dtype
    
    def create_W(self):
        return self.create_W_(self.n_input, self.symm_blocks, self.invar_nodes, self.total_nodes, self.equi_params, self.invar_params, self.normed, self.dtype)

    @staticmethod
    def gen_model(d, symm_blocks = None, invar_nodes = None , requires_grad = False):
        # print(d,symm_blocks)
        equi_params = []
        invar_params_n = d - sum([block[0]-1 for block in symm_blocks])
        total_nodes = 0
        for block in symm_blocks:
            equi_params.append(torch.randn(block[1], invar_params_n+1, requires_grad = requires_grad))
            total_nodes += block[0]*block[1]
        total_nodes += invar_nodes
        invar_params = torch.randn(invar_nodes, invar_params_n, requires_grad = requires_grad)
        return total_nodes, equi_params, invar_params
    
    @staticmethod
    def create_W_(d, symm_blocks, invar_nodes, total_nodes, equi_params, invar_params, normed=False, dtype = torch.float32):
        W = torch.zeros(total_nodes, d, dtype = dtype)
        expand = torch.zeros(invar_params.shape[1], d, dtype = dtype)
        index1, index2 = 0, 0
        for i, block in enumerate(symm_blocks):
            expand[i,index1:index1+block[0]] += 1
            index1 += block[0]
        expand[i+1:,index1:] = torch.eye(d-index1)

        index1, index2 = 0, 0
        for i, block in enumerate(symm_blocks):
            for params in equi_params[i]:
                params = params.type(dtype)
                const_parts = params[:-1].clone()
                equi_chunk = (const_parts@expand)[None,:].expand(block[0],-1).clone()
                if normed:
                    equi_chunk[:,index1:index1+block[0]] += torch.eye(block[0])*(params[-1]*torch.sqrt(torch.tensor(block[0]-1.))-const_parts[i])
                else:
                    equi_chunk[:,index1:index1+block[0]] += torch.eye(block[0])*(params[-1]-const_parts[i])
                W[index2:index2+block[0],:] = equi_chunk
                index2 += block[0]
            index1 += block[0]
        invar_params = invar_params.type(dtype)
        W[index2:,:] = invar_params@expand
        return W


class EquiRegSymm(AbstractTwoLayerModel):
    def __init__(self, d, N_t_reps, stacks=1, activation='relu'):
        super().__init__(activation)
        self.d = d
        self.N_t_reps = N_t_reps

        params = torch.randn(N_t_reps, stacks, d)
        self.weights = torch.nn.Parameter(params)

    def create_W(self):
        return self.create_W_(self.N_t_reps, self.d, self.weights)

    @staticmethod
    def create_W_(N_t_reps, d, params):
        W_list = []
        for param in params:
            for perm in itertools.permutations(range(d)):
                stacked_W = []
                for stack in param:
                    stacked_W.append(stack[torch.tensor(perm)])
                W_list.append(torch.cat(stacked_W))
        return torch.stack(W_list)





def proj_equi_form(W, symm_blocks = None, invar_nodes = None):
    d = W.shape[1]
    proj_W = torch.zeros_like(W)
    
    invar_params_n = d-sum([blocks[0]-1 for blocks in symm_blocks])
    expand = torch.zeros(invar_params_n, d, dtype = W.dtype)
    index1, index2 = 0, 0
    for i, blocks in enumerate(symm_blocks):
        p1 = torch.eye(blocks[0])
        p2 = 1-torch.eye(blocks[0])
        for block in range(blocks[1]):
            proj_W[index2:index2+blocks[0],index1:index1+blocks[0]] += torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*p1)*p1/blocks[0]
            proj_W[index2:index2+blocks[0],index1:index1+blocks[0]] += torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*p2)*p2/(blocks[0]**2-blocks[0])
            index2 += blocks[0]
        expand[i,index1:index1+blocks[0]] += 1
        index1 += blocks[0]
    expand[i+1:,index1:] = torch.eye(d-index1)
    proj_W[index2:,:] = W[index2:,:]@expand.T/torch.sum(expand,dim=1)@expand
    return proj_W

def proj_equi_params(W, symm_blocks = None, invar_nodes = None, requires_grad = False):
    W = W.clone().detach()
    d = W.shape[1]
    
    invar_params_n = d-sum([blocks[0]-1 for blocks in symm_blocks])
    expand = torch.zeros(invar_params_n, d, dtype = W.dtype)
    index1, index2 = 0, 0
    equi_params = []
    for i, blocks in enumerate(symm_blocks):
        p1 = torch.eye(blocks[0])
        p2 = 1-torch.eye(blocks[0])
        equi_params.append(torch.zeros(blocks[1], 2))
        for block in range(blocks[1]):
            equi_params[-1][block,0] = torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*p1)/blocks[0]
            equi_params[-1][block,1] = torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*p2)/(blocks[0]**2-blocks[0])
            index2 += blocks[0]
        if requires_grad:
            equi_params[-1].requires_grad_()
        expand[i,index1:index1+blocks[0]] += 1
        index1 += blocks[0]
    expand[i+1:,index1:] = torch.eye(d-index1)
    invar_params = W[index2:,:]@expand.T/torch.sum(expand,dim=1)
    if requires_grad:
        invar_params.requires_grad_()
    return W.shape[0], equi_params, invar_params

def proj_equi_form_full(W, symm_blocks = None, invar_nodes = None):
    d = W.shape[1]
    proj_W = torch.zeros_like(W)
    
    invar_params_n = d-sum([blocks[0]-1 for blocks in symm_blocks])
    expand = torch.zeros(invar_params_n, d, dtype = W.dtype)
    index1, index2 = 0, 0
    for i, block in enumerate(symm_blocks):
        expand[i,index1:index1+block[0]] += 1
        index1 += block[0]
    expand[i+1:,index1:] = torch.eye(d-index1)

    index1, index2 = 0, 0
    for i, blocks in enumerate(symm_blocks):
        expand_basis = expand[:,None,:].expand([-1,blocks[0],-1])
        norms = torch.sqrt(torch.sum(expand_basis,dim=(1,2)))[:,None,None]
        expand_basis = expand_basis/norms
        p1 = blocks[0]*torch.eye(blocks[0]) - 1
        for block in range(blocks[1]):
            proj_W[index2:index2+blocks[0],:] += torch.einsum('ijk,jk,iab -> ab',expand_basis, W[index2:index2+blocks[0],:],expand_basis)
            proj_W[index2:index2+blocks[0],index1:index1+blocks[0]] += torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*p1)*p1/torch.sum(p1*p1)
            index2 += blocks[0]
        index1 += blocks[0]
    proj_W[index2:,:] = W[index2:,:]@expand.T/torch.sum(expand,dim=1)@expand
    return proj_W

def proj_equi_params_full(W, symm_blocks = None, invar_nodes = None, requires_grad = False):
    d = W.shape[1]
    
    invar_params_n = d-sum([blocks[0]-1 for blocks in symm_blocks])
    expand = torch.zeros(invar_params_n, d, dtype = W.dtype)
    index1, index2 = 0, 0
    for i, block in enumerate(symm_blocks):
        expand[i,index1:index1+block[0]] += 1
        index1 += block[0]
    expand[i+1:,index1:] = torch.eye(d-index1)

    index1, index2 = 0, 0
    equi_params = []
    
    for i, blocks in enumerate(symm_blocks):
        expand_basis = expand[:,None,:].expand([-1,blocks[0],-1])
        norms = torch.sum(expand_basis,dim=(1,2))[:,None,None]
        expand_basis = expand_basis/norms
        p1 = torch.eye(blocks[0])
        equi_params.append(torch.zeros(blocks[1], invar_params_n+1))
        # print(expand_basis.shape,invar_params_n)
        for block in range(blocks[1]):
            equi_params[-1][block,:-1] = torch.einsum('ijk,jk -> i',expand_basis, W[index2:index2+blocks[0],:])
            equi_params[-1][block,-1] = torch.sum(W[index2:index2+blocks[0],index1:index1+blocks[0]]*torch.eye(blocks[0]))/blocks[0]
            equi_params[-1][block,i] = (equi_params[-1][block,i]*blocks[0]-equi_params[-1][block,-1])/(blocks[0]-1)
            index2 += blocks[0]
        index1 += blocks[0]

    invar_params = W[index2:,:]@expand.T/torch.sum(expand,dim=1)
    if requires_grad:
        invar_params.requires_grad_()
    return W.shape[0], equi_params, invar_params

def gen_models_OG(n,k,d):
    W = torch.randn(n,d, requires_grad = True)
    
    V = torch.randn(k,d)
    return W,V