
"""
Here are all partitions of the set {1, 2, 3, 4, 5}:
Partition with 1 block:

{(1, 2, 3, 4, 5)}

Partitions with 2 blocks:

{(1), (2, 3, 4, 5)}
{(2), (1, 3, 4, 5)}
{(3), (1, 2, 4, 5)}
{(4), (1, 2, 3, 5)}
{(5), (1, 2, 3, 4)}
{(1, 2), (3, 4, 5)}
{(1, 3), (2, 4, 5)}
{(1, 4), (2, 3, 5)}
{(1, 5), (2, 3, 4)}
{(2, 3), (1, 4, 5)}
{(2, 4), (1, 3, 5)}
{(2, 5), (1, 3, 4)}
{(3, 4), (1, 2, 5)}
{(3, 5), (1, 2, 4)}
{(4, 5), (1, 2, 3)}

Partitions with 3 blocks:

{(1), (2), (3, 4, 5)}
{(1), (3), (2, 4, 5)}
{(1), (4), (2, 3, 5)}
{(1), (5), (2, 3, 4)}
{(2), (3), (1, 4, 5)}
{(2), (4), (1, 3, 5)}
{(2), (5), (1, 3, 4)}
{(3), (4), (1, 2, 5)}
{(3), (5), (1, 2, 4)}
{(4), (5), (1, 2, 3)}
{(1), (2, 3), (4, 5)}
{(1), (2, 4), (3, 5)}
{(1), (2, 5), (3, 4)}
{(2), (1, 3), (4, 5)}
{(2), (1, 4), (3, 5)}
{(2), (1, 5), (3, 4)}
{(3), (1, 2), (4, 5)}
{(3), (1, 4), (2, 5)}
{(3), (1, 5), (2, 4)}
{(4), (1, 2), (3, 5)}
{(4), (1, 3), (2, 5)}
{(4), (1, 5), (2, 3)}
{(5), (1, 2), (3, 4)}
{(5), (1, 3), (2, 4)}
{(5), (1, 4), (2, 3)}

Partitions with 4 blocks:

{(1), (2), (3), (4, 5)}
{(1), (2), (4), (3, 5)}
{(1), (2), (5), (3, 4)}
{(1), (3), (4), (2, 5)}
{(1), (3), (5), (2, 4)}
{(1), (4), (5), (2, 3)}
{(2), (3), (4), (1, 5)}
{(2), (3), (5), (1, 4)}
{(2), (4), (5), (1, 3)}
{(3), (4), (5), (1, 2)}

Partition with 5 blocks:

{(1), (2), (3), (4), (5)}
"""

import torch
import torch.nn as nn
from .Perm22Model import PermEquiv22Linear

class PermEquiv32Linear(nn.Module):
    def __init__(self, n):
        super(PermEquiv32Linear, self).__init__()
        
        # Define the trainable parameter vector of 52 elements
        self.param_vector = nn.Parameter(torch.randn(52))  
        #self.param_vector = torch.ones(52)
        self.n = n
     
    def forward(self, T):
        #print("T: ", T, "Shape of T:", T.shape)
        num_batches, _, _, _ =  T.shape
        res = torch.zeros(num_batches, 52, self.n, self.n)

        # Create arange_n once for reuse
        arange_n = torch.arange(self.n)

        # Feature 0: res[batch, 0, i, i] = T[batch, i, i, i]    # {(1, 2, 3, 4, 5)}
        res[:, 0, arange_n, arange_n] = T[:, arange_n, arange_n, arange_n]

        # Feature 1: res[batch, 1, i, j] += T[batch, j, j, j]   # {(1), (2, 3, 4, 5)}
        res[:, 1] = torch.einsum('bjjj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n) 

        # Feature 2: res[batch, 2, i, j] += T[batch, i, i, i]        # {(2), (1, 3, 4, 5)}
        res[:, 2] = torch.einsum('biii->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n) 

        # Feature 3: res[batch, 3, i, i] += T[batch, j, i, i]        # {(3), (1, 2, 4, 5)}
        res[:, 3, arange_n, arange_n] = torch.einsum('bjii->bi', T)

        # Feature 4: res[batch, 4, i, i] += T[batch, i, j, i]        # {(4), (1, 2, 3, 5)}
        res[:, 4, arange_n, arange_n] = torch.einsum('biji->bi', T)

        # Feature 5: res[batch, 5, i, i] += T[batch, i, i, j]        # {(5), (1, 2, 3, 4)}
        res[:, 5, arange_n, arange_n] = torch.einsum('biij->bi', T)

        # Feature 6: res[batch, 6, i, i] += T[batch, j, j, j]        # {(1, 2), (3, 4, 5)}
        res[:, 6, arange_n, arange_n] = torch.einsum('biii->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 7: res[batch, 7, i, j] += T[batch, i, j, j]        # {(1, 3), (2, 4, 5)}
        res[:, 7] = torch.einsum('bijj->bij', T)

        # Feature 8: res[batch, 8, i, j] += T[batch, j, i, j]        # {(1, 4), (2, 3, 5)} 
        res[:, 8] = torch.einsum('bjij->bij', T)

        # Feature 9: res[batch, 9, i, j] += T[batch, j, j, i]        # {(1, 5), (2, 3, 4)}
        res[:, 9] = torch.einsum('bjji->bij', T)

        # Feature 10: res[batch, 10, i, j] += T[batch, j, i, i]       # {(2, 3), (1, 4, 5)}
        res[:, 10] = torch.einsum('bjii->bij', T)

        # Feature 11: res[batch, 11, i, j] += T[batch, i, j, i]       # {(2, 4), (1, 3, 5)}
        res[:, 11] = torch.einsum('biji->bij', T)

        # Feature 12: res[batch, 12, i, j] += T[batch, i, i, j]       # {(2, 5), (1, 3, 4)}
        res[:, 12] = torch.einsum('biij->bij', T)

        # Feature 13: res[batch, 13, i, i] += T[batch, j, j, i]       # {(3, 4), (1, 2, 5)}
        res[:, 13, arange_n, arange_n] = torch.einsum('bjji->bi', T)

        # Feature 14: res[batch, 14, i, i] += T[batch, j, i, j]       # {(3, 5), (1, 2, 4)}
        res[:, 14, arange_n, arange_n] = torch.einsum('bjij->bi', T)

        # Feature 15: res[batch, 15, i, i] += T[batch, i, j, j]       # {(4, 5), (1, 2, 3)}
        res[:, 15, arange_n, arange_n] = torch.einsum('bijj->bi', T)

        # Feature 16: res[batch, 16, i, j] += T[batch, k, k, k]      # {(1), (2), (3, 4, 5)}
        res[:, 16] = torch.einsum('bkkk->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        # Feature 17: res[batch, 17, i, j] += T[batch, k, j, j]      # {(1), (3), (2, 4, 5)}
        res[:, 17] = torch.einsum('bkjj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 18: res[batch, 18, i, j] += T[batch, j, k, j]      # {(1), (4), (2, 3, 5)}
        res[:, 18] = torch.einsum('bjkj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 19: res[batch, 19, i, j] += T[batch, j, j, k]      # {(1), (5), (2, 3, 4)}
        res[:, 19] = torch.einsum('bjjk->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 20: res[batch, 20, i, j] += T[batch, k, i, i]      # {(2), (3), (1, 4, 5)}
        res[:, 20] = torch.einsum('bkii->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)
        # expand across columns!

        # Feature 21: res[batch, 21, i, j] += T[batch, i, k, i]      # {(2), (4), (1, 3, 5)} 
        res[:, 21] = torch.einsum('biki->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 22: res[batch, 22, i, j] += T[batch, i, i, k]      # {(2), (5), (1, 3, 4)} 
        res[:, 22] = torch.einsum('biik->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 23: res[batch, 23, i, i] += T[batch, j, k, i]      # {(3), (4), (1, 2, 5)} 
        res[:, 23, arange_n, arange_n] = torch.einsum('bjki->bi', T)

        # Feature 24: res[batch, 24, i, i] += T[batch, j, i, k]      # {(3), (5), (1, 2, 4)} 
        res[:, 24, arange_n, arange_n] = torch.einsum('bjik->bi', T)

        # Feature 25: res[batch, 25, i, i] += T[batch, i, j, k]      # {(4), (5), (1, 2, 3)}
        res[:, 25, arange_n, arange_n] = torch.einsum('bijk->bi', T)

        # Feature 26: res[batch, 26, i, j] += T[batch, j, k, k]      # {(1), (2, 3), (4, 5)}
        res[:, 26] = torch.einsum('bjkk->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 27: res[batch, 27, i, j] += T[batch, k, j, k]      # {(1), (2, 4), (3, 5)}
        res[:, 27] = torch.einsum('bkjk->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n) 

        # Feature 28: res[batch, 28, i, j] += T[batch, k, k, j]      # {(1), (2, 5), (3, 4)} 
        res[:, 28] = torch.einsum('bkkj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n) 

        # Feature 29: res[batch, 29, i, j] += T[batch, i, k, k]      # {(2), (1, 3), (4, 5)} 
        res[:, 29] = torch.einsum('bikk->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n) 

        # Feature 30: res[batch, 30, i, j] += T[batch, k, i, k]      # {(2), (1, 4), (3, 5)}
        res[:, 30] = torch.einsum('bkik->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n) 

        # Feature 31: res[batch, 31, i, j] += T[batch, k, k, i]      # {(2), (1, 5), (3, 4)}
        res[:, 31] = torch.einsum('bkki->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n) 

        # Feature 32: res[batch, 32, i, i] += T[batch, j, k, k]      # {(3), (1, 2), (4, 5)}
        res[:, 32, arange_n, arange_n] = torch.einsum('bjkk->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 33: res[batch, 33, i, j] += T[batch, k, i, j]      # {(3), (1, 4), (2, 5)}
        res[:, 33] = torch.einsum('bkij->bij', T)

        # Feature 34: res[batch, 34, i, j] += T[batch, k, j, i]      # {(3), (1, 5), (2, 4)}
        res[:, 34] = torch.einsum('bkji->bij', T)

        # Feature 35: res[batch, 35, i, i] += T[batch, j, k, j]      # {(4), (1, 2), (3, 5)}
        res[:, 35, arange_n, arange_n] = torch.einsum('bjkj->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 36: res[batch, 36, i, j] += T[batch, i, k, j]      # {(4), (1, 3), (2, 5)}
        res[:, 36] = torch.einsum('bikj->bij', T)

        # Feature 37: res[batch, 37, i, j] += T[batch, j, k, i]      # {(4), (1, 5), (2, 3)} 
        res[:, 37] = torch.einsum('bjki->bij', T)

        # Feature 38: res[batch, 38, i, i] += T[batch, j, j, k]      # {(5), (1, 2), (3, 4)}
        res[:, 38, arange_n, arange_n] = torch.einsum('bjjk->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 39: res[batch, 39, i, j] += T[batch, i, j, k]      # {(5), (1, 3), (2, 4)}
        res[:, 39] = torch.einsum('bijk->bij', T)

        # Feature 40: res[batch, 40, i, j] += T[batch, j, i, k]      # {(5), (1, 4), (2, 3)} 
        res[:, 40] = torch.einsum('bjik->bij', T)

        # Feature 41: res[batch, 41, i, j] += T[batch, k, l, l]  # {(1), (2), (3), (4, 5)}
        res[:, 41] = torch.einsum('bkll->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        # Feature 42: res[batch, 42, i, j] += T[batch, k, l, k]  # {(1), (2), (4), (3, 5)}
        res[:, 42] = torch.einsum('bklk->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        # Feature 43: res[batch, 43, i, j] += T[batch, k, k, l]  # {(1), (2), (5), (3, 4)}
        res[:, 43] = torch.einsum('bkkl->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)

        # Feature 44: res[batch, 44, i, j] += T[batch, k, l, j]  # {(1), (3), (4), (2, 5)}
        res[:, 44] = torch.einsum('bklj->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 45: res[batch, 45, i, j] += T[batch, k, j, l]  # {(1), (3), (5), (2, 4)}
        res[:, 45] = torch.einsum('bkjl->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 46: res[batch, 46, i, j] += T[batch, j, k, l]  # {(1), (4), (5), (2, 3)}
        res[:, 46] = torch.einsum('bjkl->bj', T).view(num_batches, 1, self.n).expand(num_batches, self.n, self.n)

        # Feature 47: res[batch, 47, i, j] += T[batch, k, l, i]  # {(2), (3), (4), (1, 5)}
        res[:, 47] = torch.einsum('bkli->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 48: res[batch, 48, i, j] += T[batch, k, i, l]  # {(2), (3), (5), (1, 4)}
        res[:, 48] = torch.einsum('bkil->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 49: res[batch, 49, i, j] += T[batch, i, k, l]  # {(2), (4), (5), (1, 3)}
        res[:, 49] = torch.einsum('bikl->bi', T).view(num_batches, self.n, 1).expand(num_batches, self.n, self.n)

        # Feature 50: res[batch, 50, i, i] += T[batch, j, k, l]  # {(3), (4), (5), (1, 2)}
        res[:, 50, arange_n, arange_n] = torch.einsum('bjkl->b', T).view(num_batches, 1).expand(num_batches, self.n)

        # Feature 51: res[batch, 51, i, j] += T[batch, k, l, m]  # {(1), (2), (3), (4), (5)}        
        res[:, 51] = torch.einsum('bklm->b', T).view(num_batches, 1, 1).expand(num_batches, self.n, self.n)


        #for batch in range(num_batches):
            #for i in range(self.n):
                #res[batch, 0, i, i] = T[batch, i, i, i]    # {(1, 2, 3, 4, 5)}
                #for j in range(self.n):
                    #res[batch, 1, i, j] += T[batch, j, j, j]        # {(1), (2, 3, 4, 5)}
                    #res[batch, 2, i, j] += T[batch, i, i, i]        # {(2), (1, 3, 4, 5)}
                    #res[batch, 3, i, i] += T[batch, j, i, i]        # {(3), (1, 2, 4, 5)}
                    #res[batch, 4, i, i] += T[batch, i, j, i]        # {(4), (1, 2, 3, 5)}
                    #res[batch, 5, i, i] += T[batch, i, i, j]        # {(5), (1, 2, 3, 4)}
                    #res[batch, 6, i, i] += T[batch, j, j, j]        # {(1, 2), (3, 4, 5)}
                    #res[batch, 7, i, j] += T[batch, i, j, j]        # {(1, 3), (2, 4, 5)}
                    #res[batch, 8, i, j] += T[batch, j, i, j]        # {(1, 4), (2, 3, 5)}
                    #res[batch, 9, i, j] += T[batch, j, j, i]        # {(1, 5), (2, 3, 4)}
                    #res[batch, 10, i, j] += T[batch, j, i, i]       # {(2, 3), (1, 4, 5)}
                    #res[batch, 11, i, j] += T[batch, i, j, i]       # {(2, 4), (1, 3, 5)}
                    #res[batch, 12, i, j] += T[batch, i, i, j]       # {(2, 5), (1, 3, 4)}
                    #res[batch, 13, i, i] += T[batch, j, j, i]       # {(3, 4), (1, 2, 5)}
                    #res[batch, 14, i, i] += T[batch, j, i, j]       # {(3, 5), (1, 2, 4)}
                    #res[batch, 15, i, i] += T[batch, i, j, j]       # {(4, 5), (1, 2, 3)}
                    #for k in range(self.n):
                        #res[batch, 16, i, j] += T[batch, k, k, k]      # {(1), (2), (3, 4, 5)}
                        #res[batch, 17, i, j] += T[batch, k, j, j]      # {(1), (3), (2, 4, 5)}
                        #res[batch, 18, i, j] += T[batch, j, k, j]      # {(1), (4), (2, 3, 5)}
                        #res[batch, 19, i, j] += T[batch, j, j, k]      # {(1), (5), (2, 3, 4)}
                        #res[batch, 20, i, j] += T[batch, k, i, i]      # {(2), (3), (1, 4, 5)}
                        #res[batch, 21, i, j] += T[batch, i, k, i]      # {(2), (4), (1, 3, 5)} 
                        #res[batch, 22, i, j] += T[batch, i, i, k]      # {(2), (5), (1, 3, 4)} 
                        #res[batch, 23, i, i] += T[batch, j, k, i]      # {(3), (4), (1, 2, 5)} 
                        #res[batch, 24, i, i] += T[batch, j, i, k]      # {(3), (5), (1, 2, 4)} 
                        #res[batch, 25, i, i] += T[batch, i, j, k]      # {(4), (5), (1, 2, 3)}
                        #res[batch, 26, i, j] += T[batch, j, k, k]      # {(1), (2, 3), (4, 5)}
                        #res[batch, 27, i, j] += T[batch, k, j, k]      # {(1), (2, 4), (3, 5)}
                        #res[batch, 28, i, j] += T[batch, k, k, j]      # {(1), (2, 5), (3, 4)} 
                        #res[batch, 29, i, j] += T[batch, i, k, k]      # {(2), (1, 3), (4, 5)} 
                        #res[batch, 30, i, j] += T[batch, k, i, k]      # {(2), (1, 4), (3, 5)}
                        #res[batch, 31, i, j] += T[batch, k, k, i]      # {(2), (1, 5), (3, 4)}
                        #res[batch, 32, i, i] += T[batch, j, k, k]      # {(3), (1, 2), (4, 5)}
                        #res[batch, 33, i, j] += T[batch, k, i, j]      # {(3), (1, 4), (2, 5)}
                        #res[batch, 34, i, j] += T[batch, k, j, i]      # {(3), (1, 5), (2, 4)}
                        #res[batch, 35, i, i] += T[batch, j, k, j]      # {(4), (1, 2), (3, 5)}
                        #res[batch, 36, i, j] += T[batch, i, k, j]      # {(4), (1, 3), (2, 5)}
                        #res[batch, 37, i, j] += T[batch, j, k, i]      # {(4), (1, 5), (2, 3)} 
                        #res[batch, 38, i, i] += T[batch, j, j, k]      # {(5), (1, 2), (3, 4)} 
                        #res[batch, 39, i, j] += T[batch, i, j, k]      # {(5), (1, 3), (2, 4)}
                        #res[batch, 40, i, j] += T[batch, j, i, k]      # {(5), (1, 4), (2, 3)} 
                        #for l in range(self.n): 
                            #res[batch, 41, i, j] += T[batch, k, l, l]  # {(1), (2), (3), (4, 5)}
                            #res[batch, 42, i, j] += T[batch, k, l, k]  # {(1), (2), (4), (3, 5)}
                            #res[batch, 43, i, j] += T[batch, k, k, l]  # {(1), (2), (5), (3, 4)}
                            #res[batch, 44, i, j] += T[batch, k, l, j]  # {(1), (3), (4), (2, 5)}
                            #res[batch, 45, i, j] += T[batch, k, j, l]  # {(1), (3), (5), (2, 4)}
                            #res[batch, 46, i, j] += T[batch, j, k, l]  # {(1), (4), (5), (2, 3)}
                            #res[batch, 47, i, j] += T[batch, k, l, i]  # {(2), (3), (4), (1, 5)}
                            #res[batch, 48, i, j] += T[batch, k, i, l]  # {(2), (3), (5), (1, 4)}
                            #res[batch, 49, i, j] += T[batch, i, k, l]  # {(2), (4), (5), (1, 3)}
                            #res[batch, 50, i, i] += T[batch, j, k, l]  # {(3), (4), (5), (1, 2)}
                            #for m in range(self.n):
                                #res[batch, 51, i, j] += T[batch, k, l, m]  # {(1), (2), (3), (4), (5)}

        final_result = torch.einsum('bknm,k->bnm', res, self.param_vector)

        return final_result


class PermEquiv32Model(nn.Module):
    def __init__(self, n):
        super(PermEquiv32Model, self).__init__()
        
        self.n = n
        self.layer = nn.Sequential(
            PermEquiv32Linear(self.n),
            nn.Tanh(),
            PermEquiv22Linear(self.n),
            nn.Tanh(),
            PermEquiv22Linear(self.n),
            nn.Tanh(),
        )

    def forward(self, T):
        return self.layer(T)


if __name__ == "__main__":
    n = 3
    
    # Example usage:
    model = PermEquiv32Linear(n)

    # Generate toy data
    T = torch.tensor(
    [[
        [[ 1,  4,  0],
         [ 6,  0,  2],
         [ 0, -2,  0]],

        [[-3,  1,  3],
         [ 0,  0,  5],
         [ 2,  7,  0]],

        [[ 3,  2,  2],
         [-2,  0,  9],
         [ 0,  7,  8]]
    ]],
    dtype=torch.float32
    )
    
    print(T.shape)

    # Forward pass
    output = model(T)
    print(output, output.shape) 