import numpy as np
import torch
import torch.nn as nn


def get_res(hidden_dim, radius, connectivity):
    w = np.random.rand(hidden_dim, hidden_dim) - 0.5
    mask = np.random.rand(hidden_dim, hidden_dim) < connectivity
    w *= mask
    max_lambda = max(abs(np.linalg.eig(w)[0]))
    return w * (radius / max_lambda)


class CuDeRes(nn.Module):
    def __init__(self, input_dim, reservoir_size=50, radii=(0.9, 0.9, 0.9), regular=1):
        """
        input: [batch_size, height, width, channels, input_dim]
        """
        super(CuDeRes, self).__init__()
        self.input_dim = input_dim
        self.reservoir_size = reservoir_size
        self.radii = radii
        self.connectivity = max(0.1, 5/reservoir_size)
        self.regular = regular * torch.eye(reservoir_size * 3 + 1).unsqueeze(0)
        self.Win = nn.Parameter(torch.randn(self.input_dim, self.reservoir_size), requires_grad=False)
        self.Wx = nn.Parameter(torch.from_numpy(
            get_res(reservoir_size, self.radii[0], self.connectivity).astype(np.float32)), requires_grad=False)
        self.Wy = nn.Parameter(torch.from_numpy(
            get_res(reservoir_size, self.radii[1], self.connectivity).astype(np.float32)), requires_grad=False)
        self.Wz = nn.Parameter(torch.from_numpy(
            get_res(reservoir_size, self.radii[2], self.connectivity).astype(np.float32)), requires_grad=False)

    def compute_state(self, x, decay):
        batch_size, height, width, channels, _ = x.shape
        states = torch.zeros([batch_size, height, width, channels, self.reservoir_size])
        initial_state = torch.zeros([batch_size, self.reservoir_size])

        for c in range(channels):
            for w in range(width):
                for h in range(height):
                    sx = initial_state if h == 0 else states[:, h-1, w, c, :]
                    sy = initial_state if w == 0 else states[:, h, w-1, c, :]
                    sz = initial_state if c == 0 else states[:, h, w, c-1, :]
                    states[:, h, w, c, :] = torch.tanh(
                        x[:, h, w, c, :] @ self.Win + decay[0] * sx @ self.Wx +
                        decay[1] * sy @ self.Wy + decay[2] * sz @ self.Wz
                    )
        return states

    def get_feature(self, states, x):
        batch_size, height, width, channel, _ = states.shape
        ones = torch.ones(batch_size, 1)
        H, X = [], []
        for c in range(1, channel):
            for h in range(1, height):
                for w in range(1, width):
                    H.append(torch.cat([states[:, h-1, w, c, :], states[:, h, w-1, c, :],
                                        states[:, h, w, c-1, :], ones], dim=1).unsqueeze(1))
                    X.append(x[:, h, w, c, :].unsqueeze(1))
        regular = self.regular.repeat(batch_size, 1, 1)
        H, X = torch.cat(H, dim=1), torch.cat(X, dim=1)
        Ht = H.transpose(1, 2)
        HtH = torch.bmm(Ht, H)
        HtX = torch.bmm(Ht, X)
        W_out = torch.linalg.solve(HtH + regular, HtX)
        W_out[:, -1, :] *= pow(3., 0.5)
        return W_out.flatten(start_dim=1)

    def forward(self, x, decay=(1, 1, 1)):
        states = self.compute_state(x, decay)
        feature = self.get_feature(states, x)
        return feature



