import math
import warnings
from functools import partial
from typing import Optional, Union

import numpy as np
import ot as pot
import torch

# Import logging utilities
try:
    from .logging_utils import get_logger, log_numerical_error, log_ot_plan_info
except ImportError:
    # Fallback if logging utils not available
    def get_logger(name="torchcfm"):
        import logging
        return logging.getLogger(name)
    
    def log_numerical_error(logger, error_type, details, context=""):
        logger.error(f"NUMERICAL_ERROR: {error_type} | Context: {context} | Details: {details}")
    
    def log_ot_plan_info(logger, plan_sum, cost_matrix_info, parameters):
        logger.info(f"OT Plan sum: {plan_sum:.2e} | Cost info: {cost_matrix_info} | Params: {parameters}")


class OTPlanSampler:
    """OTPlanSampler implements sampling coordinates according to an OT plan (wrt squared Euclidean
    cost) with different implementations of the plan calculation."""

    def __init__(
        self,
        method: str,
        reg: float = 0.05,
        #reg_m: float = 1.0,
        #reg_m = (float("inf"), 1.0),
        reg_m = (1.0, 1.0),
        normalize_cost: bool = True,
        num_threads: Union[int, str] = 1,
        warn: bool = False,
    ) -> None:
        """Initialize the OTPlanSampler class.

        Parameters
        ----------
        method: str
            choose which optimal transport solver you would like to use.
            Currently supported are ["exact", "sinkhorn", "unbalanced",
            "partial"] OT solvers.
        reg: float, optional
            regularization parameter to use for Sinkhorn-based iterative solvers.
        reg_m: float, optional
            regularization weight for unbalanced Sinkhorn-knopp solver.
        normalize_cost: bool, optional
            normalizes the cost matrix so that the maximum cost is 1. Helps
            stabilize Sinkhorn-based solvers. Should not be used in the vast
            majority of cases.
        num_threads: int or str, optional
            number of threads to use for the "exact" OT solver. If "max", uses
            the maximum number of threads.
        warn: bool, optional
            if True, raises a warning if the algorithm does not converge
            if False, kill program.
        """
        # Initialize logger
        self.logger = get_logger("torchcfm.ot")
        
        # ot_fn should take (a, b, M) as arguments where a, b are marginals and
        # M is a cost matrix
        if method == "exact":
            self.ot_fn = partial(pot.emd, numThreads=num_threads)
        elif method == "sinkhorn":
            self.ot_fn = partial(pot.sinkhorn, reg=reg)
        elif method == "unbalanced_knopp": # return log_uv not only p
            self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m, log=True)
        elif method == "unbalanced": # return log_uv not only p
            self.ot_fn = partial(pot.unbalanced.sinkhorn_unbalanced, reg=reg, reg_m=reg_m, log=True)
        elif method == "gpu_unbalanced": # return log_uv not only p
            self.ot_fn = partial(pot.gpu.sinkhorn_unbalanced, reg=reg, reg_m=reg_m, log=True)
        elif method == "partial":
            self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg)
        else:
            raise ValueError(f"Unknown method: {method}")
        self.reg = reg
        self.reg_m = reg_m
        self.normalize_cost = normalize_cost
        self.warn = warn
        self.method = method
        
        # Log initialization
        self.logger.info(f"OTPlanSampler initialized with method={method}, reg={reg}, reg_m={reg_m}")

    def get_map(self, x0, x1):
        """Compute the OT plan (wrt squared Euclidean cost) between a source and a target
        minibatch.

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch

        Returns
        -------
        p : numpy array, shape (bs, bs)
            represents the OT plan between minibatches
        """
        a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
        if x0.dim() > 2:
            x0 = x0.reshape(x0.shape[0], -1)
        if x1.dim() > 2:
            x1 = x1.reshape(x1.shape[0], -1)
        M = torch.cdist(x0, x1) ** 2

        if self.normalize_cost:
            M = M / (M.max() + 1e-12)  # Prevent division by zero

        # Log cost matrix info
        cost_info = {
            "mean": M.mean().item(),
            "max": M.max().item(),
            "min": M.min().item(),
            "std": M.std().item()
        }
        
        # Log parameters
        params = {
            "method": self.method,
            "reg": self.reg,
            "reg_m": self.reg_m,
            "normalize_cost": self.normalize_cost,
            "batch_size": x0.shape[0]
        }

        self.logger.debug(f"Computing OT plan with cost matrix: {cost_info}")

        result = self.ot_fn(a, b, M.detach().cpu().numpy()) 
        
        #     
        if isinstance(result, tuple) and len(result) == 2:
            p, log_uv = result # p = pi
        else:
            p = result # p = pi
            log_uv = None

        # Check for numerical errors
        if not np.all(np.isfinite(p)):
            error_details = {
                "plan_shape": p.shape,
                "plan_sum": p.sum(),
                "plan_min": p.min(),
                "plan_max": p.max(),
                "cost_info": cost_info,
                "parameters": params
            }
            log_numerical_error(
                self.logger,
                "OT_PLAN_NOT_FINITE",
                error_details,
                f"get_map method={self.method}"
            )
            print("ERROR: p is not finite")
            print(p)
            print("Cost mean, max", M.mean(), M.max())
            print(x0, x1)
            
            # Log OT plan info
            log_ot_plan_info(self.logger, p.sum(), cost_info, params)
            
        
        if np.abs(p.sum()) < 1e-8:
            error_details = {
                "plan_sum": p.sum(),
                "plan_shape": p.shape,
                "cost_info": cost_info,
                "parameters": params
            }
            log_numerical_error(
                self.logger,
                "OT_PLAN_ZERO_SUM",
                error_details,
                f"get_map method={self.method}"
            )
            
            if self.warn:
                warnings.warn("Numerical errors in OT plan, reverting to uniform plan.")
                p = np.ones_like(p) / p.size
            else:
                raise RuntimeError("!! Numerical error in OT plan: plan sum is zero. Process terminated.")
        return p, log_uv

    def sample_map(self, pi, batch_size, replace=False):
        r"""Draw source and target samples from pi  $(x,z) \sim \pi$

        Parameters
        ----------
        pi : numpy array, shape (bs, bs)
            represents the source minibatch
        batch_size : int
            represents the OT plan between minibatches
        replace : bool
            represents sampling or without replacement from the OT plan 
            Always: One target can be coupled with multiple sources, multiple source can be coupled with one target.
            True: source-target pair can be sampled again.
            False: source-target pair can be sampled only once.

        Returns
        -------
        (i_s, i_j) : tuple of numpy arrays, shape (bs, bs)
            represents the indices of source and target data samples from $\pi$
        """
        p = pi.flatten()
        p = p / p.sum()
        choices = np.random.choice(
            pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace
        )
        return np.divmod(choices, pi.shape[1])

    def sample_map_fixed_source(self, pi, batch_size):
        r"""Draw source and target samples from pi  $(x,z) \sim \pi$

        Parameters
        ----------
        pi : numpy array, shape (bs, bs)
            represents the source minibatch
        batch_size : int
            represents the OT plan between minibatches

        Returns
        -------
        (i_s, i_j) : tuple of numpy arrays, shape (bs, bs)
            represents the indices of source and target data samples from $\pi$
            i_s: identity indices for source
            i_j: new target indices (multiple source can be coupled with one target)
        """
        pi = pi / (pi.sum(axis=1, keepdims=True) + 1e-12) # pi row normalize
        i_s = np.arange(batch_size) # source index is fixed
        i_j = np.array([np.random.choice(pi.shape[1], p=pi[row]) for row in range(batch_size)]) # target index is sampled from pi[row]
        return i_s, i_j

    def sample_map_fixed_target(self, pi, batch_size):
        r"""Draw source and target samples from pi  $(x,z) \sim \pi$

        Parameters
        ----------
        pi : numpy array, shape (bs, bs)
            represents the source minibatch
        batch_size : int
            represents the OT plan between minibatches

        Returns
        -------
        (i_s, i_j) : tuple of numpy arrays, shape (bs, bs)
            represents the indices of source and target data samples from $\pi$
            i_s: identity indices for source
            i_j: new target indices (multiple target can be coupled with one source)
        """
        pi = pi / (pi.sum(axis=0, keepdims=True) + 1e-12) # pi column normalize
        i_j = np.arange(batch_size) # target index is fixed
        i_s = np.array([np.random.choice(pi.shape[0], p=pi[:,col]) for col in range(batch_size)]) # source index is sampled from pi[:,col]
        return i_s, i_j

    def sample_map_inverse_prob(self, pi, batch_size, replace=False):
        r"""Draw source and target samples from pi  $(x,z) \sim \pi$

        Parameters
        ----------
        pi : numpy array, shape (bs, bs)
            represents the source minibatch
        batch_size : int
            represents the OT plan between minibatches
        replace : bool
            represents sampling or without replacement from the OT plan 
            Always: One target can be coupled with multiple sources, multiple source can be coupled with one target.
            True: source-target pair can be sampled again.
            False: source-target pair can be sampled only once.

        Returns
        -------
        (i_s, i_j) : tuple of numpy arrays, shape (bs, bs)
            represents the indices of source and target data samples from $\pi$
        """
        p = pi.flatten()
        if p.sum() == 0: #  0  
            p = np.ones_like(p)
        else: 
            p = p / p.sum() # normalize
        p = 1.0 - p # inverse probability
        if p.sum() == 0:  #  0    
            p = np.ones_like(p)
        p /= p.sum() # re-normalize
        choices = np.random.choice(
            pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace
        )
        return np.divmod(choices, pi.shape[1])

    def sample_plan(self, x0, x1, replace=False):
        r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
        minibatch and draw source and target samples from pi $(x,z) \sim \pi$

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        replace : bool
            represents sampling or without replacement from the OT plan

        Returns
        -------
        x0[i] : Tensor, shape (bs, *dim)
            represents the source minibatch drawn from $\pi$
        x1[j] : Tensor, shape (bs, *dim)
            represents the source minibatch drawn from $\pi$
        """
        pi, log_uv = self.get_map(x0, x1)
        i, j = self.sample_map(pi, x0.shape[0], replace=replace)
        return x0[i], x1[j]

    def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=False):
        r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
        minibatch and draw source and target labeled samples from pi $(x,z) \sim \pi$

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the target minibatch
        y0 : Tensor, shape (bs)
            represents the source label minibatch
        y1 : Tensor, shape (bs)
            represents the target label minibatch
        replace : bool
            represents sampling or without replacement from the OT plan

        Returns
        -------
        x0[i] : Tensor, shape (bs, *dim)
            represents the source minibatch drawn from $\pi$
        x1[j] : Tensor, shape (bs, *dim)
            represents the target minibatch drawn from $\pi$
        y0[i] : Tensor, shape (bs, *dim)
            represents the source label minibatch drawn from $\pi$
        y1[j] : Tensor, shape (bs, *dim)
            represents the target label minibatch drawn from $\pi$
        """
        pi, log_uv = self.get_map(x0, x1)
        i, j = self.sample_map(pi, x0.shape[0], replace=replace)
        return (
            x0[i],
            x1[j],
            y0[i] if y0 is not None else None,
            y1[j] if y1 is not None else None,
        )
    
    def sample_plan_with_weights_and_indices(self, x0, x1, replace=False, fixed_source=False, fixed_target=False):
        r"""Compute the OT plan $\pi$ (wrt squared Euclidean cost) between a source and a target
        minibatch and draw source and target samples from pi $(x,z) \sim \pi$

        Parameters
        ----------
        x0 : Tensor, shape (bs, *dim)
            represents the source minibatch
        x1 : Tensor, shape (bs, *dim)
            represents the source minibatch
        replace : bool
            represents sampling or without replacement from the OT plan. 
            Always: One target can be coupled with multiple sources, multiple source can be coupled with one target.
            True: source-target pair can be sampled again.
            False: source-target pair can be sampled only once.
        fixed_source : bool
            if True, sample is sampled from per row of pi. index of source is fixed.
        fixed_target : bool
            if True, sample is sampled from per column of pi. index of target is fixed.

        Returns
        -------
        x0[i] : Tensor, shape (bs, *dim)
            represents the source minibatch drawn from $\pi$
        x1[j] : Tensor, shape (bs, *dim)
            represents the target minibatch drawn from $\pi$
        u : Tensor, shape (bs, *dim)
            weights of the source minibatch
        v : Tensor, shape (bs, *dim)
            weights of the target minibatch
        """
        pi, log_uv = self.get_map(x0, x1)
        if log_uv is not None:
            u = torch.exp(torch.tensor(log_uv['logu']))
            v = torch.exp(torch.tensor(log_uv['logv']))
        else:
            u = None
            v = None
        if fixed_source:
            i, j = self.sample_map_fixed_source(pi, x0.shape[0])
        elif fixed_target:
            i, j = self.sample_map_fixed_target(pi, x0.shape[0])
        else:
            i, j = self.sample_map(pi, x0.shape[0], replace=replace)

        #print(f"[optimal_transport.py OTPlanSampler] mean of u: {u.mean()}, mean of v: {v.mean()}")

        return x0[i], x1[j], torch.tensor(pi), u, v, i, j

    def sample_trajectory(self, X):
        """Compute the OT trajectories between different sample populations moving from the source
        to the target distribution.

        Parameters
        ----------
        X : Tensor, (bs, times, *dim)
            different populations of samples moving from the source to the target distribution.

        Returns
        -------
        to_return : Tensor, (bs, times, *dim)
            represents the OT sampled trajectories over time.
        """
        times = X.shape[1]
        pis = []
        for t in range(times - 1):
            pi, log_uv = self.get_map(X[:, t], X[:, t + 1])
            pis.append(pi)

        indices = [np.arange(X.shape[0])]
        for pi in pis:
            j = []
            for i in indices[-1]:
                j.append(np.random.choice(pi.shape[1], p=pi[i] / pi[i].sum()))
            indices.append(np.array(j))

        to_return = []
        for t in range(times):
            to_return.append(X[:, t][indices[t]])
        to_return = np.stack(to_return, axis=1)
        return to_return


def wasserstein(
    x0: torch.Tensor,
    x1: torch.Tensor,
    method: Optional[str] = None,
    reg: float = 0.05,
    power: int = 2,
    **kwargs,
) -> float:
    """Compute the Wasserstein (1 or 2) distance (wrt Euclidean cost) between a source and a target
    distributions.

    Parameters
    ----------
    x0 : Tensor, shape (bs, *dim)
        represents the source minibatch
    x1 : Tensor, shape (bs, *dim)
        represents the source minibatch
    method : str (default : None)
        Use exact Wasserstein or an entropic regularization
    reg : float (default : 0.05)
        Entropic regularization coefficients
    power : int (default : 2)
        power of the Wasserstein distance (1 or 2)
    Returns
    -------
    ret : float
        Wasserstein distance
    """
    assert power == 1 or power == 2
    # ot_fn should take (a, b, M) as arguments where a, b are marginals and
    # M is a cost matrix
    if method == "exact" or method is None:
        ot_fn = pot.emd2
    elif method == "sinkhorn":
        ot_fn = partial(pot.sinkhorn2, reg=reg)
    else:
        raise ValueError(f"Unknown method: {method}")

    a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
    if x0.dim() > 2:
        x0 = x0.reshape(x0.shape[0], -1)
    if x1.dim() > 2:
        x1 = x1.reshape(x1.shape[0], -1)
    M = torch.cdist(x0, x1)
    if power == 2:
        M = M**2
    ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=int(1e7))
    if power == 2:
        ret = math.sqrt(ret)
    return ret
