"""
Gene-Relation Clustering (GRC)

Balanced gene clustering combining:
- Semantic embeddings (GeneCompass)
- PPI priors (STRING)
- Optimal transport for balanced assignment
"""

import os
import pickle
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

try:
    import ot
    HAS_OT = True
except ImportError:
    HAS_OT = False
    print("Warning: POT not installed. pip install POT")


def load_gene_embeddings(embedding_path: str, 
                         gene_names: List[str],
                         fill_missing: bool = True,
                         random_seed: int = 42) -> Tuple[np.ndarray, float]:
    """Load gene embeddings from GeneCompass or similar."""
    with open(embedding_path, 'rb') as f:
        data = pickle.load(f)
    
    if 'gene_symbol_embeddings' in data:
        gene_emb_dict = data['gene_symbol_embeddings']
    elif 'gene_embeddings' in data:
        gene_emb_dict = data['gene_embeddings']
    else:
        raise ValueError(f"Unknown format. Keys: {data.keys()}")
    
    embed_dim = data.get('embedding_dim', 768)
    n_genes = len(gene_names)
    
    embeddings = []
    found_count = 0
    
    for gene in gene_names:
        if gene in gene_emb_dict:
            embeddings.append(np.array(gene_emb_dict[gene]))
            found_count += 1
        else:
            embeddings.append(None)
    
    coverage = found_count / n_genes
    
    if fill_missing:
        np.random.seed(random_seed)
        for i in range(len(embeddings)):
            if embeddings[i] is None:
                embeddings[i] = np.random.randn(embed_dim) * 0.01
    
    embeddings = np.array(embeddings, dtype=np.float32)
    return embeddings, coverage


def load_ppi_network(ppi_path: str,
                     protein_info_path: str,
                     gene_names: List[str],
                     min_score: int = 700) -> np.ndarray:
    """Load STRING PPI and build adjacency matrix."""
    protein_info = pd.read_csv(protein_info_path, sep='\t')
    protein_to_gene = dict(zip(
        protein_info['#string_protein_id'], 
        protein_info['preferred_name']
    ))
    
    gene_to_idx = {gene: idx for idx, gene in enumerate(gene_names)}
    n_genes = len(gene_names)
    
    ppi_df = pd.read_csv(ppi_path, sep=' ')
    ppi_df = ppi_df[ppi_df['combined_score'] >= min_score]
    
    ppi_adj = np.zeros((n_genes, n_genes), dtype=np.float32)
    n_edges = 0
    
    for _, row in ppi_df.iterrows():
        gene1 = protein_to_gene.get(row['protein1'])
        gene2 = protein_to_gene.get(row['protein2'])
        
        if gene1 in gene_to_idx and gene2 in gene_to_idx:
            i, j = gene_to_idx[gene1], gene_to_idx[gene2]
            score = row['combined_score'] / 1000.0
            ppi_adj[i, j] = score
            ppi_adj[j, i] = score
            n_edges += 1
    
    return ppi_adj


def compute_ot_cost_matrix(embeddings: np.ndarray,
                           centroids: np.ndarray,
                           ppi_adj: np.ndarray,
                           labels: np.ndarray) -> np.ndarray:
    """
    Dual-prior cost: embedding_dist + ppi_penalty
    """
    n_genes = embeddings.shape[0]
    n_clusters = centroids.shape[0]
    
    C_embed = cdist(embeddings, centroids, metric='sqeuclidean')
    C_embed = C_embed / (C_embed.max() + 1e-8)
    
    C_ppi = np.zeros((n_genes, n_clusters), dtype=np.float32)
    
    for j in range(n_clusters):
        in_cluster = (labels == j)
        if np.sum(in_cluster) > 0:
            ppi_to_outside = ppi_adj[:, ~in_cluster].sum(axis=1)
            ppi_to_inside = ppi_adj[:, in_cluster].sum(axis=1)
            C_ppi[:, j] = ppi_to_outside - ppi_to_inside
    
    if C_ppi.max() - C_ppi.min() > 1e-8:
        C_ppi = (C_ppi - C_ppi.min()) / (C_ppi.max() - C_ppi.min() + 1e-8)
    
    cost = C_embed + C_ppi
    return cost


def balanced_ot_clustering(embeddings: np.ndarray,
                           ppi_adj: np.ndarray,
                           n_clusters: int,
                           n_iterations: int = 50,
                           ot_reg: float = 0.05,
                           random_seed: int = 42,
                           verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """Main OT clustering with dual priors."""
    if not HAS_OT:
        raise RuntimeError("POT library required")
    
    n_genes = embeddings.shape[0]
    
    # init with kmeans++
    from sklearn.cluster import KMeans
    kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=random_seed)
    kmeans.fit(embeddings)
    centroids = kmeans.cluster_centers_.copy()
    labels = kmeans.labels_.copy()
    
    a = np.ones(n_genes) / n_genes
    b = np.ones(n_clusters) / n_clusters
    
    for t in range(n_iterations):
        cost = compute_ot_cost_matrix(
            embeddings, centroids, ppi_adj, labels
        )
        
        try:
            transport_plan = ot.sinkhorn(
                a, b, cost, ot_reg,
                numItermax=1000, stopThr=1e-9, warn=False
            )
        except:
            transport_plan = ot.emd(a, b, cost)
        
        new_labels = np.argmax(transport_plan, axis=1)
        changed = np.sum(new_labels != labels)
        labels = new_labels
        
        for j in range(n_clusters):
            mask = labels == j
            if np.sum(mask) > 0:
                centroids[j] = embeddings[mask].mean(axis=0)
        
        if verbose and (t + 1) % 5 == 0:
            sizes = np.bincount(labels, minlength=n_clusters)
            print(f"  Iter {t+1}: changed={changed}, sizes=[{sizes.min()}-{sizes.max()}]")
        
        if changed == 0:
            if verbose:
                print(f"  Converged at iter {t+1}")
            break
    
    return labels, centroids


def reorder_genes_by_cluster(gene_names: List[str],
                             labels: np.ndarray) -> Tuple[List[str], np.ndarray]:
    """Reorder genes so same-cluster genes are adjacent."""
    n_clusters = labels.max() + 1
    gene_order = np.argsort(labels)
    sorted_gene_names = [gene_names[i] for i in gene_order]
    sorted_labels = labels[gene_order]
    return sorted_gene_names, sorted_labels


class GeneRelationClustering:
    """
    GRC: balanced gene clustering with semantic + PPI priors.
    
    Usage:
        grc = GeneRelationClustering(n_clusters=32)
        grc.load_priors(emb_path, ppi_path, info_path, genes)
        labels = grc.fit()
        sorted_genes, sorted_labels = grc.reorder_genes()
    """
    
    def __init__(self,
                 n_clusters: int = 32,
                 n_iterations: int = 50,
                 ot_reg: float = 0.05,
                 ppi_min_score: int = 700,
                 random_seed: int = 42):
        self.n_clusters = n_clusters
        self.n_iterations = n_iterations
        self.ot_reg = ot_reg
        self.ppi_min_score = ppi_min_score
        self.random_seed = random_seed
        
        self.embeddings = None
        self.ppi_adj = None
        self.gene_names = None
        self.labels = None
        self.centroids = None
    
    def load_priors(self,
                    embedding_path: str,
                    ppi_path: str,
                    protein_info_path: str,
                    gene_names: List[str]) -> None:
        """Load embeddings and PPI."""
        self.gene_names = gene_names
        
        print(f"Loading priors for {len(gene_names)} genes...")
        
        self.embeddings, coverage = load_gene_embeddings(
            embedding_path, gene_names,
            fill_missing=True, random_seed=self.random_seed
        )
        print(f"  Embedding coverage: {coverage*100:.1f}%")
        
        self.ppi_adj = load_ppi_network(
            ppi_path, protein_info_path, gene_names,
            min_score=self.ppi_min_score
        )
        ppi_genes = np.sum(np.any(self.ppi_adj > 0, axis=1))
        print(f"  Genes with PPI: {ppi_genes}")
    
    def fit(self, verbose: bool = True) -> np.ndarray:
        """Run clustering."""
        if self.embeddings is None:
            raise ValueError("Call load_priors() first")
        
        if verbose:
            print(f"\nGRC clustering (K={self.n_clusters})...")
        
        self.labels, self.centroids = balanced_ot_clustering(
            self.embeddings,
            self.ppi_adj,
            self.n_clusters,
            n_iterations=self.n_iterations,
            ot_reg=self.ot_reg,
            random_seed=self.random_seed,
            verbose=verbose
        )
        
        if verbose:
            sizes = np.bincount(self.labels, minlength=self.n_clusters)
            print(f"  Sizes: min={sizes.min()}, max={sizes.max()}, mean={sizes.mean():.1f}")
            ppi_coherence = self._compute_ppi_coherence()
            print(f"  PPI coherence: {ppi_coherence:.4f}")
        
        return self.labels
    
    def reorder_genes(self) -> Tuple[List[str], np.ndarray]:
        """Reorder genes by cluster."""
        if self.labels is None:
            raise ValueError("Call fit() first")
        return reorder_genes_by_cluster(self.gene_names, self.labels)
    
    def _compute_ppi_coherence(self) -> float:
        """Mean PPI density within clusters."""
        coherences = []
        for j in range(self.n_clusters):
            mask = self.labels == j
            n_in = np.sum(mask)
            if n_in > 1:
                cluster_ppi = self.ppi_adj[np.ix_(mask, mask)]
                density = cluster_ppi.sum() / (n_in * (n_in - 1) + 1e-8)
                coherences.append(density)
        return np.mean(coherences) if coherences else 0.0
    
    def save_results(self, output_path: str) -> None:
        """Save clustering results."""
        results = {
            'n_clusters': self.n_clusters,
            'labels': self.labels,
            'centroids': self.centroids,
            'gene_names': self.gene_names,
            'cluster_sizes': np.bincount(self.labels, minlength=self.n_clusters),
        }
        with open(output_path, 'wb') as f:
            pickle.dump(results, f)
        print(f"Saved to {output_path}")
    
    @classmethod
    def load_results(cls, path: str) -> 'GeneRelationClustering':
        """Load from file."""
        with open(path, 'rb') as f:
            results = pickle.load(f)
        
        grc = cls(
            n_clusters=results['n_clusters']
        )
        grc.labels = results['labels']
        grc.centroids = results['centroids']
        grc.gene_names = results['gene_names']
        
        return grc
