import numpy as np
import torch
from torch import nn
from utils.locally_connected import LocallyConnected


class LinearLag(nn.Module):
    def __init__(self, X, Y, dims, lag, ins, kernal_size, stride, bias, device):
        super(LinearLag, self).__init__()
        self.dims = dims
        self.d = dims[0]
        self.p = X.shape[2]
        self.t = X.shape[1]
        self.X = X
        self.Y = Y
        self.lag = lag
        self.ins = ins
        self.kernal_size = kernal_size
        self.stride = stride
        self.bias = bias
        self.device = device
        inntial_index = torch.from_numpy(np.array([i for i in range(self.t) if i <= int((self.t - kernal_size)/stride) * stride], dtype=np.int64))
        # inntial_index = torch.from_numpy(np.array([i * stride for i in range(self.t) if i * stride + kernal_size <= self.t], dtype=np.int64))  # exam without Taylor
        self.batch_x = torch.stack([self.X[:, inntial_index + i, :] for i in range(self.kernal_size)], dim=2).permute(2, 1, 0, 3)
        self.batch_y = torch.stack([self.Y[:, inntial_index + i, :] for i in range(self.kernal_size)], dim=2).permute(2, 1, 0, 3)
        l_layers = []
        for l in range(len(dims) - 2):
                l_layers.append(LocallyConnected(self.d, dims[l + 1], dims[l + 2], bias=self.bias))
        self.fc = nn.ModuleList(l_layers).to(self.device)
        self.relu = nn.LeakyReLU()

    def linear(self, W):
        A = (W[1:, :, :] - W[:-1, :, :]) / self.stride
        Ws = [W[:-1, :, :]]
        for i in range(1, self.stride):
            Ws.append(A * i + W[:-1, :, :])
        if self.ins:
            Ws = torch.stack(Ws, dim=1).view(-1, self.d * (self.lag + 1), self.d * self.dims[1])
        else:
            Ws = torch.stack(Ws, dim=1).view(-1, self.d * self.lag, self.d * self.dims[1])
        Ws = torch.cat([Ws, W[-1:]], dim=0)
        # Ws = W  # exam without Taylor
        h = torch.matmul(self.batch_x, Ws).squeeze(dim=-1)
        a, b, c = h.shape[0], h.shape[1], h.shape[2]
        h = h.view(a * b * c, self.d, self.dims[1])
        for fc in self.fc:
            h = torch.relu(h)
            h = fc(h)
        h = h.view(a, b, c, self.d, 1)
        est_y = h.squeeze(dim=-1)
        return est_y

    def l2_reg(self, W):
        reg = 0.
        reg += torch.sum(W ** 2) / W.shape[0]
        for fc in self.fc:
            reg += torch.sum(fc.weight ** 2)
        return reg

    def log_mse_loss(self, output):
        n = output.shape[0] * output.shape[1] * output.shape[2]
        d = output.shape[3]
        loss = 0.5 * d * torch.log(1 / n * torch.sum((output - self.batch_y) ** 2))
        return loss