import numpy as np
import torch
import torch.nn as nn

class ShapeConv(nn.Module):
    # ShapeConv layer for univariate TS
    def __init__(self, dim, kernel_size, stride = 1, dilation = 1):
        """
        params:
            dim: the number of shapelets
            kernel_size: the length of the shapelets
            stride: the stride of convolution
            dilation: the dilation of convolution
        """
        super().__init__()

        self.dim = dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

        self.conv = nn.Conv1d(1, dim, kernel_size, stride, bias=False, dilation = dilation)

        nn.init.normal_(self.conv.weight, mean = 0.5, std = 0.25)

    # This init depends on the dataset
    # def init(self, loader, man_label = None, device = torch.device('cpu')):
    #     # init with the mean of each class
    #     dataset = loader.dataset
    #     if man_label == None:
    #         lab = dataset.label
    #     else:
    #         lab = man_label
    #     labels = torch.unique(lab)
    #     data_class = [[] for _ in range(len(labels))]
    #     for i, (data, label) in enumerate(dataset):
    #         data_class[int(lab[i])].append(data)
    #     for i in range(len(labels)):
    #         data_class[i] = torch.stack(data_class[i]).to(device)
    #     # print(data_class[0].shape)
    #     if self.dim >= len(labels):
    #         num_k = self.dim // len(labels)
    #         unit = data.shape[-1] // num_k
    #         if unit < self.kernel_size:
    #             while unit < self.kernel_size:
    #                 num_k -= 1
    #                 unit = data.shape[-1] // num_k
    #         for i in range(len(labels)):
    #             data = data_class[i]
    #             for q in range(num_k):
    #                 seqs = data[:, unit * q:unit * (q + 1)].unfold(-1, self.kernel_size, self.stride)
    #                 self.conv.weight.data[i * num_k + q] = seqs.mean((0, 1)).to(self.conv.weight.device)
    #     else:
    #         pass

    def forward(self, x):
        """
        params:
            x: the input data, shape (batch_size, 1, seq_length)
        return:
            dist: the normalized distance between the input and the shapelets, shape (batch_size, dim, seq_length - kernel_size + 1)
                take minimum along the last dimension to get the shapelet-transformed distance
            knorm: the frobenius norm of the kernel matrix, scalar, used in diversity loss
        """
        h = self.conv(x)
        kernel_square = (self.conv.weight.square().sum(-1).squeeze().repeat(x.shape[0], 1).unsqueeze(-1))
        x_square = (x.unfold(-1, self.kernel_size, self.stride).square().sum(-1)).repeat(1, self.dim, 1)
        knorm = torch.linalg.norm(torch.exp(-torch.cdist(self.conv.weight.squeeze(), self.conv.weight.squeeze())), 'fro')
        dist = (x_square - 2 * h + kernel_square) / self.kernel_size
        return dist, knorm


class ShapeConvM(nn.Module):
    # ShapeConv layer for multivariate TS
    def __init__(self, dim, kernel_size, channel = 1, stride = 1, dilation = 1):
        """
        params:
            dim: the number of shapelets
            kernel_size: the length of the shapelets
            channel: the dimension of the time series
            stride: the stride of convolution
            dilation: the dilation of convolution
        """
        super().__init__()

        self.dim = dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.channel = channel
        self.dilation = dilation

        self.conv = nn.Conv1d(channel, dim, kernel_size, stride, bias=False, dilation = dilation)

        nn.init.normal_(self.conv.weight)
    
    # This init depends on the dataset
    # def init(self, loader):
    #     labels = torch.tensor([0])
    #     for data, label in loader:
    #         labels = torch.unique(torch.cat([labels, torch.unique(label)]))
    #     if self.dim > len(labels):
    #         num_k = self.dim // len(labels)
    #         weight_tmp = torch.zeros((self.dim, self.channel, self.kernel_size)) + 0.5
    #         weight_cnt = torch.zeros((self.dim, self.channel))
    #         for i in range(len(labels)):
    #             for d in range(self.channel):
    #                 for data, label in loader:
    #                     for j, lab in enumerate(label):
    #                         if lab == i:
    #                             seq = data[j, d]
    #                             unit = len(seq) // num_k
    #                             for q in range(num_k):
    #                                 tmp = torch.zeros(self.kernel_size)
    #                                 cnt = 0
    #                                 for p in range(unit * q, unit * (q + 1)):
    #                                     if p + self.kernel_size > len(seq): break
    #                                     tmp = tmp + seq[p:p+self.kernel_size]
    #                                     cnt += 1
    #                                 weight_tmp[i * num_k + q, d] += tmp
    #                                 weight_cnt[i * num_k + q, d] += cnt
    #                                 # print(i+q, weight_tmp[i * num_k + q], weight_cnt[i * num_k + q])
    #         weight_tmp /= weight_cnt.unsqueeze(2).repeat(1, 1, self.kernel_size)
    #         # weight_tmp = weight_tmp.unsqueeze(1)
    #         self.conv.weight = torch.nn.Parameter(weight_tmp)
    #     else:
    #         nn.init.normal_(self.conv.weight)

    def forward(self, x):
        """
        params:
            x: the input data, shape (batch_size, channel, seq_length)
        return:
            dist: the normalized distance between the input and the shapelets, shape (batch_size, dim, seq_length - kernel_size + 1)
                take minimum along the last dimension to get the shapelet-transformed distance
            knorm: the frobenius norm of the kernel matrix, scalar, used in diversity loss
        """
        h = self.conv(x)
        kernel_square = (self.conv.weight.square().sum(-1).sum(-1).repeat(x.shape[0], 1).unsqueeze(-1))
        unfold = nn.Unfold((1, self.kernel_size), stride = self.stride, dilation = self.dilation)
        x_square = (unfold(x.unsqueeze(2)).square().sum(-2, keepdim = True)).repeat(1, self.dim, 1)
        knorm = torch.linalg.norm(torch.exp(-torch.cdist(self.conv.weight.squeeze(), self.conv.weight.squeeze())), dim = (0, -1), ord = 'fro').sum()
        dist = (x_square - 2 * h + kernel_square) / self.kernel_size
        return dist, knorm
