import numpy as np
import oineus as oin
import torch
from enum import Enum
import jax.numpy as jnp
from wgf import *

class Filtration(Enum):
    LS = "ls" # Lower Star filtration
    VR = "vr" # Vietoris-Rips filtration
    ALPHA = "alpha" # Alpha filtration


from dataclasses import dataclass

@dataclass
class LossConfig:
    filtration: Filtration = Filtration.VR
    max_dim: int = 2
    dgm_dim: int = 1
    n_threads: int = 16
    vr_max_radius: float = 2.0
    using_wgf: bool = False
    use_template: bool = False
    eps_to_diagonal: float = 0.2
    DenoiseStrategy: oin.DenoiseStrategy = oin.DenoiseStrategy.Midway
    wass_q: float = 2.0
    wgf: WGF = None 


def topological_loss(pts: torch.Tensor, c: LossConfig) -> torch.Tensor:

    # Compute filtration &  vertices or edges  in cpu
    pts_np = pts.clone().detach().numpy().astype(np.float64)

    if c.filtration == Filtration.VR:
        fil, longest_edges = oin.get_vr_filtration_and_critical_edges(pts_np, max_dim=c.max_dim, max_radius=c.vr_max_radius, n_threads=c.n_threads)
    elif c.filtration == Filtration.LS:
        fil, max_value_vertices = oin.get_freudenthal_filtration_and_critical_vertices(pts_np, negate=False, wrap=False, max_dim=c.max_dim, n_threads=c.n_threads)
    elif c.filtration == Filtration.ALPHA:
        pass

    top_opt = oin.TopologyOptimizer(fil)
    dgm = top_opt.compute_diagram(include_inf_points=False)
    current_dgm = dgm.in_dimension(c.dgm_dim)
    # print(current_dgm)
    
    # The choosen strategy tells how the PD envolves, i.e., how the points in PD moves.
    # & We need to compute the indices of simplices and their target values accordingly.
    template_dgm = None
    if c.use_template:
        current_measure = torch.tensor([(b, d) for (b,d) in current_dgm], dtype=torch.float64)
        # print(current_measure)
        # next_measure = c.wgf.next_measure(current_measure)
        # print(next_measure)
        # template_dgm = [oin.DiagramPoint_double(b, d) for (b, d) in next_measure]
        # print(template_dgm)
        indices, values = top_opt.match(c.template_dgm, dim=c.dgm_dim, wasserstein_q=c.wass_q, return_wasserstein_distance=False)
    elif c.using_wgf:
        # Source: dgm in opt, Target: template_dgm
        # indices: the indices of simplices in source dgm
        # values: the target valueds of the indexed simplices
        current_measure = torch.tensor([(b, d) for (b,d) in current_dgm], dtype=torch.float64)
        # print(current_measure)
        next_measure = c.wgf.next_measure(current_measure)
        c.use_template = True
        # print(next_measure)
        template_dgm = [oin.DiagramPoint_double(b, d) for (b, d) in next_measure]
        c.template_dgm = template_dgm
        # print(template_dgm)
        indices, values = top_opt.match(template_dgm, dim=c.dgm_dim, wasserstein_q=c.wass_q, return_wasserstein_distance=False)
    else:
        indices, values = top_opt.simplify(c.eps_to_diagonal, c.DenoiseStrategy, c.dgm_dim)

    critical_sets = top_opt.singletons(indices, values)
    crit_indices, crit_values = top_opt.combine_loss(critical_sets, oin.ConflictStrategy.Max)

    if c.filtration == Filtration.VR:
        crit_method_edges = longest_edges[crit_indices, :]
        crit_method_edges_x, crit_method_edges_y = crit_method_edges[:, 0], crit_method_edges[:, 1]
        crit_values = torch.tensor(crit_values)
        crit_indices = np.array(crit_indices, dtype=np.int32)
        # loss = torch.sum(torch.abs(fil_values[crit_indices]-crit_values))
        loss = torch.sum(torch.abs(torch.sum((pts[crit_method_edges_x, :] - pts[crit_method_edges_y, :])**2, axis=1) - crit_values ** 2))

    elif c.filtration == Filtration.LS:
        crit_indices = np.array(crit_indices, dtype=np.int32)
        crit_vertices = max_value_vertices[crit_indices]
        crit_vertices = torch.LongTensor(crit_vertices).to(device)
        crit_values = torch.Tensor(crit_values).to(device)
        pts_torch = torch.Tensor(pts.reshape(-1)).to(device)
        loss = torch.sum((pts_torch[crit_vertices] - crit_values) ** 2)

    return loss, current_dgm, template_dgm


# ========================================
# Other Losses


def repulsion_loss(points, epsilon=1e-6):
    """
    Calculate the repulsion loss for a set of points
    
    Parameters:
    points: Tensor of shape (N, 2) where N is the number of points
    epsilon: Small constant to prevent division by zero
    
    Returns:
    The loss value
    """
    N = points.shape[0]
    
    # Compute difference vectors between all pairs of points
    diff = points.unsqueeze(1) - points.unsqueeze(0)  # shape: (N, N, 2)
    
    # Compute squared distances
    dist_sq = torch.sum(diff ** 2, dim=-1)  # shape: (N, N)
    
    # Create a mask to exclude self-comparisons
    mask = torch.eye(N, device=points.device).bool()
    dist_sq = dist_sq.masked_fill(mask, float('inf'))
    
    # Calculate repulsion loss
    loss = torch.sum(1.0 / (dist_sq + epsilon))
    
    return loss