# -*- coding: utf-8 -*-
"""
Pure PyTorch implementation of Sinkhorn algorithm
"""

import torch
from typing import Optional, Union, Tuple, Dict


def sinkhorn(
    a: torch.Tensor,
    b: torch.Tensor,
    M: torch.Tensor,
    reg: float,
    method: str = 'sinkhorn',
    numItermax: int = 1000,
    stopThr: float = 1e-9,
    verbose: bool = False,
    log: bool = False,
    warn: bool = True,
    **kwargs
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
    """
    Solve the entropic regularization optimal transport problem
    
    Parameters
    ----------
    a : torch.Tensor, shape (ns,)
        Source distribution
    b : torch.Tensor, shape (nt,)
        Target distribution  
    M : torch.Tensor, shape (ns, nt)
        Cost matrix
    reg : float
        Regularization parameter
    method : str
        Sinkhorn method ('sinkhorn' or 'sinkhorn_log')
    numItermax : int
        Maximum number of iterations
    stopThr : float
        Stopping threshold
    verbose : bool
        Print information
    log : bool
        Return log dictionary
        
    Returns
    -------
    gamma : torch.Tensor
        Optimal transport plan
    log : dict, optional
        Log dictionary if log=True
    """
    device = a.device
    dtype = a.dtype
    
    # Initialize
    if log:
        log_dict = {'err': []}
    
    # Numerical stability checks
    if reg <= 0:
        raise ValueError(f"Regularization must be positive, got {reg}")
    
    # Normalize distributions
    a = a / a.sum()
    b = b / b.sum()
    
    epsilon = 1e-16
    a = torch.clamp(a, min=epsilon)
    b = torch.clamp(b, min=epsilon)
    
    # Kernel matrix with numerical stability
    K = torch.exp(-M / reg)
    
    if torch.isnan(K).any() or torch.isinf(K).any():
        print("Warning: Kernel has numerical issues")
        K = torch.clamp(K, min=epsilon, max=1/epsilon)
    
    # Initial dual variables
    u = torch.ones_like(a) / a.shape[0]
    v = torch.ones_like(b) / b.shape[0]
    
    u = torch.clamp(u, min=epsilon)
    v = torch.clamp(v, min=epsilon)
    
    for i in range(numItermax):
        u_prev = u.clone()
        
        # Sinkhorn updates with numerical stability
        try:
            # v = b / (K.T @ u)
            Ku = K.T @ u
            Ku = torch.clamp(Ku, min=epsilon)
            v = b / Ku
            v = torch.clamp(v, min=epsilon, max=1/epsilon)
            
            # u = a / (K @ v)  
            Kv = K @ v
            Kv = torch.clamp(Kv, min=epsilon)
            u = a / Kv
            u = torch.clamp(u, min=epsilon, max=1/epsilon)
            
        except Exception as e:
            if warn:
                print(f"Numerical issue in Sinkhorn iteration {i}: {e}")
            break
        
        # Check for numerical issues
        if torch.isnan(u).any() or torch.isnan(v).any() or torch.isinf(u).any() or torch.isinf(v).any():
            if warn:
                print(f"NaN/inf detected at iteration {i}")
            break
        
        # Check convergence
        if i % 10 == 0:
            # Compute current plan
            gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
            
            # Check marginal constraints
            err_u = torch.sum(torch.abs(torch.sum(gamma, dim=1) - a))
            err_v = torch.sum(torch.abs(torch.sum(gamma, dim=0) - b))
            err = err_u + err_v
            
            if log:
                log_dict['err'].append(err.item())
            
            if verbose and i % 100 == 0:
                print(f"Iteration {i}, error: {err.item():.6e}")
            
            if err < stopThr:
                break
    
    # Final transport plan
    try:
        gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
        
        # Final numerical stability check
        if torch.isnan(gamma).any() or torch.isinf(gamma).any():
            if warn:
                print("Warning: Final transport plan has numerical issues, using fallback")
            gamma = torch.outer(a, b)
        
        # Ensure non-negativity
        gamma = torch.clamp(gamma, min=0)
        
    except Exception as e:
        if warn:
            print(f"Error computing final transport plan: {e}")
        gamma = torch.outer(a, b)
    
    if log:
        log_dict['u'] = u
        log_dict['v'] = v
        log_dict['niter'] = i
        return gamma, log_dict
    else:
        return gamma
