import numpy as np
import torch
import torch.nn as nn

class ShapeConvM(nn.Module):
    def __init__(self, dim, kernel_size, channel = 1, stride = 1, dilation = 1):
        # dim: number of kernels / shapelets
        # kernel_size: kernel size (length of shapelets)
        # channel: dimension of time series
        super().__init__()

        self.dim = dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.channel = channel
        self.dilation = dilation

        self.conv = nn.Conv1d(channel, dim, kernel_size, stride, bias=False, dilation = dilation)

        nn.init.normal_(self.conv.weight)
    
    def init(self, loader):
        labels = torch.tensor([0])
        for data, label in loader:
            labels = torch.unique(torch.cat([labels, torch.unique(label)]))
        if self.dim > len(labels):
            num_k = self.dim // len(labels)
            weight_tmp = torch.zeros((self.dim, self.channel, self.kernel_size)) + 0.5
            weight_cnt = torch.zeros((self.dim, self.channel))
            for i in range(len(labels)):
                for d in range(self.channel):
                    for data, label in loader:
                        for j, lab in enumerate(label):
                            if lab == i:
                                seq = data[j, d]
                                unit = len(seq) // num_k
                                for q in range(num_k):
                                    tmp = torch.zeros(self.kernel_size)
                                    cnt = 0
                                    for p in range(unit * q, unit * (q + 1)):
                                        if p + self.kernel_size > len(seq): break
                                        tmp = tmp + seq[p:p+self.kernel_size]
                                        cnt += 1
                                    weight_tmp[i * num_k + q, d] += tmp
                                    weight_cnt[i * num_k + q, d] += cnt
                                    # print(i+q, weight_tmp[i * num_k + q], weight_cnt[i * num_k + q])
            weight_tmp /= weight_cnt.unsqueeze(2).repeat(1, 1, self.kernel_size)
            # weight_tmp = weight_tmp.unsqueeze(1)
            self.conv.weight = torch.nn.Parameter(weight_tmp)
        else:
            nn.init.normal_(self.conv.weight)

    def forward(self, x):
        h = self.conv(x)
        kernel_square = (self.conv.weight.square().sum(-1).sum(-1).repeat(x.shape[0], 1).unsqueeze(-1))
        unfold = nn.Unfold((1, self.kernel_size), stride = self.stride, dilation = self.dilation)
        x_square = (unfold(x.unsqueeze(2)).square().sum(-2, keepdim = True)).repeat(1, self.dim, 1)
        knorm = torch.linalg.norm(torch.exp(-torch.cdist(self.conv.weight.squeeze(), self.conv.weight.squeeze())), dim = (0, -1), ord = 'fro').sum()
        return (x_square - 2 * h + kernel_square) / self.kernel_size, knorm

class ShapeConvClassifierM(nn.Module):
    def __init__(self, num_class, channel, hid, dim, kernel_size, stride = 1, dilation = 1):
        super().__init__()

        self.shapeconv = ShapeConvM(dim, kernel_size, channel, stride, dilation = dilation)

        layers = [nn.Linear(dim, hid), nn.ReLU(), nn.Linear(hid, num_class)]
        self.cla = nn.Sequential(*layers)

    def init(self, loader):
        self.shapeconv.init(loader)

    def forward(self, x):
        h, knorm = self.shapeconv(x)
        return h, knorm, self.cla(h.min(-1)[0])
        