import torch
import torch.nn as nn
import torch.nn.functional as F


class Layer(nn.Module):
    def __init__(self, data_num, d):
        super(Layer, self).__init__()
        self.d = d
        self.window_size = data_num
        self.W = nn.Parameter(torch.randn(2 * (d + 1), 2 * (d + 1)))
        self.V = nn.Parameter(torch.randn(2 * (d + 1), 2 * (d + 1)))
        nn.init.normal_(self.W, 0, 1 / (2 * (d + 1)))
        nn.init.normal_(self.V, 0, 1 / (2 * (d + 1)))

    def forward(self, x):
        output = self.V @ x @ x.T @ self.W @ x[:, -1]
        return output[: self.d ]


class CoT(nn.Module):
    def __init__(self, data_num, d):
        super(CoT, self).__init__()
        self.layer = Layer(data_num, d)

    def forward(self, x):
        x = self.layer(x)
        return x