import numpy as np
import torch
from torch import nn
from utils.locally_connected import LocallyConnected


class LinearODE(nn.Module):
    def __init__(self, X, Y, dims, kernal_size, stride, bias, device):
        super(LinearODE, self).__init__()
        self.dims = dims
        self.d = dims[0]
        self.n = X.shape[0]
        self.t = X.shape[1]
        self.X = X
        self.Y = Y
        self.kernal_size = kernal_size
        self.stride = stride
        self.bias = bias
        self.device = device
        inntial_index = torch.from_numpy(np.array([i for i in range(self.t) if i <= int((self.t - kernal_size)/stride) * stride], dtype=np.int64))
        self.batch_x = self.X[ :, inntial_index, None, :].permute(2, 1, 0, 3)
        self.batch_y = torch.stack([self.Y[:, inntial_index + i, :] for i in range(self.kernal_size)], dim=2).permute(2, 1, 0, 3)
        self.elu = nn.ELU(inplace=True)

        l_layers = []
        for l in range(len(dims) - 2):
            l_layers.append(LocallyConnected(self.d, dims[l + 1], dims[l + 2], bias=self.bias))
        self.fc = nn.ModuleList(l_layers).to(self.device)

    def odeint(self, W):
        A = (W[1:, :, :] - W[:-1, :, :]) / self.stride
        Ws = [W[:-1, :, :]]
        for i in range(1, self.stride):
            Ws.append(A * i + W[:-1, :, :])
        Ws = torch.stack(Ws, dim=1).view(-1, self.d, self.d * self.dims[1])
        Ws = torch.cat([Ws, W[-1:]], dim=0)
        est_y = [self.batch_x]
        for i in range(1, self.kernal_size):
            h = torch.matmul(est_y[i-1], Ws)
            h = h.view(h.shape[0] * h.shape[1] * h.shape[2], self.d, self.dims[1])
            for fc in self.fc:
                h = self.elu(h)
                h = fc(h)
            est_delta_y = h.squeeze().view(1, -1, self.n, self.d)
            est_y.append(est_y[i-1] + est_delta_y)
        est_y = torch.cat(est_y, dim=0).squeeze(dim=-1)
        return est_y

    def l2_reg(self, W):
        reg = 0.
        reg += torch.sum(W ** 2) / W.shape[0]
        for fc in self.fc:
            reg += torch.sum(fc.weight ** 2)
        return reg

    def log_mse_loss(self, output):
        n = output.shape[0] * output.shape[1] * output.shape[2]
        d = output.shape[3]
        loss = 0.5 / n * torch.sum((output - self.batch_y) ** 2)
        return loss