import torch

from abc import ABC, abstractmethod
from src.wavelets import calculate_wavelet_coeffs

class WOT(ABC):
    """
    Abstract class for Wavelet Optimal Transport.

    Parameters
    ----------
    X1 : torch.tensor
        Source data matrix.
    X2 : torch.tensor
        Target data matrix.
    n_scales : int
        Number of scales to use for wavelet transform.
    w_op : str
        Wavelet operator to use. Options are "heat", "mexican_hat", "itersine", 
        "simple_tight", "half_cosine_kernel" or "meyer".
    T : torch.tensor
        Coupling matrix. If None, then the coupling matrix will be calculated.
    """
    def __init__(self, X1, X2, n_scales=20, w_op="heat", T=None, rbf_norm=True, dist="geodesic"):
        self.X1 = torch.tensor(X1, dtype=torch.float)
        self.X2 = torch.tensor(X2, dtype=torch.float)

        self.n_scales = n_scales

        # get wavelet coefficients matrices
        self.X1_wavelet_coeffs = calculate_wavelet_coeffs(self.X1, self.n_scales, w_op=w_op, rbf_norm=rbf_norm, dist=dist)
        self.X2_wavelet_coeffs = calculate_wavelet_coeffs(self.X2, self.n_scales, w_op=w_op, rbf_norm=rbf_norm, dist=dist)

        self.w_op = w_op
        self.T = T

    def solve(self, p=None, q=None):
        """
        Solves the optimal transport problem.

        Parameters
        ----------
        p : torch.tensor
            Source distribution.
        q : torch.tensor
            Target distribution.
        """
        if not p:
            self.p = torch.ones((self.X1_wavelet_coeffs[0].shape[1], ))
            self.p /= torch.numel(self.p)
        else:
            self.p = p

        if not q:
            self.q = torch.ones((self.X2_wavelet_coeffs[0].shape[1], ))
            self.q /= torch.numel(self.q)
        else:
            self.q = q

    def project(self, to_X2=True):
        """
        Barycentric projection of the source or target data matrix.

        Parameters
        ----------
        to_X2 : bool
            If True, project the source data matrix to the target data matrix. 
            Otherwise, project the target data matrix to the source data matrix.
        
        Returns
        ----------
        X_proj : torch.tensor
            Projected data matrix.
        """
        if to_X2:
            weights = torch.sum(self.T, axis = 1)
            X_proj = (self.T @ self.X2) / weights[:, None]
        else:
            weights = torch.sum(self.T, axis = 0)
            X_proj = (torch.t(self.T) @ self.X1) / weights[:, None]

        return X_proj