from typing import Tuple

import torch


class CrossCorrelLoss:

    def __init__(self, max_lag: int = 64) -> None:
        self.max_lag = max_lag

    def cacf(self, x: torch.Tensor, max_lag: int, dim: Tuple[int, int] = (0, 1)) -> torch.Tensor:
        """
        Computes the cross-correlation between feature dimension and time dimension
        Parameters
        ----------
        x
        max_lag
        dim

        Returns
        -------

        """

        # Define a helper function to get the lower triangular indices for a given dimension
        def get_lower_triangular_indices(n):
            return [list(x) for x in torch.tril_indices(n, n)]

        # Get the lower triangular indices for the input tensor x
        ind = get_lower_triangular_indices(x.shape[2])

        # Standardize the input tensor x along the given dimensions
        x = (x - x.mean(dim, keepdim=True)) / x.std(dim, keepdim=True)

        # Split the input tensor into left and right parts based on the lower triangular indices
        x_l = x[..., ind[0]]
        x_r = x[..., ind[1]]

        # Compute the cross-correlation at each lag and store in a list
        cacf_list = list()
        for i in range(max_lag):
            # Compute the element-wise product of the left and right parts, shifted by the lag if i > 0
            y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r

            # Compute the mean of the product along the time dimension
            cacf_i = torch.mean(y, (1))

            # Append the result to the list of cross-correlations
            cacf_list.append(cacf_i)

        # Concatenate the cross-correlations across lags and reshape to the desired output shape
        cacf = torch.cat(cacf_list, 1)
        return cacf.reshape(cacf.shape[0], -1, len(ind[0]))

    def __call__(self, x_real: torch.Tensor, x_fake: torch.Tensor) -> torch.Tensor:
        cross_correl_real = self.cacf(x_real, max_lag=self.max_lag).mean(0)[0]
        cross_correl_fake = self.cacf(x_fake, max_lag=self.max_lag).mean(0)[0]
        loss = torch.abs(cross_correl_fake - cross_correl_real.to(x_fake.device)).mean(0)
        return loss
