import numpy as np 
import torch
from torch.optim import Adam,SGD
import ot
from typing import Union

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
# Utility functions for barycentering methods.
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################

def full_to_reduced_system(Ts:tuple) -> tuple:
    """Reduce the spectral decomposition to subspaces associated with positive frequencies. 
    Args:
        Ts (tuple): A tuple containing the spectral decomposition components.
            If dual, it should contain (Ds, Xs, prs, pls).
            If not dual, it should contain (Ds, Rs, Ls).    
    Returns:
        tuple: The reduced spectral decomposition.
    """

    if len(Ts) == 4:
        dual = True
        Ds, Xs, prs, pls = Ts
    elif len(Ts) == 3:
        dual = False
        Ds, Rs, Ls = Ts
    else: 
        raise ValueError("The input tuple must have 3 or 4 elements.")
    
    frequencies = Ds.imag
    positive_freq_indices = frequencies > 0
    if dual:
        Ds = Ds[positive_freq_indices]
        prs = prs[:, positive_freq_indices]
        pls = pls[:, positive_freq_indices]
        Ts  = (Ds, Xs, prs, pls)
    else:
        Ds = Ds[positive_freq_indices]
        Rs = Rs[:, positive_freq_indices]
        Ls = Ls[:, positive_freq_indices]
        Ts = (Ds, Rs, Ls)
    return Ts

def reduced_to_full_system(Ts:tuple) -> tuple:
    """Reconstruct the full spectral decomposition from the reduced components.
    
    Args:
        Ts (tuple): A tuple containing the reduced spectral decomposition components.
            If dual, it should contain (Ds, Xs, prs, pls).
            If not dual, it should contain (Ds, Rs, Ls).
    
    Returns:
        tuple: The full spectral decomposition.
    """
    if len(Ts) == 4:
        dual = True
        Ds, Xs, prs, pls = Ts
    elif len(Ts) == 3:
        dual = False
        Ds, Rs, Ls = Ts
    else: 
        raise ValueError("The input tuple must have 3 or 4 elements.")
    
    # Ensure that the input is a numpy array or torch tensor
    torch_flag = False
    if isinstance(Ds, torch.Tensor):
        torch_flag = True
        device = Ds.device
        grad_flags = [comp.requires_grad for comp in Ts]
        dtype = Ds.dtype
        if dual:
            Ds = Ds.detach().cpu().numpy()
            Xs = Xs.detach().cpu().numpy()
            prs = prs.detach().cpu().numpy()
            pls = pls.detach().cpu().numpy()
        else:
            Ds = Ds.detach().cpu().numpy()
            Rs = Rs.detach().cpu().numpy()
            Ls = Ls.detach().cpu().numpy()

    if dual:
        Ds = np.hstack((Ds, np.conj(Ds[::-1])))
        prs = np.hstack((prs, np.conj(prs[:, ::-1])))
        pls = np.hstack((pls, np.conj(pls[:, ::-1])))
        Ts = (Ds, Xs, prs, pls)
        
    else:
        Ds = np.hstack((Ds, np.conj(Ds[::-1])))
        Rs = np.hstack((Rs, np.conj(Rs[:, ::-1])))
        Ls = np.hstack((Ls, np.conj(Ls[:, ::-1])))
        Ts = (Ds, Rs, Ls)

    if torch_flag:
        Ts = tuple(torch.tensor(comp, device=device, dtype=dtype, requires_grad=grad) for comp, grad in zip(Ts, grad_flags))

    return Ts

def ot_score(C:torch.Tensor,P:torch.Tensor, p:int=2)->float:
    """Compute the OT score (distance) between two distributions given a cost matrix and a transport plan.

    Args:
        C (torch.Tensor): Cost matrix, shape: (n,m).
        P (torch.Tensor): Transport plan, shape: (n,m).
        p (int, optional): Power for the OT score. Defaults to 2.

    Returns:
        float: OT score (distance).
    """
    return torch.sum(C * P) ** (1/p)

def ot_plan(C:torch.Tensor,Ws:torch.Tensor=None,Wt:torch.Tensor=None,device:torch.device=torch.device("cpu"))->torch.Tensor:
    """Compute the optimal transport plan between two distributions given a cost matrix and marginal distributions.

    Args:
        C (torch.Tensor): Cost matrix, shape: (n,m).
        Ws (torch.Tensor): Source distribution, shape: (n,).
        Wt (torch.Tensor): Target distribution, shape: (m,).
        device (torch.device, optional): Device to perform the computation on. Defaults to torch.device("cpu").

    Returns:
        torch.Tensor: Optimal transport plan, shape: (n,m).
    """
    if Ws is None:
        Ws = torch.ones(C.shape[0], device="cpu") / C.shape[0]
    if Wt is None:
        Wt = torch.ones(C.shape[1], device="cpu") / C.shape[1]
    return ot.emd(Ws,Wt,C.to("cpu")).to(device)

def all_ot_plan(
    cost_function:callable,
    Ts: list,
    Tt_lst:list,
    Ts_weights:torch.Tensor=None,
    Tt_weights:list=None,
    device:torch.device=torch.device("cpu"))->list:
    """Compute the optimal transport plans for a list of target spectral decompositions.
    
    Args:
        cost_function (callable): Function to compute the cost matrix between spectral decompositions.
        Ts (list): Source spectral decomposition (Ds, Rs, Ls) or (Ds, Xs, prs, pls).
        Tt_lst (list): List of tuples, each containing a target spectral decomposition (Dt, Rt, Lt) or (Dt, Xt, prt, plt).
        Ts_weights (torch.Tensor, optional): Weights for the source spectral decomposition (discrete probability distribution). Defaults to None, set to uniform distribution.
        Tt_weights (list, optional): Weights for the target spectral decompositions (discrete probability distribution). Defaults to None, set to uniform distribution.
        device (torch.device, optional): Device to perform the computation on. Defaults to torch.device("cpu").         
    
    Returns:
        list: List of optimal transport plans for each target spectral decomposition.
    """
    if Tt_weights is None:
        Tt_weights = [None]*len(Tt_lst)

    P_lst = []
    for Tt,w in zip(Tt_lst, Tt_weights):
        C = cost_function(*Ts,*Tt)
        P = ot_plan(C,Ts_weights,w,device)
        P_lst.append(P)
    return P_lst

def eigenvalue_cost_matrix(Ds:torch.Tensor,Dt:torch.Tensor,real_scale:float=1.0,imag_scale:float=1.0)->torch.Tensor:
    """Compute pairwise eigenvalue matrices for source and target domains.

    Args:
        Ds (torch.Tensor): Source domain eigenvalues, shape: (Rs,), Rs: source dimension.
        Dt (torch.Tensor): Target domain eigenvalues, shape: (Rt,), Rt: target dimension.

    Returns:
        torch.Tensor: Eigenvalue cost matrix.
    """
    Dsn = Ds.real*real_scale + 1j* Ds.imag*imag_scale
    Dtn = Dt.real*real_scale + 1j* Dt.imag*imag_scale
    C = torch.abs(Dsn[:,None] - Dtn[None,:])
    return C

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
#  Utility functions in the primal setting.
##################################################################################################################################################################################################################
######################################################################################################################################################################################################################

def unitary_grassman_matrix(Ps:torch.Tensor,Pt:torch.Tensor)->torch.Tensor:
    """Compute the unitary grassman matrix for source and target domains.

    Args:
        Ps (torch.Tensor): Source domain data, shape: (l,n_ds), l: ambiant space dimensions size, n_ds: grassman source dimension.
        Pt (torch.Tensor): Target domain data, shape: (l,n_dt), l: ambiant space dimensions size, n_dt: grassman target dimension.

    Returns:
        torch.Tensor: Grassman matrix.
    """
    Psn = Ps/torch.norm(Ps,dim=0,keepdim=True)
    Ptn = Pt/torch.norm(Pt,dim=0,keepdim=True)
    C = torch.einsum('dl,lr->dr',Psn.conj().T,Ptn)
    return C

def eigenvector_chordal_cost_matrix(Rs:torch.Tensor,Ls:torch.Tensor,Rt:torch.Tensor,Lt:torch.Tensor)->torch.Tensor:
    """Compute pairwise grassman matrices for source and target domains.

    Args:
        Rs (torch.Tensor): Source right eigenvectors, shape: (L,Rs), L: ambiant space dimensions size, Rs: grassman source dimension.
        Ls (torch.Tensor): Source left eigenvectors, shape: (L,Rs), L: ambiant space dimensions size, Rs: grassman source dimension.
        Rt (torch.Tensor): Target right eigenvectors, shape: (L,Rt), L: ambiant space dimensions size, Rt: grassman target dimension.
        Lt (torch.Tensor): Target left eigenvectors, shape: (L,Rt), L: ambiant space dimensions size, Rt: grassman target dimension.

    Returns:
        torch.Tensor: eigenvector chordal cost matrix.
    """
    Cr = unitary_grassman_matrix(Rs,Rt)
    Cl = unitary_grassman_matrix(Ls,Lt)
    C = torch.sqrt(1-torch.clamp((Cr*Cl).real, min=0, max=1))
    return C

def ChordalCostFunction(
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):
    """Generate the chordal cost function.

    Args:
        real_scale (float): Real scale factor.
        imag_scale (float): Imaginary scale factor.
        alpha (float): Weighting factor for the eigenvalue cost.
        p (int): Power for the chordal distance.

    Returns:
        callable: Chordal cost function.
    """
    def cost_function(
            Ds:torch.Tensor,
            Rs:torch.Tensor,
            Ls:torch.Tensor,
            Dt:torch.Tensor,
            Rt:torch.Tensor,
            Lt:torch.Tensor) -> torch.Tensor:
        """Compute the chordal cost matrix between source and target spectral decompositions.

        Args:
            Ds (torch.Tensor): Source eigenvalues.
            Rs (torch.Tensor): Source right eigenvectors.
            Ls (torch.Tensor): Source left eigenvectors.
            Dt (torch.Tensor): Target eigenvalues.
            Rt (torch.Tensor): Target right eigenvectors.
            Lt (torch.Tensor): Target left eigenvectors.

        Returns:
            torch.Tensor: Chordal cost matrix.
        """
        CD = eigenvalue_cost_matrix(Ds, Dt, real_scale=real_scale, imag_scale=imag_scale)
        CC = eigenvector_chordal_cost_matrix(Rs, Ls, Rt, Lt)
        C = alpha * CD + (1 - alpha) * CC
        return C**p

    return cost_function

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
#  Utility functions in the dual setting.
##################################################################################################################################################################################################################
######################################################################################################################################################################################################################

def RBF(sigma:float=1.0)->callable:
    """Generate the RBF kernel function with a given sigma.

    Args:
        sigma (float): The bandwidth parameter for the RBF kernel.
    Returns:
        callable: RBF kernel function that computes the kernel matrix between two sets of points.
    """
    def f(X:torch.Tensor,Y:torch.Tensor) -> torch.Tensor:
        """Compute the RBF kernel between two sets of points.
        
        Args:
            X (torch.Tensor): First set of points, shape: (n, d), n: number of points, d: dimension.
            Y (torch.Tensor): Second set of points, shape: (m, d), m: number of points, d: dimension.

        Returns:
            torch.Tensor: RBF kernel matrix, shape: (n, m).
        """
        pdist = torch.sum((X[:,None,:] - Y[None,:,:])**2,dim=-1)
        return torch.exp(-pdist/sigma**2)
    return f 

def kernel_unitary_grassman_matrix(
        Xs:torch.Tensor,
        ps:torch.Tensor,
        Xt:torch.Tensor, 
        pt: torch.Tensor,
        K:callable) -> torch.Tensor: 

    Ks = torch.sqrt(torch.sum(ps.conj() * (K(Xs,Xs) @ ps),dim=0,keepdim=True).real)
    Kt = torch.sqrt(torch.sum(pt.conj() * (K(Xt,Xt) @ pt),dim=0,keepdim=True).real)
    Kst = ps.conj().T @ K(Xs,Xt) @ pt
    return Kst/ (Ks.T*Kt)

def _kernel_unitary_grassman_matrix(
        ps:torch.Tensor, 
        pt: torch.Tensor,
        Ks:torch.Tensor,
        kt:torch.Tensor,
        Kst:torch.Tensor) -> torch.Tensor: 
    
    pKs = torch.sqrt(torch.sum(ps.conj() * (Ks @ ps),dim=0,keepdim=True).real)
    pKt = torch.sqrt(torch.sum(pt.conj() * (kt @ pt),dim=0,keepdim=True).real)
    pKst = ps.conj().T @ Kst @ pt
    return pKst/ (pKs.T*pKt)

def kernel_eigenvector_chordal_cost_matrix(
        Xs:torch.Tensor,
        prs:torch.Tensor,
        pls:torch.Tensor, 
        Xt: torch.Tensor,
        prt:torch.Tensor, 
        plt:torch.Tensor,
        K:callable):
    
    #compute the kernel matrices
    Ks = K(Xs,Xs)
    Kt = K(Xt,Xt)
    Kst = K(Xs,Xt)
    
    #compute the unitary grassman matrices
    Cr = _kernel_unitary_grassman_matrix(prs,prt,Ks,Kt,Kst)
    Cl = _kernel_unitary_grassman_matrix(pls,plt,Ks,Kt,Kst)
    C = torch.sqrt(1-torch.clamp((Cr*Cl).real,min=0))
    return C

def eigenvalue_cost_matrix(Ds:torch.Tensor,Dt:torch.Tensor,real_scale:float=1.0,imag_scale:float=1.0)->torch.Tensor:
    """Compute pairwise eigenvalue matrices for source and target domains.

    Args:
        Ds (torch.Tensor): Source domain eigenvalues, shape: (Rs,), Rs: source dimension.
        Dt (torch.Tensor): Target domain eigenvalues, shape: (Rt,), Rt: target dimension.

    Returns:
        torch.Tensor: Eigenvalue cost matrix.
    """
    Dsn = Ds.real*real_scale + 1j* Ds.imag*imag_scale
    Dtn = Dt.real*real_scale + 1j* Dt.imag*imag_scale
    C = torch.abs(Dsn[:,None] - Dtn[None,:])
    return C

def kernel_chordal_cost_matrix(
    Ds:torch.Tensor,
    Xs:torch.Tensor,
    prs:torch.Tensor, 
    pls:torch.Tensor,
    Dt:torch.Tensor,
    Xt:torch.Tensor,
    prt: torch.Tensor,
    plt:torch.Tensor,
    K:callable,
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):

    CD = eigenvalue_cost_matrix(Ds,Dt,real_scale=real_scale,imag_scale=imag_scale)
    CC = kernel_eigenvector_chordal_cost_matrix(Xs,prs,pls,Xt,prt,plt,K)
    C = alpha * CD + (1 - alpha) * CC
    return C**p

def KernelChordalCostFunction(
    K:callable,
    real_scale:float=1.0,
    imag_scale:float=1.0,
    alpha:float=0.5,
    p:int = 2):

    def cost_function(
        Ds:torch.Tensor,
        Xs:torch.Tensor,
        prs:torch.Tensor, 
        pls:torch.Tensor,
        Dt:torch.Tensor,
        Xt:torch.Tensor,
        prt: torch.Tensor,
        plt:torch.Tensor):

        CD = eigenvalue_cost_matrix(Ds,Dt,real_scale=real_scale,imag_scale=imag_scale)
        CC = kernel_eigenvector_chordal_cost_matrix(Xs,prs,pls,Xt,prt,plt,K)
        C = alpha * CD + (1 - alpha) * CC
        return C**p
    return cost_function

##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
#  Core barycentring methods.
##################################################################################################################################################################################################################
######################################################################################################################################################################################################################


def SpectralOTBarycenterLoss(
        cost_function:callable,
        Tt_lst:list,
        ponderations:torch.Tensor=None,
        device:torch.device=torch.device("cpu"),
        p:int=1)-> callable:
    """Generate the loss function for spectral OT barycenter computation.

    Args:
        cost_function (callable): Function to compute the cost matrix between spectral decompositions.
        Tt_lst (list): List of tuples, each containing a target spectral decomposition (Dt, Rt, Lt) or (Dt, Xt, prt, plt).
        ponderations (torch.Tensor, optional): Ponderations associated to target spectral decompositions. Defaults to None, set to uniform distribution.
        device (torch.device, optional): Device to perform the computation on. Defaults to torch.device("cpu").
        p (int, optional): Power for the OT score. Defaults to 1.
    
    Returns:
        callable: Loss function that computes the OT score for a given source spectral decomposition and a list of target spectral decompositions.
    """

    # Number of target spectral decompositions
    N = len(Tt_lst)
    # Ensure that the target spectral decompositions are on the correct device
    Tt_lst = [tuple([x.to(device) for x in tpl]) for tpl in Tt_lst]

    if ponderations is None:
        # If no weights are provided, use uniform weights
        ponderations = torch.ones(N, device=device) / N
    else:
        assert ponderations.shape[0] == N, "Weights must have the same length as the number of target spectral decompositions."
        assert torch.all(ponderations >= 0), "Weights must be non-negative."
        assert torch.isclose(torch.sum(ponderations), torch.tensor(1.0, device=device)), "Weights must sum to 1."

        ponderations = ponderations.to(device)

    def loss_function(
            Ts:list,
            P_lst: list)->torch.Tensor:
        loss = torch.empty(N,device=device)
        for i,(Tt,P,w) in enumerate(zip(Tt_lst,P_lst,ponderations)):
            # Compute the cost matrix for the current target spectral decomposition
            C = cost_function(*Ts,*Tt)
            # Compute the OT score (distance)
            loss[i] = w*ot_score(C,P,p=p)
        return torch.sum(loss)
    
    return loss_function

class SpectralOTBarycenter:

    def __init__(self,
        cost_function: callable,
        K: callable = None,
        dual: bool = False,
        optimize_control_points: bool = True,
        optimizer: torch.optim.Optimizer = Adam,
        lr: float = 1e-3,
        max_epochs: int = 1000,
        max_iter: int = 1000,
        tol: float = 1e-6,
        scheduler: torch.optim.lr_scheduler = None,
        device: torch.device = torch.device("cpu"),
        verbose: int = 1):

        # check duality requirements
        if dual:
            assert K is not None, "Kernel function K must be provided for dual formulation and identical to the kernel used to define the cost function."
    
        # Initialize the KernelOTBarycenter class.
        self.cost_function = cost_function
        self.K = K
        self.dual = dual
        self.optimize_control_points = optimize_control_points
        self.optimizer = optimizer
        self.lr = lr
        self.max_epochs = max_epochs
        self.max_iter = max_iter
        self.tol = tol
        self.scheduler = scheduler
        self.device = device
        self.verbose = verbose

    def _single_coordinate_descent(self, Ts, P_lst, param, optimizer, scheduler, constraint, label, epoch):

        if self.verbose >= 2:
            print(f"Epoch {epoch}/{self.max_epochs}, optimizing {label}.")

        iter_error = torch.tensor(np.inf, device=self.device)
        for iter in range(self.max_iter):
            optimizer.zero_grad()
            # Compute the loss
            loss = self.b_loss(Ts, P_lst)
            # Perform the optimization step for the parameter
            grad_param = torch.autograd.grad(loss, param, create_graph=False ,retain_graph=False)[0]
            param.grad = grad_param
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            # apply the constraint if provided
            if constraint is not None:
                with torch.no_grad():
                    param.copy_(constraint(param))
            # Compute the iteration error
            iter_error = torch.norm(grad_param, p=2)

            if (self.verbose >= 2) * (iter % 10 == 0 or iter == self.max_iter):
                print(f"Iteration {iter+1}/{self.max_iter}, {label} loss: {loss.item()}, gradient norm: {iter_error.item()}.")

            if iter_error < self.tol:
                if self.verbose >= 2:
                    print(f"Convergence reached for {label} at iteration {iter+1} with error {iter_error.item()}.")
                break

    def _set_constraint_function(self, Ts, label):
        constraint = None
        if label == "prs":
            _, Xs, _, _ = Ts
            K_ = self.K(Xs,Xs)
            constraint = lambda x: x/ torch.sqrt(torch.sum(x.conj() * (K_ @ x),dim=0, keepdim=True).real)
        elif label == "Rs":
            constraint = lambda x: x/x.norm(dim=0, keepdim=True)
        elif label == "pls":
            _, Xs, prs, _ = Ts
            K_ = self.K(Xs,Xs) 
            A = prs.conj().T @ K_ @ prs
            def constraint(x):
                B = x.conj().T @ K_ @ prs - torch.eye(prs.shape[1], device=self.device)
                r_part = torch.linalg.solve(A, B, left=False)
                proj_x =  x - prs @ r_part.conj().T
                return proj_x
        elif label == "Ls": 
            _, Rs, _ = Ts
            A = Rs.conj().T @ Rs
            def constraint(x):
                B = x.conj().T @ Rs - torch.eye(Rs.shape[1], device=self.device)
                r_part = torch.linalg.solve(A, B, left=False)
                proj_x =  x - Rs @ r_part.conj().T
                return proj_x
        return constraint

    def fit(self,
        Tt_lst: list,
        init: tuple,
        Tt_weights: list = None,
        init_weights: torch.Tensor = None,
        ponderations: torch.Tensor = None):
    
        # Initialize the barycenter (source domain)
        if self.dual:
            # If using dual formulation, initialize the barycenter as a tuple of (Ds, Xs, prs, pls)
            Ds, Xs, prs, pls = init
            Ds.requires_grad_(True), Xs.requires_grad_(self.optimize_control_points), prs.requires_grad_(True), pls.requires_grad_(True)
            Ds, Xs, prs, pls = Ds.to(self.device), Xs.to(self.device), prs.to(self.device), pls.to(self.device)
            Ts = (Ds, Xs, prs, pls)
            labels = ["Ds", "Xs", "prs", "pls"]
        else:
            # If using primal formulation, initialize the barycenter as a tuple of (Ds, Rs, Ls)
            Ds, Rs, Ls = init
            Ds.requires_grad_(self.optimize_control_points), Rs.requires_grad_(True), Ls.requires_grad_(True)
            Ds, Rs, Ls = Ds.to(self.device), Rs.to(self.device), Ls.to(self.device)
            Ts = (Ds, Rs, Ls)
            labels = ["Ds", "Rs", "Ls"]
        
        if init_weights is not None:
            Ts_weights = init_weights.to(self.device)
        else:
            Ts_weights = None

        # Initialize targets
        Tt_lst = [tuple([x.to(self.device) for x in tpl]) for tpl in Tt_lst]
        if Tt_weights is None:
            Tt_weights = [None]*len(Tt_lst)

        # Initialize OT plans
        P_lst = all_ot_plan(self.cost_function, Ts, Tt_lst, Ts_weights, Tt_weights, device=self.device)

        # Initialize loss function
        if ponderations is None:
            # If no weights are provided, use uniform weights
            ponderations = torch.ones(len(Tt_lst), device=self.device) / len(Tt_lst)
        self.b_loss = SpectralOTBarycenterLoss(self.cost_function,Tt_lst,ponderations,device=self.device)

        # initialize optimizers and schedulers
        optimizers = [None]*len(Ts)
        schedulers = [None]*len(Ts)
        for i,param in enumerate(Ts):
            if param.requires_grad:
                optimizer = self.optimizer(params=[param], lr=self.lr)
                optimizers[i] = optimizer
                if self.scheduler is not None:
                    scheduler = self.scheduler(optimizer)
                    schedulers[i] = scheduler

        #initialize the loss tensor
        epoch_loss = torch.empty(self.max_epochs+1, device=self.device)
        epoch_loss[0] = self.b_loss(Ts,P_lst).item()
        epoch_error = torch.tensor(np.inf, device=self.device)
        if self.verbose >= 1:
            print(f"Starting training with {len(Tt_lst)} target spectral decompositions.")
            print(f"Initial loss: {epoch_loss[0]}")

        # Training loop
        for epoch in range(self.max_epochs):
            
            # Optimize each parameter in Ts by block coordinate descent
            for param,optimizer,scheduler,label in zip(Ts, optimizers, schedulers,labels):
                if param.requires_grad:
                    constraint = self._set_constraint_function(Ts, label)
                    self._single_coordinate_descent(Ts, P_lst, param, optimizer, scheduler, constraint, label, epoch)

            # Update the OT plans
            with torch.no_grad():
                P_lst =  all_ot_plan(self.cost_function, Ts, Tt_lst, Ts_weights, Tt_weights, device=self.device)

            # Compute the epoch loss and error
            with torch.no_grad():
                loss = self.b_loss(Ts, P_lst)
                epoch_loss[epoch] = loss.item()
            epoch_error = torch.abs(epoch_loss[epoch] - epoch_loss[epoch-1] if epoch > 0 else torch.tensor(np.inf, device=self.device))

            if self.verbose >= 1:
                print(f"Epoch {epoch+1}/{self.max_epochs}, loss: {loss.item()}, error: {epoch_error.item()}.")
            if epoch_error < self.tol:
                if self.verbose >= 1:
                    print(f"Convergence reached at epoch {epoch+1}/{self.max_epochs} with error {epoch_error.item()}.")
                break

        if self.verbose >= 1:
            print(f"Training completed after {epoch+1}/{self.max_epochs} epochs with final loss: {loss.item()}.")

        self.Ts = Ts
        self.P_lst = P_lst
        self.epoch_loss = epoch_loss[:epoch+1]

        return Ts, P_lst, epoch_loss[:epoch+1]
    
##################################################################################################################################################################################################################
##################################################################################################################################################################################################################
#  Barycenter with Hilbert Schmidt cost.
##################################################################################################################################################################################################################
######################################################################################################################################################################################################################

def HilbertSchmidtCostFunction(
        sampfreq:int,
        step:int=1
    ):
    """Compute the Hilbert-Schmidt cost function.

    Args:
        sampfreq (int): The sampling frequency.
        step (int, optional): The number of time steps. Defaults to 1.
    """
    def cost_function(
            Ds:torch.Tensor,
            Rs:torch.Tensor,
            Ls:torch.Tensor,
            Dt:torch.Tensor,
            Rt:torch.Tensor,
            Lt:torch.Tensor) -> torch.Tensor:
        """Compute the chordal cost matrix between source and target spectral decompositions.

        Args:
            Ds (torch.Tensor): Source eigenvalues.
            Rs (torch.Tensor): Source right eigenvectors.
            Ls (torch.Tensor): Source left eigenvectors.
            Dt (torch.Tensor): Target eigenvalues.
            Rt (torch.Tensor): Target right eigenvectors.
            Lt (torch.Tensor): Target left eigenvectors.

        Returns:
            torch.Tensor: The Hilbert-Schmidt cost function.
        """
        Ds = torch.exp(Ds * (step/sampfreq))
        Ts = Rs @ (Ds.view(-1,1) * Ls.conj().T)
        Dt = torch.exp(Dt * (step/sampfreq))
        Tt = Rt @ (Dt.view(-1,1) * Lt.conj().T)
        C = (Tt-Ts).conj() * (Tt-Ts)
        return C.sum().real
    return cost_function

def StandardBarycenterLoss(
        cost_function:callable,
        Tt_lst:list,
        ponderations:torch.Tensor=None,
        device:torch.device=torch.device("cpu"))-> callable:
    """Generate the loss function for spectral OT barycenter computation.

    Args:
        cost_function (callable): Function to compute the cost matrix between spectral decompositions.
        Tt_lst (list): List of tuples, each containing a target spectral decomposition (Dt, Rt, Lt) or (Dt, Xt, prt, plt).
        ponderations (torch.Tensor, optional): Ponderations associated to target spectral decompositions. Defaults to None, set to uniform distribution.
        device (torch.device, optional): Device to perform the computation on. Defaults to torch.device("cpu").
    
    Returns:
        callable: Loss function that computes the OT score for a given source spectral decomposition and a list of target spectral decompositions.
    """

    # Number of target spectral decompositions
    N = len(Tt_lst)
    # Ensure that the target spectral decompositions are on the correct device
    Tt_lst = [tuple([x.to(device) for x in tpl]) for tpl in Tt_lst]

    if ponderations is None:
        # If no weights are provided, use uniform weights
        ponderations = torch.ones(N, device=device) / N
    else:
        assert ponderations.shape[0] == N, "Weights must have the same length as the number of target spectral decompositions."
        assert torch.all(ponderations >= 0), "Weights must be non-negative."
        assert torch.isclose(torch.sum(ponderations), torch.tensor(1.0, device=device)), "Weights must sum to 1."

        ponderations = ponderations.to(device)

    def loss_function(Ts:list)->torch.Tensor:
        loss = torch.empty(N,device=device)
        for i,(Tt,w) in enumerate(zip(Tt_lst,ponderations)):
            # Compute the OT score (distance)
            loss[i] = w * cost_function(*Ts,*Tt)
        return torch.sum(loss)
    
    return loss_function

class StandardBarycenter:

    def __init__(self,
        cost_function: callable,
        K: callable = None,
        dual: bool = False,
        optimize_control_points: bool = True,
        optimizer: torch.optim.Optimizer = Adam,
        lr: float = 1e-3,
        max_epochs: int = 1000,
        max_iter: int = 1000,
        tol: float = 1e-6,
        scheduler: torch.optim.lr_scheduler = None,
        device: torch.device = torch.device("cpu"),
        verbose: int = 1):

        # check duality requirements
        if dual:
            assert K is not None, "Kernel function K must be provided for dual formulation and identical to the kernel used to define the cost function."
    
        # Initialize the KernelOTBarycenter class.
        self.cost_function = cost_function
        self.K = K
        self.dual = dual
        self.optimize_control_points = optimize_control_points
        self.optimizer = optimizer
        self.lr = lr
        self.max_epochs = max_epochs
        self.max_iter = max_iter
        self.tol = tol
        self.scheduler = scheduler
        self.device = device
        self.verbose = verbose

    def _single_coordinate_descent(self, Ts, param, optimizer, scheduler, constraint, label, epoch):

        if self.verbose >= 2:
            print(f"Epoch {epoch}/{self.max_epochs}, optimizing {label}.")

        iter_error = torch.tensor(np.inf, device=self.device)
        for iter in range(self.max_iter):
            optimizer.zero_grad()
            # Compute the loss
            loss = self.b_loss(Ts)
            # Perform the optimization step for the parameter
            grad_param = torch.autograd.grad(loss, param, create_graph=False ,retain_graph=False)[0]
            param.grad = grad_param
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            # apply the constraint if provided
            if constraint is not None:
                with torch.no_grad():
                    param.copy_(constraint(param))
            # Compute the iteration error
            iter_error = torch.norm(grad_param, p=2)

            if (self.verbose >= 2) * (iter % 10 == 0 or iter == self.max_iter):
                print(f"Iteration {iter+1}/{self.max_iter}, {label} loss: {loss.item()}, gradient norm: {iter_error.item()}.")

            if iter_error < self.tol:
                if self.verbose >= 2:
                    print(f"Convergence reached for {label} at iteration {iter+1} with error {iter_error.item()}.")
                break

    def _set_constraint_function(self, Ts, label):
        constraint = None
        if label == "prs":
            _, Xs, _, _ = Ts
            K_ = self.K(Xs,Xs)
            constraint = lambda x: x/ torch.sqrt(torch.sum(x.conj() * (K_ @ x),dim=0, keepdim=True).real)
        elif label == "Rs":
            constraint = lambda x: x/x.norm(dim=0, keepdim=True)
        elif label == "pls":
            _, Xs, prs, _ = Ts
            K_ = self.K(Xs,Xs) 
            A = prs.conj().T @ K_ @ prs
            def constraint(x):
                B = x.conj().T @ K_ @ prs - torch.eye(prs.shape[1], device=self.device)
                r_part = torch.linalg.solve(A, B, left=False)
                proj_x =  x - prs @ r_part.conj().T
                return proj_x
        elif label == "Ls": 
            _, Rs, _ = Ts
            A = Rs.conj().T @ Rs
            def constraint(x):
                B = x.conj().T @ Rs - torch.eye(Rs.shape[1], device=self.device)
                r_part = torch.linalg.solve(A, B, left=False)
                proj_x =  x - Rs @ r_part.conj().T
                return proj_x
        return constraint

    def fit(self,
        Tt_lst: list,
        init: tuple,
        Tt_weights: list = None,
        init_weights: torch.Tensor = None,
        ponderations: torch.Tensor = None):
    
        # Initialize the barycenter (source domain)
        if self.dual:
            # If using dual formulation, initialize the barycenter as a tuple of (Ds, Xs, prs, pls)
            Ds, Xs, prs, pls = init
            Ds.requires_grad_(True), Xs.requires_grad_(self.optimize_control_points), prs.requires_grad_(True), pls.requires_grad_(True)
            Ds, Xs, prs, pls = Ds.to(self.device), Xs.to(self.device), prs.to(self.device), pls.to(self.device)
            Ts = (Ds, Xs, prs, pls)
            labels = ["Ds", "Xs", "prs", "pls"]
        else:
            # If using primal formulation, initialize the barycenter as a tuple of (Ds, Rs, Ls)
            Ds, Rs, Ls = init
            Ds.requires_grad_(self.optimize_control_points), Rs.requires_grad_(True), Ls.requires_grad_(True)
            Ds, Rs, Ls = Ds.to(self.device), Rs.to(self.device), Ls.to(self.device)
            Ts = (Ds, Rs, Ls)
            labels = ["Ds", "Rs", "Ls"]
        

        # Initialize targets
        Tt_lst = [tuple([x.to(self.device) for x in tpl]) for tpl in Tt_lst]
        if Tt_weights is None:
            Tt_weights = [None]*len(Tt_lst)

        # Initialize loss function
        if ponderations is None:
            # If no weights are provided, use uniform weights
            ponderations = torch.ones(len(Tt_lst), device=self.device) / len(Tt_lst)
        self.b_loss = StandardBarycenterLoss(self.cost_function,Tt_lst,ponderations,device=self.device)

        # initialize optimizers and schedulers
        optimizers = [None]*len(Ts)
        schedulers = [None]*len(Ts)
        for i,param in enumerate(Ts):
            if param.requires_grad:
                optimizer = self.optimizer(params=[param], lr=self.lr)
                optimizers[i] = optimizer
                if self.scheduler is not None:
                    scheduler = self.scheduler(optimizer)
                    schedulers[i] = scheduler

        #initialize the loss tensor
        epoch_loss = torch.empty(self.max_epochs+1, device=self.device)
        epoch_loss[0] = self.b_loss(Ts).item()
        epoch_error = torch.tensor(np.inf, device=self.device)
        if self.verbose >= 1:
            print(f"Starting training with {len(Tt_lst)} target spectral decompositions.")
            print(f"Initial loss: {epoch_loss[0]}")

        # Training loop
        for epoch in range(self.max_epochs):
            
            # Optimize each parameter in Ts by block coordinate descent
            for param,optimizer,scheduler,label in zip(Ts, optimizers, schedulers,labels):
                if param.requires_grad:
                    constraint = self._set_constraint_function(Ts, label)
                    self._single_coordinate_descent(Ts, param, optimizer, scheduler, constraint, label, epoch)

            # Compute the epoch loss and error
            with torch.no_grad():
                loss = self.b_loss(Ts)
                epoch_loss[epoch] = loss.item()
            epoch_error = torch.abs(epoch_loss[epoch] - epoch_loss[epoch-1] if epoch > 0 else torch.tensor(np.inf, device=self.device))

            if self.verbose >= 1:
                print(f"Epoch {epoch+1}/{self.max_epochs}, loss: {loss.item()}, error: {epoch_error.item()}.")
            if epoch_error < self.tol:
                if self.verbose >= 1:
                    print(f"Convergence reached at epoch {epoch+1}/{self.max_epochs} with error {epoch_error.item()}.")
                break

        if self.verbose >= 1:
            print(f"Training completed after {epoch+1}/{self.max_epochs} epochs with final loss: {loss.item()}.")

        self.Ts = Ts
        self.epoch_loss = epoch_loss[:epoch+1]

        return Ts, epoch_loss[:epoch+1]