
import time
# from models1.losses import hierarchical_contrastive_loss,cluster_loss
import math
import torch
import torch.nn.functional as F
from torch import nn

# from timm.layers import Mlp, DropPath
# from timm.layers.helpers import to_2tuple
import sys
from models.sdtw import SoftDTW_align
# from module import hierarchical_clustering

import torch
from torch import nn
import torch.nn.functional as F
from random import sample, choices
import numpy as np



import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import copy
import math
import shutil
import sys
import numpy as np
import scipy.sparse as sp
import scipy.spatial as ss
from collections import defaultdict
from sklearn import metrics
from pynndescent import NNDescent
from tqdm import tqdm
import matplotlib.pyplot as plt


def hierarchical_contrastive_loss(z1, z2, alpha=1.0, temporal_unit=0, norm_flag=1):
    loss = torch.tensor(0., device=z1.device)
    d = 0

    if norm_flag:
        z1 = F.normalize(z1, p=2, dim=-1)
        z2 = F.normalize(z2, p=2, dim=-1)
        tau = 1.0
    else:
        tau = 1

    while z1.size(1) > 1:
        if alpha != 0:
            loss += alpha * instance_contrastive_loss(z1, z2, tau=tau)

        if d >= temporal_unit:
            if 1 - alpha != 0:
                loss += (1 - alpha) * temporal_contrastive_loss(z1, z2, tau=tau)

        d += 1
        z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2)
        z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2)

    if z1.size(1) == 1:
        if alpha != 0:
            loss += alpha * instance_contrastive_loss(z1, z2, tau=tau)
        d += 1

    return loss / d


def instance_contrastive_loss(z1, z2, tau=1):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x T x C
    z = z.transpose(0, 1)  # T x 2B x C
    sim = torch.matmul(z, z.transpose(1, 2)) / tau  # T x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # T x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)

    i = torch.arange(B, device=z1.device)
    loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    return loss


def temporal_contrastive_loss(z1, z2, tau=1):
    B, T = z1.size(0), z1.size(1)
    if T == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # B x 2T x C
    sim = torch.matmul(z, z.transpose(1, 2)) / tau  # B x 2T x 2T
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]  # B x 2T x (2T-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)

    t = torch.arange(T, device=z1.device)
    loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
    return loss


@torch.no_grad()
def concat_all_gather(tensor):

    if torch.distributed.is_initialized():

        tensors_gather = [torch.ones_like(tensor)
                          for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
        return torch.cat(tensors_gather, dim=0)
    else:

        return tensor.clone()


@torch.no_grad()
def _batch_unshuffle_ddp(x, idx_unshuffle):

    x_gather = x.clone()

    return x_gather[idx_unshuffle]


@torch.no_grad()
def _batch_shuffle_ddp(x):

    idx_shuffle = torch.randperm(x.shape[0]).to(x.device)

    idx_unshuffle = torch.argsort(idx_shuffle)

    return x[idx_shuffle], idx_unshuffle


def cluster_loss(im_q, im_k=None, is_eval=False, cluster_result=None, c=None, index=None):
    """
    Input:
        im_q: a batch of query series
        im_k: a batch of key series
        is_eval: return momentum embeddings (used for clustering)
        cluster_result: cluster assignments, centroids, and density
        index: indices for training samples
    Output:
        logits, targets, proto_logits, proto_targets
    """
    posi = 3
    negi = 4
    posp = 3
    negp = 4
    m = 0.999
    tempi = 0.2
    tempp = 0.3
    usetemp = True
    mlp = False

    if is_eval:
        k = nn.functional.normalize(im_k, dim=1)
        return k

    # compute key features
    with torch.no_grad():  # no gradient to keys

        # shuffle for making use of BN
        k, idx_unshuffle = _batch_shuffle_ddp(im_k)

        k = nn.functional.normalize(k, dim=1)

        # undo shuffle
        k = _batch_unshuffle_ddp(k, idx_unshuffle)

    # compute query features
    q = nn.functional.normalize(im_q, dim=1)

    # print(f'q.shape:{q.shape}')  # aug1 [128,128]
    # print(f'k.shape:{k.shape}')  # aug2 [128,128]

    # if cluster_result is not None:
    proto_labels = []
    proto_logits = []

    """instance-level contrastive learning only uses the 0-th partition"""
    p0_label = {}  # dict (key:index, value:label)
    label_index = {}  # dict
    index_u = {}
    for u in range(0, index.shape[0]):
        index_u[index[u].item()] = u
        p0_label[index[u].item()] = cluster_result['im2cluster'][0][index[u]].item()
    # find keys(ids) with same value(cluster label) in dict p0_label
    for key, value in p0_label.items():
        label_index.setdefault(value, []).append(key)

    posid = {}
    negid = {}
    neg_instances = [[] for _ in range(len(p0_label))]
    pos_instances = [[] for _ in range(len(p0_label))]
    all_instances = [[] for _ in range(len(p0_label))]

    for i in p0_label:
        posid[i] = label_index[p0_label[i]].copy()  # all candidate pos instances(if not enough, copy itself)
        if (len(posid[i])) < posi:
            for _ in range(0, posi - len(posid[i])):
                posid[i].append(i)
        ##test
        # posid[i] = [x for x in index.tolist()]

        negid[i] = [x for x in index.tolist() if x not in posid[i]]

        # print(f'negid[i]:{negid[i]}')
        if (len(posid[i])) > posi:
            posid[i] = sample(posid[i], posi)  # if len = self.posi, preserve
        if len(negid[i]) > 0:
            negid[i] = choices(negid[i], k=negi)
        # have obtained posid and negid, then find the corresponding representations and concat
        # pos[dim, 2*posi]
        # neg[dim, 2*negi]
        # all=pos+neg [dim, 2*posi+2*negi]
        for m in range(len(posid[i])):
            if posid[i][m] == i:
                pos_instances[index_u[i]].append(k[index_u[posid[i][m]]])
                pos_instances[index_u[i]].append(
                    k[index_u[posid[i][m]]])  # all candidate pos instances(if not enough, copy itself)
            else:
                pos_instances[index_u[i]].append(q[index_u[posid[i][m]]])
                pos_instances[index_u[i]].append(k[index_u[posid[i][m]]])
        pos_instances[index_u[i]] = torch.stack(pos_instances[index_u[i]])

        for n in range(len(negid[i])):
            neg_instances[index_u[i]].append(q[index_u[negid[i][n]]])
            neg_instances[index_u[i]].append(k[index_u[negid[i][n]]])
        # 新增：检查负样本是否为空

        if len(neg_instances[index_u[i]]) > 0:
            neg_instances[index_u[i]] = torch.stack(neg_instances[index_u[i]])
            all_instances[index_u[i]] = torch.cat([pos_instances[index_u[i]], neg_instances[index_u[i]]], dim=0)
        else:
            all_instances[index_u[i]] = pos_instances[index_u[i]]
            negi = 0

    all_instances = torch.stack(all_instances)  # [batch_size, 2*posi+2*negi, dim]

    # q: [n,c],  all: [n,m+r,c],  compute logits
    # q[n,c] -> newq[n,1,c]     all[n,m+r,c] -> all[n,c,m+r]
    # newq[n,1,c] x all[n,c,m+r] = logits[n,1,m+r]

    all_instances = torch.reshape(all_instances,
                                  (all_instances.shape[0], all_instances.shape[2], all_instances.shape[1]))
    # [batch_size, 2*posi+2*negi, dim] -> [batch_size, dim, 2*posi+2*negi]

    # logits of instances
    newq = q.unsqueeze(1)
    # q[batch_size,dim] -> newq[batch_size,1,dim]

    logits = torch.einsum('nab,nbc->nac', [newq, all_instances])
    # [batch_size,1,dim] x [batch_size, dim, 2*posi+2*negi] = [batch_size,1,2*posi+2*negi]

    logits = logits.squeeze(1)

    # nc,c(m+r) ->n(m+r)
    # [batchsize,dim] * [dim,(2pos+2neg)*batchsize] = [batchsize, (2pos+2neg)*batchsize]
    # logits = torch.einsum('nc,ck->nk', [q, all_instances])

    # apply temperature
    if usetemp:
        # print('----------usetemp-----------')
        logits /= tempi

    # labels of instances
    temp_label = np.zeros(posi * 2 + negi * 2)
    temp_label[0: posi * 2] = 1
    labels = np.tile(temp_label, (q.shape[0], 1))  # [B,2posi+2negi] each row has 2posi label1 and 2negi label0
    # print(f'labels of instances:{labels}')
    # print(f'labels of instances.shape:{labels.shape}')

    # """cluster-level contrastive learning uses multiple partitions"""
    for n, (im2cluster, prototypes) in enumerate(
            zip(cluster_result['im2cluster'], cluster_result['centroids'])):
        # print(f'n:{n}') # partition-layer
        if n == 1:
            break
        # get positive prototypes
        pos_proto_id = im2cluster[index]

        pos_prototypes = prototypes[pos_proto_id]

        all_proto_id = [i for i in range(im2cluster.max() + 1)]

        new = pos_proto_id.split(1, 0)
        neg_proto_id = []
        new_pos_proto_id = []  # the cluster label each instance belonging to

        pos_next_partition_label = {}
        neg_next_partition_label = [{} for _ in range(len(pos_proto_id))]

        # maxlen = 0
        pdict = {}  # partition_dict

        for i in range(len(pos_proto_id)):
            new_pos_proto_id.append(new[i].tolist())  # pos prototype
            """sample negative prototypes - random select neg"""
            # neg_proto_id.append(sample(list(set(all_proto_id) - set(new[i].tolist())), self.negp))

            """mask fake negative prototypes"""
            neg_proto_id.append(list(set(all_proto_id) - set(new[i].tolist())))
            # neg_proto_id[i]=list(set(all_proto_id) - set(new[i].tolist()))
            m = c.T
            if new[i] not in pdict.keys():
                pos_next_partition_label[i] = m[n + 1][np.argwhere(m[n] == int(new[i]))[0][0]]
                pdict[int(new[i])] = pos_next_partition_label[i]

            mask_list = []
            for j in range(0, len(neg_proto_id[i])):
                if neg_proto_id[i][j] not in pdict.keys():
                    neg_next_partition_label[i][j] = m[n + 1][np.argwhere(m[n] == int(neg_proto_id[i][j]))[0][0]]
                    pdict[int(neg_proto_id[i][j])] = neg_next_partition_label[i][j]
                    if pos_next_partition_label[i] == neg_next_partition_label[i][j]:
                        mask_list.append(neg_proto_id[i][j])  # all fake negs that need to be masked
                else:
                    if pdict[int(new[i])] == pdict[neg_proto_id[i][j]]:
                        mask_list.append(neg_proto_id[i][j])  # all fake negs that need to be masked

            for a in range(0, len(mask_list)):
                neg_proto_id[i].remove(mask_list[a])
                new_pos_proto_id[i].append(mask_list[a])

            """1- random sample n negative prototypes after masking"""
            if len(neg_proto_id[i]) != 0:
                neg_proto_id[i] = choices(neg_proto_id[i], k=negp)
            # print(f'after masking_len(neg_proto_id[i]):{len(neg_proto_id[i])}')
            # if len(neg_proto_id[i]) >= maxlen:
            #     maxlen = len(neg_proto_id[i])

            # pos prototype : 1 current centroid + (pos-1) other centroids with same parent)
            if len(new_pos_proto_id[i]) <= posp - 1:
                for _ in range(0, posp - len(new_pos_proto_id[i])):
                    new_pos_proto_id[i].append(new_pos_proto_id[i][0])
            new_pos_proto_id[i] = [new_pos_proto_id[i][0]] + sample(new_pos_proto_id[i][1:], posp - 1)

        neg_prototypes = torch.zeros([pos_prototypes.shape[0], negp, pos_prototypes.shape[1]]).cuda()
        new_pos_prototypes = torch.zeros([pos_prototypes.shape[0], posp, pos_prototypes.shape[1]]).cuda()

        """2- all negative prototypes after masking"""
        for i in range(len(new_pos_proto_id)):
            new_pos_prototypes[i] = prototypes[new_pos_proto_id[i]]

        for i in range(len(neg_proto_id)):  # pos_proto_id
            if len(neg_proto_id[i]) != 0:
                neg_prototypes[i] = prototypes[neg_proto_id[i]]
        proto_selected = torch.cat([new_pos_prototypes, neg_prototypes], dim=1)  # [batch_size, pos+neg, dim]

        # if use neg queue:
        # keys = concat_all_gather(k) #note:no name k
        # batch_size = keys.shape[0]
        # ptr = int(self.queue_ptr)
        # new_neg_prototypes = self.queue[:, ptr: ptr + batch_size * self.negp]
        # proto_selected = torch.cat([new_pos_prototypes, new_neg_prototypes], dim=1)
        # self.queue = torch.cat([proto_selected.detach(), new_neg_prototypes], dim=1)[:, :self.queue_size]

        # compute cluster-wise logits/prototypes
        # q[batch_size, dim]    proto_selected[batch_size, posp+negp, dim]   compute logits_proto
        # q[n,c] -> newq[n,1,c]     all[n,m+r,c] -> all[n,c,m+r]

        newq = q.unsqueeze(1)
        # print(f'newq.shape:{newq.shape}')

        proto_selected = torch.reshape(proto_selected,
                                       (proto_selected.shape[0], proto_selected.shape[2], proto_selected.shape[1]))
        # [batch_size, posp+negp, dim] -> [batch_size, dim, posp+negp]
        # print(f'proto_selected.shape:{proto_selected.shape}')
        # newq[n,1,c] x all[n,c,m+r] = logits[n,1,m+r]

        logits_proto = torch.einsum('nab,nbc->nac', [newq, proto_selected])
        # [batch_size,1,dim] x [batch_size, dim, posp+negp] = [batch_size, 1, posp+negp]
        # print(f'logits_proto.shape:{logits_proto.shape}')
        # print(f'logits_proto:{logits_proto}')

        logits_proto = logits_proto.squeeze(1)
        # print(f'logits_proto.shape:{logits_proto.shape}')
        # print(f'logits_proto:{logits_proto}')

        # labels of prototypes
        temp_proto_label = np.zeros(posp + negp)
        temp_proto_label[0: posp] = 1
        labels_proto = np.tile(temp_proto_label,
                               (q.shape[0], 1))  # [B,2posi+2negi] each row has posp label1 and negp label0
        # print(f'labels of prototypes.shape:{labels_proto.shape}')
        # print(f'labels of prototypes:{labels_proto}')

        # scaling temperatures for the selected prototypes
        # temp_proto = torch.zeros([batch_size, (self.negp + 1)*batch_size]).cuda()  # [batch_size,(1+n)*batch_size]
        if usetemp:
            logits_proto /= tempp

        proto_labels.append(labels_proto)
        proto_logits.append(logits_proto)

    return logits, labels, proto_logits, proto_labels





def compute_features(data_loader, model, args):
    # print('Computing features...')
    model.eval()
    features = torch.zeros(3 * len(data_loader.dataset), args.low_dim).cuda()

    for i, (data, target, aug1, aug2, index) in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            data = data.unsqueeze(3)
            aug1 = aug1.unsqueeze(3)
            aug2 = aug2.unsqueeze(3)

            data, target = data.float().cuda(non_blocking=True), target.long().cuda(non_blocking=True)
            aug1, aug2 = aug1.float().cuda(non_blocking=True), aug2.float().cuda(non_blocking=True)
            feat = model(data, is_eval=True)
            features[index] = feat

            feat_aug1 = model(aug1, is_eval=True)
            features[index + len(data_loader.dataset)] = feat_aug1
            feat_aug2 = model(aug2, is_eval=True)
            features[index + 2 * len(data_loader.dataset)] = feat_aug2



    return features.cpu()


def cool_mean(data, partition, max_dis_list=None):
    s = data.shape[0]
    un, nf = np.unique(partition, return_counts=True)

    row = np.arange(0, s)
    col = partition
    d = np.ones(s, dtype='float32')

    if max_dis_list is not None:
        for i in max_dis_list:
            data[i] = 0
        nf = nf - 1

    umat = sp.csr_matrix((d, (row, col)), shape=(s, len(un)))
    cluster_rep = umat.T @ data
    cluster_mean_rep = cluster_rep / nf[..., np.newaxis]

    return cluster_mean_rep

def hierarchical_clustering(x, initial_rank=None, distance='cosine', ensure_early_exit=True, verbose=True,
                            ann_threshold=40000,layers=2):
    """
    x: input matrix with features in rows.(n_samples, n_features)
    initial_rank: Nx1 first integer neighbor indices (optional). (n_samples, 1)
    req_clust: set output number of clusters (optional).
    distance: one of ['cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan']
    ensure_early_exit: [Optional flag] may help in large, high dim datasets, ensure purity of merges and helps early exit
    verbose: print verbose output.
    ann_threshold: int (default 40000) Data size threshold below which nearest neighbors are approximated with ANNs.
    """
    # print('Performing finch clustering')
    req_clust = None
    mask_mode = 'mask_farthest'
    mask_layer0 = True
    replace_centroids =True
    mask_others =True
    dist_threshold = 0.7
    proportion = 0.5



    results = {'im2cluster': [], 'centroids': [], 'density': []}

    x = x.astype(np.float32)
    min_sim = None

    # calculate pairwise similarity orig_dis to find the nearest neighbor and obtain the adj matrix
    adj, orig_dist, first_neighbors, _ = clust_rank(
        x,
        initial_rank,
        distance,
        verbose=verbose,
        ann_threshold=ann_threshold
    )

    initial_rank = None

    # eg: 520---119---31---6
    # obtain clusters by connecting nodes using the adj matrix obtained by cluster_rank
    u, num_clust = get_clust(adj, [], min_sim)

    # group: the parent classes of all subclass nodes, cluster labels, num_cluster: components
    c, mat = get_merge([], u, x)  # obtain the centroids according to the partition and raw data

    """find the points farthest from the centroids in each cluster and mask these points in next round of clustering"""
    # orig_dist:  distance between the original samples
    # recalculate distance between points and centroids
    # x:(2617, 128)  mat:(521,128)  group:(2617,)

    # step1: define cluster dict, key: cluster_label，value: id list of the cluster_label
    # outliers_dis dic, key: cluster_label，value: distance between each point and the centroid of cluster it belongs to
    cluster = defaultdict(list)
    outliers_dist = defaultdict(list)

    for i in range(0, len(u)):  # u: current partition, c: all partitions
        cluster[u[i]].append(i)
        outliers_dist[u[i]].append(i)

    # step 2: compute euclidean(x[cluster[i]],mat[i]) -> find max centroids_dist

    # max_dist_list is used to access the points farthest from centroids in each cluster,
    # these points will be masked, and then the centroids will be recalculated
    max_dis_list = []
    min_dis_dict = dict()

    """mask strategy"""
    # mode 1: mask one point farthest from the centroid of each cluster
    if mask_mode == 'mask_farthest':
        for i in range(0, num_clust):  # calculate the distance between points and the centroids within each cluster
            maxd = 0
            mind = sys.maxsize
            for j in range(0, len(cluster[i])):
                d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                if mind >= d:
                    mind = d
                    minindex = cluster[i][j]
                if maxd <= d:
                    maxd = d
                    maxindex = cluster[i][j]
            max_dis_list.append(maxindex)
            min_dis_dict[i] = minindex

    # mode 2: mask the points whose distance from the centroid is above the specified threshold in each cluster
    elif mask_mode == 'mask_threshold':
        for i in range(0, num_clust):
            mind = sys.maxsize
            for j in range(0, len(cluster[i])):
                d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                if mind >= d:
                    mind = d
                    minindex = cluster[i][j]
                if d > dist_threshold:
                    max_dis_list.append(cluster[i][j])
            min_dis_dict[i] = minindex

    # mode 3: mask the points with the specified proportion farthest from the centroid of each cluster
    elif mask_mode == 'mask_proportion':
        for i in range(0, num_clust):
            mind = sys.maxsize
            for j in range(0, len(cluster[i])):
                d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                if mind >= d:
                    mind = d
                    minindex = cluster[i][j]
                outliers_dist[i][j] = d  # save the distance between jth point and the centroid in ith cluster
            t = copy.deepcopy(outliers_dist[i])
            for _ in range(round(len(outliers_dist[i]) * proportion)):
                dist = max(t)
                index = t.index(dist)
                t[index] = 0
                max_dis_list.append(cluster[i][index])
            t = []
            min_dis_dict[i] = minindex

    # step 3: obtain the centroids according to the partition and raw data
    """ Recalculate the centroids at layer 0 (only mask points at the first step of clustering)"""
    if mask_layer0 is True:
        mat = cool_mean(x, u, max_dis_list)

    # clustering at layer 0 (bottom layer) end

    # begin clustering at following layers through the while loop

    """ replace computed prototypes with raw data """
    if replace_centroids is True:
        for i in range(0, num_clust):
            mat[i] = x[min_dis_dict[i]]

    lowest_level_centroids = mat

    ''' save centroids of the bottom layer (layer 0)'''
    lowest_centroids = torch.Tensor(lowest_level_centroids).cuda()
    results['centroids'].append(lowest_centroids)

    if verbose:
        print('Level/Partition 0: {} clusters'.format(num_clust))

    if ensure_early_exit:
        if orig_dist.shape[-1] > 2:
            min_sim = np.max(orig_dist * adj.toarray())

    exit_clust = 2
    c_ = c  # transfer value first and then mask

    k = 1
    num_clust = [num_clust]  # int->list
    partition_clustering = []
    # while exit_clust > 1:
    while k < layers+1:
        adj, orig_dist, first_neighbors, knn_index = clust_rank(
            mat,
            initial_rank,
            distance,
            verbose=verbose,
            ann_threshold=ann_threshold
        )

        u, num_clust_curr = get_clust(adj, orig_dist, min_sim)  # u = group

        partition_clustering.append(u)  # all partitions (u: current partition)

        c_, mat = get_merge(c_, u, x)
        c = np.column_stack((c, c_))

        num_clust.append(num_clust_curr)
        exit_clust = num_clust[-2] - num_clust_curr

        # if num_clust_curr == 1 or exit_clust <= 1:
        #     num_clust = num_clust[:-1]
        #     c = c[:, :-1]
        #     break

        if verbose:
            print('Level/Partition {}: {} clusters'.format(k, num_clust[k]))

        ''' save the controids of the bottom args.layers '''
        # max_dis_dict = dict()
        max_dis_list = []
        min_dis_dict = dict()

        """mask strategy"""
        # mode 1: mask one point farthest from the centroid of each cluster
        if mask_mode == 'mask_farthest':
            for i in range(0, mat.shape[0]):
                maxd = 0
                mind = sys.maxsize
                for j in range(0, len(cluster[i])):
                    d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                    if mind >= d:
                        mind = d
                        minindex = cluster[i][j]
                    if maxd <= d:
                        maxd = d
                        maxindex = cluster[i][j]
                max_dis_list.append(maxindex)
                min_dis_dict[i] = minindex

        # mode 2: mask the points whose distance from the centroid is above the specified threshold in each cluster
        elif mask_mode == 'mask_threshold':
            for i in range(0, mat.shape[0]):
                mind = sys.maxsize
                for j in range(0, len(cluster[i])):
                    d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                    if mind >= d:
                        mind = d
                        minindex = cluster[i][j]
                    if d > dist_threshold:
                        max_dis_list.append(cluster[i][j])
                min_dis_dict[i] = minindex

        # mode 3: mask the points with the specified proportion farthest from the centroid of each cluster
        elif mask_mode == 'mask_proportion':
            for i in range(0, mat.shape[0]):
                mind = sys.maxsize
                for j in range(0, len(cluster[i])):
                    d = ss.distance.euclidean(mat[i], x[cluster[i][j]])
                    if mind >= d:
                        mind = d
                        minindex = cluster[i][j]
                    outliers_dist[i][j] = d
                t = copy.deepcopy(outliers_dist[i])
                for _ in range(round(len(outliers_dist[i]) * proportion)):
                    dist = max(t)
                    index = t.index(dist)
                    t[index] = 0
                    max_dis_list.append(cluster[i][index])
                t = []
                min_dis_dict[i] = minindex

        """ replace computed prototypes with raw data """
        if replace_centroids is True:
            for i in range(0, mat.shape[0]):
                mat[i] = x[min_dis_dict[i]]

        """ Recalculate the centroids at top layers (except layer 0)"""
        if mask_others is True:
            np.savetxt('mat_beforemask.txt', mat, delimiter=',', fmt='%s')
            mat = cool_mean(x, c_, max_dis_list)
            np.savetxt('mat_aftermask.txt', mat, delimiter=',', fmt='%s')

        ''' save the controids at args.layers '''
        # if args.layers=3 means: save 533 131 32  from [533, 131, 32, 7, 2]
        if k < layers:
            centroids = torch.Tensor(mat).cuda()
            results['centroids'].append(centroids)

        k += 1

    if req_clust is not None:
        print(f'req_clust:{req_clust}')
        print('yes')
        if req_clust not in num_clust:
            print('notinyes')
            ind = [i for i, v in enumerate(num_clust) if v >= req_clust]
            req_c = req_numclust(c[:, ind[-1]], x, req_clust, distance)
        else:
            print('inyes')
            req_c = c[:, num_clust.index(req_clust)]
    else:
        req_c = None

    """ save multiple partitions """
    # save 131 32 7 from [533, 131, 32, 7, 2]
    for i in range(0, layers):
        im2cluster = [int(n[i]) for n in c]
        im2cluster = torch.LongTensor(im2cluster).cuda()
        results['im2cluster'].append(im2cluster)

    return c, num_clust, partition_clustering, lowest_level_centroids, req_c, results
    # c: NxP matrix. cluster label for every partition P. array(n_samples, n_partitions)
    # num_clust: number of clusters. array(n_partitions)
    # partition_clustering: list of arrays with labels indicating the centroids cluster participation per level. list of arrays of shapes equal to the values of num_clust
    # lowest_level_centroids: feature coordinates of the lowest level centroids. array(num_clust[0], n_features)
    # req_c: labels of required clusters (Nx1). only set if req_clust is not None.


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def clust_rank(
        mat,
        initial_rank=None,
        metric='cosine',
        verbose=False,
        ann_threshold=40000):
    knn_index = None
    s = mat.shape[0]
    if initial_rank is not None:
        orig_dist = []
    elif s <= ann_threshold:
        # If the sample size is smaller than threshold, use metric to calculate similarity.
        # If the sample size is larger than threshold, use PyNNDecent to speed up the calculation of nearest neighbor
        orig_dist = metrics.pairwise.pairwise_distances(mat, mat, metric=metric)
        np.fill_diagonal(orig_dist, 1e12)
        initial_rank = np.argmin(orig_dist, axis=1)
    else:
        if verbose:
            print('Using PyNNDescent to compute 1st-neighbours at this step ...')
        knn_index = NNDescent(
            mat,
            n_neighbors=2,
            metric=metric,
            verbose=verbose)
        result, orig_dist = knn_index.neighbor_graph
        initial_rank = result[:, 1]
        orig_dist[:, 0] = 1e12
        if verbose:
            print('Step PyNNDescent done ...')

    sparce_adjacency_matrix = sp.csr_matrix(
        (np.ones_like(initial_rank, dtype=np.float32),
         (np.arange(0, s), initial_rank)),
        shape=(s, s))  # join adjacency matrix based on Initial rank

    return sparce_adjacency_matrix, orig_dist, initial_rank, knn_index


def get_clust(a, orig_dist, min_sim=None):
    # connect nodes based on adj, orig_dist, min_sim
    # build the graph and obtain multiple components/clusters
    if min_sim is not None:
        a[np.where((orig_dist * a.toarray()) > min_sim)] = 0

    num_clust, u = sp.csgraph.connected_components(csgraph=a, directed=True, connection='weak', return_labels=True)

    return u, num_clust


def get_merge(partition, group, data):
    # get_merge([], group, x)
    # u/group: (n,)  data/x: (n, dim)
    if len(partition) != 0:
        _, ig = np.unique(partition, return_inverse=True)
        partition = group[ig]
    else:
        partition = group

    mat = cool_mean(data, partition, max_dis_list=None)  # mat: computed centroids(k,dim)
    # data: (n, dim)   partition: (n,)  return:(k, dim)
    return partition, mat


def update_adj(adj, d):
    # Update adj, keep one merge at a time
    idx = adj.nonzero()
    v = np.argsort(d[idx])
    v = v[:2]
    x = [idx[0][v[0]], idx[0][v[1]]]
    y = [idx[1][v[0]], idx[1][v[1]]]
    a = sp.lil_matrix(adj.get_shape())
    a[x, y] = 1
    return a


def req_numclust(c, data, req_clust, distance):
    print('update when req_clust is specified')
    iter_ = len(np.unique(c)) - req_clust
    c_, mat = get_merge([], c, data)
    for i in range(iter_):
        adj, orig_dist, _, _ = clust_rank(mat, initial_rank=None, metric=distance)
        adj = update_adj(adj, orig_dist)
        u, _ = get_clust(adj, [], min_sim=None)
        c_, mat = get_merge(c_, u, data)
    return c_


class AugSem(nn.Module):
    def __init__(self, d_model=64, num_prototypes=1, seq_len=64, gamma=0.01, n_fft=32, hop_length=16, win_length=32,
                 lambda_prototype=0.1, lambda_alignment=0.1, args=None, prototype_dim=64, alignment_dim=64):
        super(AugSem, self).__init__()

        if args is None:
            class ArgsPlaceholder:  # pragma: no cover
                def __init__(self):
                    self.n_fft = n_fft
                    self.hop_length = hop_length
                    self.win_length = win_length
                    self.magnitude_noise_std = 0.1
                    self.mask_prob = 0.5

            self.args = ArgsPlaceholder()
            # print("Warning: AugSem initialized without args, using default/init values for some parameters.")
        else:
            self.args = args

        self.gamma = gamma
        self.n_fft = self.args.n_fft
        self.hop_length = self.args.hop_length
        self.win_length = self.args.win_length

        self.sdtw = SoftDTW_align(True, gamma=self.args.dtw_gamma1, normalize=False)
        self.sdtw1 = SoftDTW_align(True, gamma=self.args.dtw_gamma2, normalize=False)


        self.lambda_prototype = lambda_prototype
        self.lambda_alignment = lambda_alignment
        self.alignment_dim = alignment_dim
        self.magnitude_noise_std = getattr(self.args, 'magnitude_noise_std', 0.1)
        self.phase_adj = getattr(self.args, 'phase_adj', 1.0)

        self.mask_prob = getattr(self.args, 'mask_prob', 0.5)



    def compute_loss(self, reconstructed_signal, input_signal, prototype_signal=None, num_prototypes=None):

        recon_loss = F.mse_loss(reconstructed_signal, input_signal)
        if prototype_signal is None:  # pragma: no cover
            return recon_loss


        batch_size, num_vars, seq_len = input_signal.size()


        x_flat = reconstructed_signal.reshape(-1, seq_len)
        x_expanded = x_flat  # type: ignore
        X = x_expanded.reshape(-1, seq_len).unsqueeze(-1)
        X = X.requires_grad_()
        distances = self.sdtw(X, prototype_signal)
        prototype_loss = torch.mean(distances)
        total_loss = recon_loss + self.lambda_prototype * prototype_loss
        return total_loss

    def compute_time_shifts(self, alignment_matrix, window_size, hop_length, use_threshold=False, threshold=0.5,
                            epsilon=1e-8,
                            weight_scale=30):

        batch_size, seqx_len, seqy_len = alignment_matrix.shape
        device = alignment_matrix.device

        shifts = torch.arange(seqy_len, device=device).view(1, 1, seqy_len) - \
                 torch.arange(seqx_len, device=device).view(1, seqx_len, 1)
        shifts = shifts.expand(batch_size, -1, -1).float()

        alignment_sum = alignment_matrix.sum(dim=2) + epsilon
        time_shift_per_row = (alignment_matrix * shifts).sum(dim=2) / alignment_sum

        row_max_confidence = alignment_matrix.max(dim=2).values  # Renamed from row_max
        confidence = row_max_confidence / alignment_sum
        confidence = torch.clamp(confidence, min=0.0)

        if use_threshold:
            confident_mask = confidence > threshold  # Renamed from confident
            confidence_weights = confident_mask.float() * 1.0 + (~confident_mask).float() * 0.1
            # Apply softmax to these binary-like weights per window
            unfolded_weights = confidence_weights.unfold(dimension=1, size=window_size, step=hop_length)
            if unfolded_weights.numel() > 0:  # Guard against empty tensor for softmax
                window_confidence_softmax = F.softmax(unfolded_weights, dim=2)
            else:  # Should not happen if seqx_len >= window_size
                window_confidence_softmax = unfolded_weights
        else:
            confidence_scaled = confidence * weight_scale  # Renamed from confidence_weights
            unfolded_scaled_confidence = confidence_scaled.unfold(dimension=1, size=window_size, step=hop_length)
            if unfolded_scaled_confidence.numel() > 0:
                window_confidence_softmax = F.softmax(unfolded_scaled_confidence, dim=2)
            else:
                window_confidence_softmax = unfolded_scaled_confidence

        time_shift_unfolded = time_shift_per_row.unfold(dimension=1, size=window_size, step=hop_length)

        # Ensure dimensions match for multiplication if one of them became empty due to unfold
        if time_shift_unfolded.shape == window_confidence_softmax.shape and time_shift_unfolded.numel() > 0:
            window_time_shift = (time_shift_unfolded * window_confidence_softmax).sum(dim=2)
        elif time_shift_unfolded.numel() == 0:  # No windows produced
            window_time_shift = torch.zeros(batch_size, 0, device=device)  # Return empty tensor of correct rank
        else:  # Shapes mismatch post-unfold (e.g. one empty, one not - should not happen if seqx_len processed correctly)
            # This case implies an issue with unfold or input sizes, returning sum of unfolded time shifts as a fallback
            print(
                f"Warning: Mismatch or empty tensor in compute_time_shifts. time_shift_unfolded: {time_shift_unfolded.shape}, window_confidence_softmax: {window_confidence_softmax.shape}")
            window_time_shift = time_shift_unfolded.sum(dim=2)  # Fallback, likely not what's desired but avoids crash

        return window_time_shift

    def extract_subsequence_features_with_mask(self, x: torch.Tensor, alignment_01: torch.Tensor) -> torch.Tensor:

        device = x.device
        B, T_x, V = x.shape

        alignment_01_proc = alignment_01
        if alignment_01.dim() == 4 and alignment_01.shape[1] == 1:
            alignment_01_proc = alignment_01.squeeze(1)
        elif alignment_01.dim() != 3:
            raise ValueError(
                f"Unexpected alignment_01 shape: {alignment_01.shape}, expected 3D or 4D with squeezeable dim.")

        if alignment_01_proc.shape[0] != B or alignment_01_proc.shape[1] != T_x:
            # Allow broadcasting if V=1 for alignment_01_proc, e.g. alignment for one var applied to all.
            # This check is too strict if alignment_01_proc might be [B, T_x, T_y] and x is [B, T_x, V_x]
            # The current use case: x is [B_eff, T_x, 1], alignment_01 is [B_eff, T_x, T_y]
            # So this check is fine for current usage.
            raise ValueError(
                f"Processed alignment_01 shape {alignment_01_proc.shape} incompatible with x shape {x.shape}")

        row_segments = torch.cumsum(alignment_01_proc, dim=-1)
        row_break_mask = (row_segments >= self.args.row_threshold).float() * alignment_01_proc
        col_segments = torch.cumsum(alignment_01_proc, dim=-2)
        col_segments = col_segments * alignment_01_proc
        col_break_mask = (col_segments >= self.args.column_threshold).float()
        row_broken = row_break_mask.any(dim=-1)
        col_contribution_to_break = alignment_01_proc * col_break_mask
        row_involved_in_col_break = col_contribution_to_break.any(dim=-1)
        is_broken_row = (row_broken | row_involved_in_col_break).float()
        final_mask = 1.0 - is_broken_row




        final_mask_expanded = final_mask.unsqueeze(-1).expand_as(x)
        aligned_x = x * final_mask_expanded



        ##nonalign
        ##noise
        non_aligned_selector_expanded = (1.0 - final_mask_expanded)  # 1 for non-aligned, 0 for aligned
        original_non_aligned_parts = x * non_aligned_selector_expanded
        noise_perturbation = torch.randn_like(original_non_aligned_parts) * self.args.non_aligned_noise_std
        noised_non_aligned_parts = original_non_aligned_parts * (1 + noise_perturbation)
        ##mask
        rand_mask_t_for_non_aligned = (torch.rand(B, T_x, device=device) < self.args.mask_prob).float()
        rand_mask_expanded_for_non_aligned = rand_mask_t_for_non_aligned.unsqueeze(-1).expand_as(x)

        # Apply the random mask to the noised_non_aligned_parts
        # The noised_non_aligned_parts is already zero where final_mask_expanded was 1 (aligned regions).
        # So, this operation correctly masks only the non-aligned, noised regions.
        x_non_aligned_output = noised_non_aligned_parts * rand_mask_expanded_for_non_aligned

        return aligned_x, x_non_aligned_output

    def aug(self, x, prototype, original_prototype_was_multivariate,
            freq_aug=None):


        device = x.device

        if original_prototype_was_multivariate:
            batch_size, num_vars, seq_len = x.size()
            batch_size1, _, seq_len_U = prototype.size()

            device = x.device


            X = x.permute(0, 2, 1).to(device).detach()  # [N, seq_len, num_vars]
            Y1 = prototype.permute(0, 2, 1).to(device).detach()  # [P, seq_len_U, num_vars]
            N, P = X.shape[0], Y1.shape[0]


            Y2 = Y1
            if seq_len_U != seq_len:
                if seq_len_U > seq_len:
                    Y2 = Y1[:, :seq_len, :]
                else:
                    pad_amt = seq_len - seq_len_U
                    Y2 = F.pad(Y1, (0, 0, 0, pad_amt))
                seq_len_U = seq_len

            with torch.no_grad():

                D = seq_len * num_vars
                X_flat = X.reshape(N, D)
                Y_flat = Y2.reshape(P, D)

                X_rep = X_flat.unsqueeze(1).expand(N, P, D).reshape(-1, D)
                Y_rep = Y_flat.unsqueeze(0).expand(N, P, D).reshape(-1, D)

                dist = (X_rep - Y_rep).norm(dim=1).view(N, P)  # [N, P]


                del X_rep, Y_rep

                torch.cuda.empty_cache()

            min_index = dist.argmin(dim=1)  # [N]
            Y = Y1[min_index]  # [N, seq_len, num_vars]


            X = X.requires_grad_()
            start = time.time()

            align = self.sdtw.align(X, Y)
            alignment_01 = self.sdtw1.align(X, Y)

        else:  # original_prototype_was_multivariate is True
            batch_size, num_vars, seq_len = x.size()
            batch_size1, seq_len_U = prototype.size()
            if batch_size1 > batch_size:
                prototype = prototype[:batch_size, :]

            device = x.device

            # X : [N,  seq_len, 1]
            X = x.reshape(-1, seq_len).unsqueeze(-1).to(device)  # N = batch_size * num_vars
            # Y1: [P,  seq_len_U, 1]
            Y1 = prototype.unsqueeze(-1).to(device)  # P = batch_size1

            N, P = X.shape[0], Y1.shape[0]
            Y2 = Y1

            if seq_len_U != seq_len:
                if seq_len_U > seq_len:
                    Y2 = Y1[:, :seq_len, :]
                else:
                    pad_amt = seq_len - seq_len_U
                    Y2 = F.pad(Y1, (0, 0, 0, pad_amt))
                seq_len_U = seq_len


            with torch.no_grad():

                X_rep = X.unsqueeze(1).expand(N, P, seq_len, 1).reshape(-1, seq_len)  # [N*P, seq_len,     1]

                Y_rep = Y2.unsqueeze(0).expand(N, P, seq_len_U, 1).reshape(-1, seq_len_U)  # [N*P, seq_len_U, 1]

                dist = (X_rep - Y_rep).norm(dim=1).view(N, P)  # [N, P]


                del X_rep, Y_rep

                torch.cuda.empty_cache()

            min_index = dist.argmin(dim=1)  # [N]
            Y = Y1[min_index]  # [N, seq_len, num_vars]
            Y_return = Y[0,:,0]

            X = X.requires_grad_()
            align = self.sdtw.align(X, Y)

            alignment_01 = self.sdtw1.align(X, Y)


        end = time.time()
        # print(end-start)


        def freq():

            x_stft = torch.stft(
                x.reshape(-1, seq_len),
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                win_length=self.win_length,
                window=torch.ones(self.win_length).to(x.device),
                center=False,
                return_complex=True,
                onesided=False
            )
            x_magnitude = torch.abs(x_stft)  # [batch_size * num_vars, freq_bins, time_steps]
            x_phase = torch.angle(x_stft)  # [batch_size * num_vars, freq_bins, time_steps]


            time_steps = x_stft.size(2)
            frame_indices = (torch.arange(time_steps).to(device) * self.hop_length + self.win_length // 2).long()
            frame_indices = torch.clamp(frame_indices, max=seq_len - 1)  # 确保索引不超出范围


            delta_t = self.compute_time_shifts(align, self.win_length,
                                               self.hop_length)  # [batch_size * num_vars * num_prototypes, time_steps]
            # delta_t=torch.tensor([[6]]).cuda()

            # lambda_scale = self.compute_lambda_scale(alignment_matrix, frame_indices)  # [batch_size * num_vars * num_prototypes, time_steps]


            # delta_t = delta_t.unsqueeze(1)  # [batch_size * num_vars * num_prototypes, 1, time_steps]
            if original_prototype_was_multivariate:
                delta_t_expanded_for_phi = delta_t.unsqueeze(1).repeat(1, num_vars, 1).reshape(-1, 1, time_steps)
            else:
                delta_t_expanded_for_phi = delta_t.unsqueeze(1).reshape(-1, 1, time_steps)
            


            freqs = torch.fft.fftfreq(self.n_fft, d=1.0)[:self.n_fft].to(device)  # [freq_bins]

            freqs = freqs.unsqueeze(0).unsqueeze(2)  # [1, freq_bins, 1]


            # Broadcasting: (1, F, 1) * ((B*V), 1, T_freq) -> ((B*V), F, T_freq)
            phi_adjust = -2 * torch.pi * freqs * delta_t_expanded_for_phi

            x_phase_adjusted = x_phase + self.args.phase_adj * phi_adjust  # [batch_size * num_vars * num_prototypes, freq_bins, time_steps]

            ##
            magnitude_noise = torch.randn_like(x_magnitude) * self.args.magnitude_noise_std
            x_magnitude_noisy = x_magnitude * (1.0 + magnitude_noise)

            # Reconstruct complex STFT from magnitude and phase
            x_stft_reconstructed = x_magnitude_noisy * torch.exp(
                1j * x_phase_adjusted)  # [batch_size * num_vars, freq_bins, time_steps]

            # Convert to a supported type before mean calculation
            x_stft_reconstructed = x_stft_reconstructed.to(
                torch.complex64)  # Convert to ComplexFloat for CUDA compatibility


            x_stft_reconstructed_mean = x_stft_reconstructed.reshape(-1, 1, x_stft_reconstructed.shape[-2],
                                                             x_stft_reconstructed.shape[-1]).mean(
                dim=1)  # Now this should work on CUDA

            x_adjusted_signal = torch.istft(
                x_stft_reconstructed_mean,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                win_length=self.win_length,
                window=torch.ones(self.win_length).to(x.device),
                center=False,
                length=seq_len,
                onesided=False
            )  # [batch_size_num_vars, seq_len]

            x_adjusted_signal = x_adjusted_signal[:, :seq_len]


            x_adjusted = x_adjusted_signal.reshape(batch_size, num_vars, seq_len)  # [batch_size, num_vars, seq_len]

            return x_adjusted


        if freq_aug:
            x_adjusted = freq()
            x_adjusted = x_adjusted
            return x_adjusted
        else:



            eps = 1e-8
            alpha=1




            row_min_values = alignment_01.min(dim=2).values.unsqueeze(-1)
            row_max_values = alignment_01.max(dim=2).values.unsqueeze(-1)

            norm_mat = (alignment_01 - row_min_values) / (row_max_values - row_min_values + eps)

            stretched = norm_mat ** alpha


            row_threshold = (row_min_values.squeeze(-1) + row_max_values.squeeze(-1)) / 50.0

            row_threshold = row_threshold.unsqueeze(-1)  # [B, X, 1]
            binary_mat = (stretched > row_threshold).int()
            # colab_show_dtw(binary_mat)

            _, _, col = binary_mat.shape


            cumsum_mat = binary_mat.cumsum(dim=2)


            last_ones = (cumsum_mat == cumsum_mat.max(dim=2, keepdim=True).values).int()

            last_ones_idx = last_ones.argmax(dim=2)

            j_indices = torch.arange(col, device=binary_mat.device).view(1, 1, col)
            ###------------------------------------
            last_ones_idx1 = torch.roll(last_ones_idx, shifts=1, dims=-1)
            last_ones_idx1[:, 0] = 0
            avg_ones_idx = last_ones_idx1
            ##----------------------------------------
            avg_ones_idx = avg_ones_idx.unsqueeze(-1)  # [B, X, 1]

            dtw_mask = j_indices >= avg_ones_idx  # [B, X, Y]，确保 j >= last_ones_idx[i]


            binary_mat = binary_mat * dtw_mask




            x_align, x_nonalign = self.extract_subsequence_features_with_mask(
                x=X,  # [B', T]
                alignment_01=binary_mat.squeeze(1)  # [B', T, T_U]
            )

            x_align, x_nonalign = x_align.reshape(batch_size, num_vars, seq_len), x_nonalign.reshape(batch_size, num_vars, seq_len)


            return x_align, x_nonalign