import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


def get_cross_correlation(ts: Tensor, padding=32) -> Tensor:
    # ts.shape = [batch, seq_len, n_features]

    n_features = ts.shape[-1]
    indices = np.triu_indices(n_features, 1)

    cross_corr = list()
    for i1, i2 in zip(indices[0], indices[1]):
        cross_corr_ = F.conv1d(ts[:, :, i1].unsqueeze(1), ts[:, :, i2].unsqueeze(1), padding=padding)
        cross_corr_ = torch.diagonal(cross_corr_).transpose(1, 0)
        cross_corr.append(cross_corr_)
    cross_corr = torch.stack(cross_corr, dim=1)

    # cross_corr.shape = [batch, (n_features choose 2), *(seq_len, padding)]
    return cross_corr
