import numpy as np
import torch
from torch import nn
from utils.locally_connected import LocallyConnected
torch.set_default_dtype(torch.double)

class WindowsEncode(nn.Module):
    def __init__(self, dims, sem_type, n, lag, ins, kernel_size, stride, bias, device):
        super(WindowsEncode, self).__init__()
        self.d = dims[0]
        self.dims = dims
        self.sem_type = sem_type
        self.n = n
        self.lag = lag
        self.ins = ins
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.device = device
        self.I = torch.eye(self.dims[0]).to(self.device)

        self.conv2d = nn.Conv2d(in_channels=self.d, out_channels=self.d * self.dims[1], bias=self.bias, kernel_size=(self.n, self.kernel_size), stride=self.stride, padding=0).to(self.device)
        if self.sem_type == 'ode':
            self.w_fc1 = LocallyConnected(self.d, dims[1], dims[1] * self.d, bias=True).to(self.device)
            self.w_fc2 = LocallyConnected(self.d * dims[1], self.d, self.d, bias=True).to(self.device)
        elif self.ins:
            self.w_fc1 = LocallyConnected(self.d, dims[1], dims[1] * self.d, bias=True).to(self.device)
            self.w_fc2 = LocallyConnected(self.d * dims[1], self.d, self.d * (self.lag + 1), bias=True).to(self.device)
        else:
            self.w_fc1 = LocallyConnected(self.d, dims[1], dims[1] * self.d, bias=True).to(self.device)
            self.w_fc2 = LocallyConnected(self.d * dims[1], self.d, self.d * self.lag, bias=True).to(self.device)
        # for dynamic, if static, the follow can be removed
        nn.init.zeros_(self.conv2d.weight)
        nn.init.zeros_(self.w_fc1.weight)
        nn.init.zeros_(self.w_fc2.weight)

    def forward(self, x):
        x = x.permute(2, 0, 1)[None, :, :, :]  #[1, d, n, t]
        h = self.conv2d(x).squeeze().T  # [d * lag * m1, t/s]
        h = torch.relu(h).view(-1, self.d, self.dims[1])
        h = self.w_fc1(h)
        h = torch.relu(h).transpose(1, 2)
        W = self.w_fc2(h).transpose(1, 2)
        return W

    def h_func(self, W):
        W = W.view(W.shape[0], self.d, self.d, -1)
        A = torch.sum(W * W, dim=3)
        with torch.no_grad():
            N = torch.linalg.matrix_norm(A, ord=1)
        h = - torch.slogdet(self.I - A / N[:, None, None])[1]
        return h

    def l1_reg(self, W):
        reg = 0.0
        reg += torch.sum(torch.abs(W)) / W.shape[0]
        return reg