from abc import abstractmethod, ABC

from BACKEND import cp

# Pseudometrics used for sampling in S-SWIM

class SamplingMetric(ABC):

    def __init__(self, weight=1.):
        self.metrics = [self]
        self.weight = weight

    def __add__(self, other):
        self.metrics.append(other)
        return self

    def __call__(self, A):
        N = A.shape[0]
        pairwise = cp.zeros((N, N), dtype=cp.float64)
        for metric in self.metrics:
            pairwise += metric.weight * metric.compute_dist(A)
        return pairwise

    @abstractmethod
    def __str__(self):
        pass

    @abstractmethod
    def compute_dist(self, A):
        pass

class FourierMag(SamplingMetric):
    def __str__(self):
        return "FourierMag"

    def __init__(self, weight=1.):
        super().__init__(weight)

    def compute_dist(self, A):
        N, n_in, T = A.shape
        distance_matrix = cp.zeros((N, N))
        AF_T = cp.abs(cp.fft.rfft(A, axis=-1)).astype(cp.float32).transpose(2, 0, 1)
        fT = AF_T.shape[0]

        mult_rfft = cp.ones(fT, dtype=cp.float32) # Mean signal coeff 1
        if T % 2 == 0:
            mult_rfft[1:fT - 1] = 2.0 # All others appear twice because of symmetry
        else:
            mult_rfft[1:fT] = 2.0

        for t in range(fT):
            data_t = AF_T[t]  # Shape: (N, n_in)

            # Compute pairwise differences using broadcasting
            diff = data_t[:, cp.newaxis, :] - data_t[cp.newaxis, :, :]  # Shape: (N, N, n_in)

            # Compute Euclidean norms
            norms = cp.square(diff).sum(axis=2)  # Shape: (N, N)

            # Accumulate distances
            distance_matrix +=  mult_rfft[t] * norms

        return 1 / cp.sqrt(T) * cp.sqrt(distance_matrix)

class FourierAngle(SamplingMetric):
    def __str__(self):
        return "FourierAngle"

    def __init__(self, weight=1.):
        super().__init__(weight)

    def compute_dist(self, A):
        N, n_in, T = A.shape
        distance_matrix = cp.zeros((N, N))
        AF_T = cp.fft.rfft(A, axis=-1).astype(cp.complex64).transpose(2, 0, 1)
        fT = AF_T.shape[0]

        mult_rfft = cp.ones(fT, dtype=cp.float32) # Mean signal coeff 1
        if T % 2 == 0:
            mult_rfft[1:fT - 1] = 2.0 # All others appear twice because of symmetry
        else:
            mult_rfft[1:fT] = 2.0

        for t in range(fT):
            data_t = cp.where(AF_T[t] != 0, AF_T[t] / cp.abs(AF_T[t]), 0) # Shape: (N, n_in)

            # Compute pairwise differences using broadcasting
            diff = data_t[:, cp.newaxis, :] - data_t[cp.newaxis, :, :]  # Shape: (N, N, n_in)

            # Compute Euclidean norms
            norms = cp.abs(diff).sum(axis=2)  # Shape: (N, N)

            # Accumulate distances
            distance_matrix +=  mult_rfft[t] * norms

        return 1 / cp.sqrt(T) * cp.sqrt(distance_matrix)

class PairwiseL2(SamplingMetric):
    def __str__(self):
        return "L2"

    def __init__(self, weight=1.):
        super().__init__(weight)

    def compute_dist(self, A):
        """
        Compute pairwise distances between samples in a 3D array A of shape (N, n_in, T).
        Distance between samples i and j is defined as:
            sum over t=0 to T-1 of ||A[i, :, t] - A[j, :, t]||_2

        Parameters:
        A (cp.ndarray): Input array of shape (N, n_in, T)

        Returns:
        cp.ndarray: Distance matrix of shape (N, N)
        """
        N, n_in, T = A.shape
        distance_matrix = cp.zeros((N, N))

        A_T = A.transpose(2, 0, 1)  # Shape: (T, N, n_in)

        for t in range(T):
            data_t = A_T[t]  # Shape: (N, n_in)

            # Compute pairwise differences using broadcasting
            diff = data_t[:, cp.newaxis, :] - data_t[cp.newaxis, :, :]  # Shape: (N, N, n_in)

            # Compute Euclidean norms
            norms = cp.square(cp.abs(diff)).sum(axis=2)  # Shape: (N, N)

            # Accumulate distances
            distance_matrix += norms

        return cp.sqrt(distance_matrix)

class PairwiseDot(SamplingMetric):
    def __str__(self):
        return "Dot"

    def __init__(self, weight=1.):
        super().__init__(weight)


    def compute_dist(self, A):
        """
        :param A: (N, n_in, T)
        :return: (N, N)
        """
        N, n_in, T = A.shape

        A_re = A.reshape(N, n_in * T)
        dots = A_re @ A_re.T
        norms = cp.linalg.norm(A_re, axis=1)
        norms = norms[:, cp.newaxis] @ norms[cp.newaxis, :]

        return dots / (norms + 1e-6)

class BandedFourierL2(SamplingMetric):
    def __str__(self):
        return f"BandedFourier[{self.band.min()}, {self.band.max()}])"

    def __init__(self, band: cp.ndarray, weight=1.):
        super().__init__(weight=weight)
        self.band = band
        self.l2 = PairwiseL2()

    def compute_dist(self, A):
        N, n_in, T = A.shape
        ts = cp.arange(T) / T
        w = cp.exp(-2j * cp.pi * ts[:, None] * self.band[None, :]).astype(cp.complex64) # (T, B)
        return self.l2(A @ w) / cp.sqrt(T)

class CosineDistance(SamplingMetric):
    def __str__(self):
        return "CosineDistance"

    def __init__(self, weight=1.):
        super().__init__(weight=weight)
        self.dot = PairwiseDot()
    def compute_dist(self, A):
        return 1 - self.dot(A)


