import numpy as np
import torch
import torch.nn as nn

class ShapeConv(nn.Module):
    def __init__(self, dim, kernel_size, stride = 1, dilation = 1):
        super().__init__()

        self.dim = dim
        self.kernel_size = kernel_size
        self.stride = stride
        self.dilation = dilation

        self.conv = nn.Conv1d(1, dim, kernel_size, stride, bias=False, dilation = dilation)

        nn.init.normal_(self.conv.weight, mean = 0.5, std = 0.25)

    def init_cluster(self, loader, sample = True, num_sample = 10000, num_class = 0):
        # init with clustering
        # if sample is True, then sample num_sample subsequences from the dataset and cluster them
        # else, cluster the whole TS and use the labels to init

        from sklearn.cluster import KMeans
        if sample:
            dat = loader.dataset
            tot = len(dat)
            length = dat[0][0].shape[0]
            ids = torch.randint(low = 0, high = tot, size = (num_sample, ))
            pos = torch.randint(low = 0, high = length - self.kernel_size, size = (num_sample, ))
            presamples = dat[ids][0]
            samples = []
            for i in range(num_sample):
                samples.append(presamples[i][pos[i]:pos[i]+self.kernel_size])
            
            samples = torch.stack(samples)
            kmeans = KMeans(n_clusters=self.dim).fit(samples.numpy())
            cluster_centers = torch.tensor(kmeans.cluster_centers_)
            self.conv.weight = torch.nn.Parameter(cluster_centers.unsqueeze(1))
        else:
            dat = loader.dataset
            kmeans = KMeans(n_clusters=num_class).fit(dat.data.numpy())
            cluster_ids_x = torch.tensor(kmeans.labels_)
            self.init(loader, cluster_ids_x)

    def init_cut(self, loader, cut, num_sample = 10000):
        # init with cutting and clustering

        from sklearn.cluster import KMeans
        dat = loader.dataset.data
        num_k = self.dim // cut
        length = dat.shape[-1] // cut
        for i in range(cut):
            seqs = dat[:, length * i:length * (i + 1)]
            ids = torch.randint(low = 0, high = dat.shape[0], size = (num_sample // cut, ))
            pos = torch.randint(low = 0, high = length - self.kernel_size, size = (num_sample // cut, ))
            presamples = seqs[ids]
            samples = []
            for j in range(num_sample // cut):
                samples.append(presamples[j][pos[j]:pos[j]+self.kernel_size])
            samples = torch.stack(samples)#.to(device)
            kmeans = KMeans(n_clusters=num_k).fit(samples.numpy())
            cluster_centers = torch.tensor(kmeans.cluster_centers_)
            self.conv.weight.data[i * num_k:(i + 1) * num_k] = torch.nn.Parameter(cluster_centers.unsqueeze(1))

    def init(self, loader, man_label = None, device = torch.device('cpu')):
        # init with the mean of each class
        dataset = loader.dataset
        if man_label == None:
            lab = torch.tensor([dataset.dataset.label[i] for i in dataset.indices])
        else:
            lab = man_label
        labels = torch.unique(lab)
        data_class = [[] for _ in range(len(labels))]
        for i, (data, label) in enumerate(dataset):
            data_class[int(lab[i])].append(data)
        for i in range(len(labels)):
            data_class[i] = torch.stack(data_class[i]).to(device)
        if self.dim >= len(labels):
            num_k = self.dim // len(labels)
            unit = data.shape[-1] // num_k
            if unit < self.kernel_size:
                while unit < self.kernel_size:
                    num_k -= 1
                    unit = data.shape[-1] // num_k
            for i in range(len(labels)):
                data = data_class[i]
                for q in range(num_k):
                    seqs = data[:, unit * q:unit * (q + 1)].unfold(-1, self.kernel_size, self.stride)
                    self.conv.weight.data[i * num_k + q] = seqs.mean((0, 1)).to(self.conv.weight.device)
        else:
            pass

    def forward(self, x):
        h = self.conv(x)
        kernel_square = (self.conv.weight.square().sum(-1).squeeze().repeat(x.shape[0], 1).unsqueeze(-1))
        x_square = (x.unfold(-1, self.kernel_size, self.stride).square().sum(-1)).repeat(1, self.dim, 1)
        knorm = torch.linalg.norm(torch.exp(-torch.cdist(self.conv.weight.squeeze(), self.conv.weight.squeeze())), 'fro')
        return (x_square - 2 * h + kernel_square) / self.kernel_size, knorm
    
class ShapeConvClassifier(nn.Module):
    def __init__(self, num_class, hid, dim, kernel_size, dropout=0.0, stride = 1, dilation = 1):
        super().__init__()

        self.shapeconv = ShapeConv(dim, kernel_size, stride, dilation = dilation)

        layers = [nn.Linear(dim, hid), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hid, num_class)]
        self.cla = nn.Sequential(*layers)

    def init(self, loader):
        self.shapeconv.init(loader)

    def forward(self, x):
        h, knorm = self.shapeconv(x)
        return h, knorm, self.cla(h.min(-1)[0])
