import torch

class LRULayer(torch.nn.Module):
    def __init__(self, input_size, hidden_size, device, r_min=0, r_max=1):
        super(LRULayer,self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        self.device = device

        # Random variables
        u_1 = torch.rand((hidden_size, hidden_size))
        u_2 = torch.rand((hidden_size, hidden_size))
        # 𝜈 defined as in the paper
        self.nu =  torch.nn.Parameter(-0.5*torch.log(u_1 * (r_max**2 - r_min**2) 
            + r_min**2), requires_grad=True)
        # 𝜃 defined as in the paper
        self.theta = torch.nn.Parameter(2 * 3.14159 * u_2, requires_grad=True)
        # B, C, D initialised as in paper
        self.B_real = torch.nn.Parameter(torch.randn((hidden_size, 
            input_size)),requires_grad=True)
        self.B_im = torch.nn.Parameter(torch.randn((hidden_size, 
            input_size)),requires_grad=True)
        torch.nn.init.xavier_normal_(self.B_real)
        torch.nn.init.xavier_normal_(self.B_im)
        self.C_real = torch.nn.Parameter(torch.randn((hidden_size, 
            hidden_size)),requires_grad=True)
        self.C_im = torch.nn.Parameter(torch.randn((hidden_size, 
            hidden_size)),requires_grad=True)
        torch.nn.init.xavier_normal_(self.C_real)
        torch.nn.init.xavier_normal_(self.C_im)

        self.D = torch.nn.Parameter(torch.randn((hidden_size, input_size)), 
            requires_grad=True)
        torch.nn.init.xavier_normal_(self.D)
        self.update_weights()

    def forward(self, u):
        # TODO: This should really be done through torch parameterisation
        self.update_weights()

        batch_size = u.shape[0]
        N = u.shape[1]

        # Eigenvalues of Λ
        eigs = torch.diagonal(self.Lambda)
        # ɣ used to scale/normalise matrices
        gamma = torch.log(torch.sqrt(1 - eigs**2))

        # u as complex matrix
        u_complex = torch.complex(u, torch.zeros(u.shape).to(self.device))

        # multiplying B by u
        Bu = torch.matmul(self.B, u_complex.unsqueeze(-1)).squeeze(-1)
        # Normalising with gamma and repeating value of u so calculation 
        # for each position can be performed simultaneously
        gamma_Bu = torch.exp(gamma)*Bu

        # Lambda is diagonal

        lambda_powers = torch.stack(
            [torch.pow(eigs, i) for i in range(N)],
            dim=0)

        all_powers = torch.stack(
            [torch.cat(
                [torch.flip(lambda_powers[:i+1],dims=[0]), 
                torch.zeros((N-i-1, self.hidden_size)
                    ).to(self.device)],
                dim=0) for i in range(N)],dim=0)

        x = torch.sum(all_powers.unsqueeze(0) * gamma_Bu.unsqueeze(1),dim=2)
        # Multiply x by C and take real part
        Cmult = torch.matmul(self.C, x.unsqueeze(-1)).squeeze(-1)
        C_term = Cmult.real
        # Multiply D by u
        Dmult = torch.matmul(self.D, u.unsqueeze(-1)).squeeze(-1)
        #print(Dmult.shape)
        y = C_term + Dmult
        return y, x

    def update_weights(self):
        # Λ defined as the diagonal of exp(-𝜈+i𝜃)
        self.complex = torch.complex(self.nu, self.theta)
        self.Lambda = torch.diag(torch.diagonal(torch.exp(self.complex)))
        self.B = torch.complex(self.B_real, self.B_im)
        self.C = torch.complex(self.C_real, self.C_im)


class LRU(torch.nn.Module):
    def __init__(self, embed_dim, hidden_size, num_layers, device, r_min=0, 
        r_max=1):
        super(LRU, self).__init__()
        self.layers = torch.nn.ModuleList([LRULayer(embed_dim, hidden_size, 
            device, r_min, r_max)] 
            + [LRULayer(hidden_size, hidden_size, device, r_min, 
                r_max) for i in range(num_layers-1)])

    def forward(self, u):
        for i, lru in enumerate(self.layers):
            u, hidden = lru(u)
            if (i+1) < len(self.layers):
                u = torch.tanh(u)
        return u, hidden
