import math
import warnings
from functools import partial
from typing import Optional, Union, Tuple

import numpy as np
import ot as pot
import torch

try:
    from .oat_distance import efficient_batch_oat_distance
except ImportError:
    from oat_distance import efficient_batch_oat_distance


class OATPlanSampler:
    def __init__(
        self,
        method: str = "sinkhorn",
        reg: float = 0.05,
        reg_m: float = 1.0,
        normalize_cost: bool = False,
        num_threads: Union[int, str] = 1,
        warn: bool = True,
        squared_cost: bool = True,
    ) -> None:
        # Configure OT solver function
        if method == "exact":
            self.ot_fn = partial(pot.emd, numThreads=num_threads)
        elif method == "sinkhorn":
            self.ot_fn = partial(pot.sinkhorn, reg=reg)
        elif method == "unbalanced":
            self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m)
        elif method == "partial":
            self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg)
        else:
            raise ValueError(f"Unknown method: {method}")
        
        self.method = method
        self.reg = reg
        self.reg_m = reg_m
        self.normalize_cost = normalize_cost
        self.warn = warn
        self.squared_cost = squared_cost

    def get_oat_cost_matrix(
        self, 
        z0: Tuple[torch.Tensor, torch.Tensor], 
        z1: Tuple[torch.Tensor, torch.Tensor]
    ) -> np.ndarray:
        x0, v0 = z0
        x1, v1 = z1
        
        # Flatten spatial dimensions if needed
        if x0.dim() > 2:
            x0 = x0.reshape(x0.shape[0], -1)
            v0 = v0.reshape(v0.shape[0], -1)
        if x1.dim() > 2:
            x1 = x1.reshape(x1.shape[0], -1)
            v1 = v1.reshape(v1.shape[0], -1)
        
        # Compute OAT cost matrix
        z0_reshaped = (x0, v0)
        z1_reshaped = (x1, v1)
        
        M = efficient_batch_oat_distance(z0_reshaped, z1_reshaped, squared=self.squared_cost)
        M = M.detach().cpu().numpy()
        
        # Apply cost normalization
        if self.normalize_cost:
            M_max = M.max()
            if M_max > 0:
                M = M / M_max
        
        return M

    def get_map(
        self, 
        z0: Tuple[torch.Tensor, torch.Tensor], 
        z1: Tuple[torch.Tensor, torch.Tensor]
    ) -> np.ndarray:
        x0, v0 = z0
        x1, v1 = z1
        
        # Create uniform marginals
        a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0])
        
        # Get OAT cost matrix
        M = self.get_oat_cost_matrix(z0, z1)
        
        # Solve OT problem
        p = self.ot_fn(a, b, M)
        
        # Check for numerical issues
        if not np.all(np.isfinite(p)):
            print("ERROR: OT plan is not finite")
            print(p)
            print("OAT cost mean, max:", M.mean(), M.max())
            
        if np.abs(p.sum()) < 1e-8:
            if self.warn:
                warnings.warn("Numerical errors in OT plan, reverting to uniform plan.")
            p = np.ones_like(p) / p.size
            
        return p

    def sample_map(self, pi, batch_size, replace=True):
        p = pi.flatten()
        p = p / p.sum()
        choices = np.random.choice(
            pi.shape[0] * pi.shape[1], p=p, size=batch_size, replace=replace
        )
        return np.divmod(choices, pi.shape[1])

    def sample_plan(self, z0, z1, replace=True):
        pi = self.get_map(z0, z1)
        x0, v0 = z0
        x1, v1 = z1
        i, j = self.sample_map(pi, x0.shape[0], replace=replace)
        return (x0[i], v0[i]), (x1[j], v1[j])