# -*- coding: utf-8 -*-
"""
PyTorch implementations matching the original NumPy-based utils for FGW.
Follows the same logic and API, but uses torch tensors and GPU when available.
"""

import os
from typing import List, Dict, Optional, Tuple
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity  # kept for safe_cosine_similarity (exact behavior)
from dotenv import load_dotenv, find_dotenv
import time

from kt_gen.knowledge_graph.utils.pot_gpu.fused_gromov_sinkhorn import (
    fused_gromov_wasserstein_sinkhorn,
    fused_gromov_wasserstein2_sinkhorn,
)




# ==== Model path (unchanged)
load_dotenv(find_dotenv())
model_path = os.getenv("MODEL_PATH") or "sentence-transformers/all-MiniLM-L6-v2"


# ========== Embeddings ==========

def compute_embeddings(graphs: List[nx.Graph],
                       model: Optional[SentenceTransformer] = None,
                       device: Optional[torch.device] = None,
                       dtype: torch.dtype = torch.float64) -> List[torch.Tensor]:
    """
    Same API as before, but returns torch tensors [ns, d], normalized.
    """
    if model is None:
        model = SentenceTransformer(model_path, trust_remote_code=False)

    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    out: List[torch.Tensor] = []
    for G in graphs:
        texts = [G.nodes[n]["text"] for n in G.nodes()]
        # normalize_embeddings=True already L2-normalizes; convert_to_tensor keeps torch
        E = model.encode(texts, normalize_embeddings=True, convert_to_tensor=True, device=str(device))
        E = E.to(dtype=dtype)  # enforce float64 for numerical parity with POT
        out.append(E)
    return out


# ========== Distance normalizations (torch) ==========

def _ensure_tensor(x, dtype=torch.float64, device=None) -> torch.Tensor:
    if isinstance(x, torch.Tensor):
        return x.to(dtype=dtype, device=device if device is not None else x.device)
    xt = torch.as_tensor(x, dtype=dtype)
    return xt.to(device) if device is not None else xt


def normalize_distance(D: torch.Tensor, factor: float = 1.5) -> torch.Tensor:
    D = _ensure_tensor(D)
    D = D.clone()
    finite = torch.isfinite(D) & (D > 0)
    if finite.any():
        penalty = torch.quantile(D[finite], 0.95) * factor
    else:
        penalty = torch.tensor(1.0, dtype=D.dtype, device=D.device)
    D[torch.isinf(D)] = penalty
    # keep zeros on diagonal
    if D.ndim == 2 and D.shape[0] == D.shape[1]:
        D.fill_diagonal_(0)
    return D


def normalize_distance_exp(D: torch.Tensor, factor: float = 1.5, sim_power: float = 1.0) -> torch.Tensor:
    D = _ensure_tensor(D)
    D = D.clone()
    penalty = torch.exp(torch.tensor(2.0, dtype=D.dtype, device=D.device)) * factor
    D[torch.isinf(D)] = penalty
    D[D <= 0] = penalty
    if D.ndim == 2 and D.shape[0] == D.shape[1]:
        D.fill_diagonal_(0.0)
    return D


# ========== Structure distances (torch) ==========

def compute_structure_distances(graphs: List[nx.Graph], factor: float = 1.5) -> List[torch.Tensor]:
    """
    Uses Floyd–Warshall via networkx (as in original), then converts to torch and normalizes.
    """
    dists: List[torch.Tensor] = []
    for G in graphs:
        # exact same method as before
        D_np = np.array(nx.floyd_warshall_numpy(G), dtype=np.float64)
        D = torch.from_numpy(D_np)
        D = normalize_distance(D, factor=factor)
        dists.append(D)
    return dists


def compute_similarity_weighted_structure_distances(
    graphs: List[nx.Graph],
    embeddings: List[torch.Tensor],
    factor: float = 1.5,
    similarity_power: float = 1.0
) -> List[torch.Tensor]:
    """
    Follows the original logic exactly:
    - build a weighted graph where edge weight = 1/(exp(sim)/e)**power
    - then take the adjacency matrix (not shortest paths), normalize with 'exp' scheme
    """
    weighted_dists: List[torch.Tensor] = []

    for G, E in zip(graphs, embeddings):
        # Ensure E is [ns, d] torch tensor, normalized (SentenceTransformer already did)
        E = _ensure_tensor(E)

        # Copy topology into weighted graph
        G_weighted = nx.DiGraph() if G.is_directed() else nx.Graph()
        G_weighted.add_nodes_from(G.nodes(data=True))

        nodes_list = list(G.nodes)
        idx = {n: i for i, n in enumerate(nodes_list)}

        for u, v in G.edges():
            eu = E[idx[u]]
            ev = E[idx[v]]
            # cosine similarity (manual, to mirror sklearn behavior with normalized E)
            sim = F.cosine_similarity(eu.unsqueeze(0), ev.unsqueeze(0)).item()
            transfo_sim = np.exp(sim) / np.exp(1)   # keep identical transform
            adjusted_weight = 1.0 / ((transfo_sim ** similarity_power) + 1e-6)
            G_weighted.add_edge(u, v, weight=float(adjusted_weight))

        # Exactly like the original: adjacency matrix, then normalize_distance_exp
        A = nx.adjacency_matrix(G_weighted, weight='weight').toarray().astype(np.float64)
        D = torch.from_numpy(A)
        D = normalize_distance_exp(D, factor=factor, sim_power=similarity_power)
        weighted_dists.append(D)

    return weighted_dists


# ========== Cosine helper (unchanged semantics) ==========

def safe_cosine_similarity(vec1, vec2) -> float:
    v1 = np.asarray(vec1)
    v2 = np.asarray(vec2)
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)
    if norm1 == 0 or norm2 == 0:
        return 1/np.exp(2)
    return float(cosine_similarity([v1], [v2])[0, 0])


# ========== FGW / GW using Sinkhorn (torch) ==========

# Import your stabilized FGW (torch) implementation
# Expecting: fused_gromov_wasserstein_sinkhorn / fused_gromov_wasserstein2_sinkhorn

def compute_fgw(
    F1: torch.Tensor,
    D1: torch.Tensor,
    F2: torch.Tensor,
    D2: torch.Tensor,
    alpha: float = 0.5,
    *,
    eps: float = 2e-2,
    sinkhorn_max_iter: int = 1000,
    sinkhorn_tol: float = 1e-9
) -> Tuple[float, torch.Tensor]:
    """
    Torch version mirroring original semantics:
    - uniform p, q
    - handle NaNs like before
    - M = pairwise euclidean distances
    - returns (dist, T)
    """
    device = F1.device
    dtype = torch.float64

    F1 = _ensure_tensor(F1, dtype=dtype, device=device)
    F2 = _ensure_tensor(F2, dtype=dtype, device=device)
    D1 = _ensure_tensor(D1, dtype=dtype, device=device)
    D2 = _ensure_tensor(D2, dtype=dtype, device=device)

    ns, nt = F1.shape[0], F2.shape[0]
    p = torch.full((ns,), 1.0/ns, dtype=dtype, device=device)
    q = torch.full((nt,), 1.0/nt, dtype=dtype, device=device)

    factor = 1.5
    # NaN handling (same constants)
    penalty = torch.exp(torch.tensor(2.0, dtype=dtype, device=device)) * factor
    D1 = torch.nan_to_num(D1, nan=penalty.item())
    D2 = torch.nan_to_num(D2, nan=penalty.item())

    # Feature cost M
    M = torch.cdist(F1, F2, p=2).to(dtype)

    # Transport plan (Sinkhorn FGW)
    T = fused_gromov_wasserstein_sinkhorn(
        M, D1, D2, p, q, alpha=alpha,
        eps=eps, numItermax=200, sinkhorn_max_iter=sinkhorn_max_iter,
        sinkhorn_tol=sinkhorn_tol, log=False
    )
    # Distance (FGW2)
    dist = fused_gromov_wasserstein2_sinkhorn(
        M, D1, D2, p, q, alpha=alpha,
        eps=eps, numItermax=200, sinkhorn_max_iter=sinkhorn_max_iter,
        sinkhorn_tol=sinkhorn_tol, log=False
    )
    return float(dist.item()), T


def compute_gw(D1: torch.Tensor, D2: torch.Tensor) -> Tuple[float, torch.Tensor]:
    """
    Torch version of GW with uniform p/q (uses FGW with M=0 and alpha=1 for GW weight).
    Returns (gw_dist, T)
    """
    device = D1.device
    dtype = torch.float64
    D1 = _ensure_tensor(D1, dtype=dtype, device=device)
    D2 = _ensure_tensor(D2, dtype=dtype, device=device)
    ns, nt = D1.shape[0], D2.shape[0]
    p = torch.full((ns,), 1.0/ns, dtype=dtype, device=device)
    q = torch.full((nt,), 1.0/nt, dtype=dtype, device=device)
    M = torch.zeros((ns, nt), dtype=dtype, device=device)

    # Plan + distance
    T = fused_gromov_wasserstein_sinkhorn(M, D1, D2, p, q, alpha=1.0, eps=2e-2, log=False)
    dist = fused_gromov_wasserstein2_sinkhorn(M, D1, D2, p, q, alpha=1.0, eps=2e-2, log=False)
    return float(dist.item()), T


# ========== Graph helpers (unchanged behavior) ==========

def assign_depths(G: nx.DiGraph):
    if not nx.is_directed_acyclic_graph(G):
        raise ValueError("Le graphe doit être un DAG pour que la profondeur soit bien définie.")
    depths = {}
    for node in nx.topological_sort(G):
        preds = list(G.predecessors(node))
        depths[node] = 0 if not preds else 1 + max(depths[pred] for pred in preds)
        G.nodes[node]["depth"] = depths[node]


def compute_node_features(G, embeddings: Dict[str, np.ndarray]) -> torch.Tensor:
    # unchanged semantics, returns torch tensor
    arr = np.array([embeddings[n] for n in G.nodes], dtype=np.float64)
    return torch.from_numpy(arr)


def compute_structure_matrix(G) -> torch.Tensor:
    D_np = np.array(nx.floyd_warshall_numpy(G), dtype=np.float64)
    return torch.from_numpy(D_np)


# ========== Fusion helpers used by genetic optimized traversal (torch) ==========

def _fuse_text_compact(node_ids: List[str], G: nx.DiGraph, compact_fusion: bool = True) -> str:
    # Same behavior as _create_fused_text_compact in kg_fgw.py; kept here for test_fusion
    if not compact_fusion:
        return " | ".join([G.nodes[n]["text"] for n in node_ids])

    texts_by_section = {}
    for node_id in node_ids:
        node_text = G.nodes[node_id]["text"]
        if node_id.startswith("section_"):
            section_num = node_id.split("_")[1]
        else:
            section_num = "unknown"
        texts_by_section.setdefault(section_num, []).append((node_id, node_text))

    fused_texts = []
    def _safe_int(x):
        try:
            return int(x)
        except Exception:
            return 10**9

    for section_num in sorted(texts_by_section.keys(), key=_safe_int):
        section_items = texts_by_section[section_num]
        section_items.sort(key=lambda item: (
            _safe_int(item[0].split("_")[3] if len(item[0].split("_")) > 3 else 10**9),
            _safe_int(item[0].split("_")[5] if len(item[0].split("_")) > 5 else 10**9),
        ))
        section_text = " ".join([item[1] for item in section_items])
        fused_texts.append(f"From section_{section_num}: {section_text}")
    return "\n----\n".join(fused_texts)


def _delete_rows_cols(M: torch.Tensor, indices_desc: List[int]) -> torch.Tensor:
    """
    Delete rows/cols at positions in 'indices_desc' (sorted descending) exactly like np.delete used twice.
    """
    keep = torch.ones(M.shape[0], dtype=torch.bool, device=M.device)
    keep[indices_desc] = False
    M2 = M[keep][:, keep]
    return M2


def test_fusion(tmp_set: List[str], candidate_set: List[str], G: nx.DiGraph, G_fus: nx.DiGraph,
                map_index_node: Dict[str, int],
                embedding_current_graph: List[torch.Tensor], dists_current_graph: List[torch.Tensor],
                model: SentenceTransformer, current: str, depth: int, uuid_fus: str,
                initial_prob: float, alpha: float, structure_distance: str,
                compact_fusion: bool = False,
                eps: float = 2e-2, sinkhorn_max_iter: int = 1000, sinkhorn_tol: float = 1e-9
               ) -> Tuple[float, List[torch.Tensor], List[torch.Tensor], nx.DiGraph, Dict[str, int]]:

    new_embeddings_fus = [embedding_current_graph[0].clone()]
    new_distance_struct = [dists_current_graph[0].clone()]
    G_fus_copy = G_fus.copy()
    map_index_node_copy = dict(map_index_node)

    fused_text = _fuse_text_compact(tmp_set, G, compact_fusion)
    tmp_node_id = f"fused_{current}_{depth}_{uuid_fus}"

    # remove nodes in tmp_set (highest index first)
    indices_to_remove = sorted([map_index_node_copy[n] for n in tmp_set], reverse=True)
    if len(indices_to_remove) > 0:
        new_embeddings_fus[0] = torch.vstack([
            new_embeddings_fus[0][:min_idx]
            if i == 0 else  # just to placate editor
            new_embeddings_fus[0]
            for i, min_idx in []
        ]) if False else new_embeddings_fus[0]  # no-op placeholder to keep structure similar
        new_embeddings_fus[0] = new_embeddings_fus[0][torch.tensor(
            [i for i in range(new_embeddings_fus[0].shape[0]) if i not in set(indices_to_remove)],
            dtype=torch.long, device=new_embeddings_fus[0].device
        )]

        new_distance_struct[0] = _delete_rows_cols(new_distance_struct[0], indices_to_remove)

        # update indices map
        for n2, idx in list(map_index_node_copy.items()):
            for cut in indices_to_remove:
                if idx > cut:
                    idx -= 1
            map_index_node_copy[n2] = idx

    # add fused node in graph and embeddings
    if tmp_node_id in G_fus_copy.nodes:
        G_fus_copy.nodes[tmp_node_id]["text"] = fused_text
    else:
        G_fus_copy.add_node(tmp_node_id, text=fused_text, type="fused")
        # encode fused text
        new_embedding = model.encode([fused_text], normalize_embeddings=True, convert_to_tensor=True).to(dtype=torch.float64)
        new_embeddings_fus[0] = torch.cat([new_embeddings_fus[0], new_embedding], dim=0)

        # extend distance matrix with penalty rows/cols
        pen = torch.exp(torch.tensor(2.0, dtype=new_distance_struct[0].dtype, device=new_distance_struct[0].device)) * 1.5
        num_nodes = new_distance_struct[0].shape[0]
        col = torch.full((num_nodes, 1), pen, dtype=new_distance_struct[0].dtype, device=new_distance_struct[0].device)
        row = torch.full((1, num_nodes + 1), pen, dtype=new_distance_struct[0].dtype, device=new_distance_struct[0].device)
        new_distance_struct[0] = torch.cat([new_distance_struct[0], col], dim=1)
        new_distance_struct[0] = torch.cat([new_distance_struct[0], row], dim=0)
        new_distance_struct[0][-1, -1] = 0.0

    # add edges for fused node
    for n in tmp_set:
        for gc in G.successors(n):
            G_fus_copy.add_edge(
                tmp_node_id, gc,
                prob=G.edges[n, gc].get("prob", initial_prob),
                structure_distance=G.edges[n, gc].get("structure_distance", 1.0)
            )
            if gc in map_index_node_copy:
                new_distance_struct[0][-1, map_index_node_copy[gc]] = G.edges[n, gc].get("structure_distance", 1.0)

    # edge from current to fused node
    new_prob = float(np.mean([G.edges[current, gc].get("prob", initial_prob) for gc in tmp_set])) if tmp_set else initial_prob
    new_structure_distance = float(np.mean([G.edges[current, gc].get("structure_distance", 1.0) for gc in tmp_set])) if tmp_set else 1.0
    G_fus_copy.add_edge(current, tmp_node_id, prob=new_prob, structure_distance=new_structure_distance)

    if current in map_index_node_copy:
        new_distance_struct[0][map_index_node_copy[current], -1] = new_structure_distance

    # remove old nodes
    for child_id in tmp_set:
        if G_fus_copy.has_edge(current, child_id):
            G_fus_copy.remove_edge(current, child_id)
        if G_fus_copy.has_node(child_id):
            G_fus_copy.remove_node(child_id)

    # structure distances
    if structure_distance != "similarity_weighted":
        dists = compute_structure_distances([G_fus_copy])
    else:
        dists = new_distance_struct  # keep incremental matrix

    # FGW distance
    time_fgw = time.time()
    fgw_dist, _ = compute_fgw(
        embedding_current_graph[0], dists_current_graph[0],
        new_embeddings_fus[0], dists[0], alpha=alpha,
        eps=eps, sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol
    )
    
    print(f"[INFO] Time for the fgw computation : {time.time() - time_fgw:.4f} seconds")

    return fgw_dist, new_embeddings_fus, dists, G_fus_copy, map_index_node_copy


def dichotomie_fusion(sorted_children_proba, candidate_set, candidate_vecs,
                      G, G_fus, map_index_node,
                      embedding_current_graph, dists_current_graph,
                      model, current, depth, uuid_fus,
                      initial_prob, alpha, structure_distance,
                      distance_restante, compact_fusion=False, enable_logging=False,
                      eps: float = 2e-2, sinkhorn_max_iter: int = 1000, sinkhorn_tol: float = 1e-9):
    """
    Binary search identical to original, but torch-based FGW calls.
    """
    left, right = 1, len(sorted_children_proba)
    best_k = 0
    best_candidate_set = candidate_set
    best_candidate_vecs = candidate_vecs[:]
    best_embedding = embedding_current_graph
    best_distance = dists_current_graph
    best_G_fus = G_fus
    best_map_index_node = dict(map_index_node)

    while left <= right:
        mid = (left + right) // 2
        tmp_set = candidate_set + [cid for cid, _ in sorted_children_proba[:mid]]
        tmp_vecs = candidate_vecs + [s for _, s in sorted_children_proba[:mid]]

        fgw_dist, new_emb, new_dist, new_G_fus, map_index_node_copy = test_fusion(
            tmp_set, candidate_set, G, G_fus, map_index_node,
            embedding_current_graph, dists_current_graph,
            model, current, depth, uuid_fus,
            initial_prob, alpha, structure_distance, compact_fusion,
            eps=eps, sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol
        )
        if enable_logging:
            print(f"[DEBUG] mid={mid}, fgw_dist={fgw_dist:.4f}, distance_restante={distance_restante:.4f}")
        if fgw_dist <= distance_restante:
            best_k = mid
            best_candidate_set = tmp_set
            best_candidate_vecs = tmp_vecs
            best_embedding = new_emb
            best_distance = new_dist
            best_G_fus = new_G_fus
            best_map_index_node = map_index_node_copy
            left = mid + 1
        else:
            right = mid - 1

    if enable_logging:
        print(f"[INFO] FGW distance: {fgw_dist:.4f}")
        print(f"[INFO] Optimal set containing: {best_candidate_set} (FGW<=budget)")
        print(f"[INFO] Total candidate set: {sorted_children_proba}")

    return best_candidate_set, best_candidate_vecs, best_embedding, best_distance, best_G_fus, best_map_index_node
