import os
import numpy as np
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE, trustworthiness
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import seaborn as sns

# Set global font and font size
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Times New Roman'],
    'font.size': 16,
    'axes.labelsize': 16,
    'axes.titlesize': 16,
    'legend.fontsize': 16,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'figure.figsize': (5, 4),  # 24, 6
    'lines.linewidth': 2,
    'lines.markersize': 6,
    'grid.alpha': 0.5,
    'legend.frameon': False,
    'figure.dpi': 300,
})

# Centralized plot output directory (relative to project root)
PLOT_DIR = 'OffClusBandit/analysis/plot_figs'

# numpy>=2.0 compatibility shim for libraries expecting numpy.warnings
try:
    import warnings as _warnings  # noqa: F401
    if not hasattr(np, "warnings"):
        np.warnings = _warnings  # type: ignore[attr-defined]
except Exception:
    pass

def load_user_vectors(file_path: str) -> np.ndarray:
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Data file not found: {file_path}")
    U = np.load(file_path)
    if U.ndim != 2:
        raise ValueError(f"Invalid data shape: {U.shape}")
    return U

def pca_to_2d(vectors: np.ndarray, random_state: int = 42):
    from sklearn.decomposition import PCA
    pca = PCA(n_components=2, random_state=random_state)
    points_2d = pca.fit_transform(vectors)
    return points_2d, pca

def kmeans_labels(vectors: np.ndarray, n_clusters: int, random_state: int = 42):
    model = KMeans(n_clusters=n_clusters, n_init=10, random_state=random_state)
    labels = model.fit_predict(vectors)
    centers = model.cluster_centers_
    return labels, centers

def xmeans_labels(vectors: np.ndarray, random_state: int = 42, init_k: int = 5, n_init: int = 5, kmax: int = 30):
    """Cluster with XMeans (auto-select number of clusters) without feature scaling.
    Tries multiple random initializations and returns the best run by within-cluster SSE in original space.
    Args:
        vectors: (N,d) data in original feature space
        random_state: base seed
        init_k: initial number of seeds for xmeans (>=2)
        n_init: number of random initializations
        kmax: maximum number of clusters to grow to
    Returns:
        labels: (N,)
        centers: (k,d) in original space
    """
    try:
        from pyclustering.cluster.xmeans import xmeans
        from pyclustering.cluster.center_initializer import kmeans_plusplus_initializer
    except Exception as e:
        raise ImportError("pyclustering is required for XMeans. Please `pip install pyclustering`.") from e

    X = np.ascontiguousarray(vectors, dtype=float)

    def run_once(seed: int):
        # initialize with kmeans++ centers
        init_centers = kmeans_plusplus_initializer(X, max(2, init_k), random_state=seed).initialize()
        # try ccore; fallback to python
        try:
            xm = xmeans(X, init_centers, kmax=kmax, ccore=True)
            xm.process()
        except Exception:
            xm = xmeans(X, init_centers, kmax=kmax, ccore=False)
            xm.process()
        clists = xm.get_clusters()
        if not clists:
            raise ValueError("XMeans produced no clusters")
        # build labels and centers
        lbls = np.empty(X.shape[0], dtype=int)
        cents = []
        for cid, idxs in enumerate(clists):
            lbls[idxs] = cid
            cents.append(X[idxs].mean(axis=0))
        cents = np.vstack(cents)
        # compute within-cluster SSE in original space
        sse = 0.0
        for cid, idxs in enumerate(clists):
            diffs = X[idxs] - cents[cid]
            sse += float(np.sum(diffs * diffs))
        return lbls, cents, sse

    best = None
    rng = np.random.RandomState(random_state)
    for t in range(max(1, n_init)):
        seed = int(rng.randint(0, 10_000_000))
        try:
            lbls, cents, sse = run_once(seed)
            if (best is None) or (sse < best[2]):
                best = (lbls, cents, sse)
        except Exception:
            continue
    if best is None:
        raise RuntimeError("All XMeans initializations failed")
    return best[0], best[1]

def plot_clusters(ax, points_2d: np.ndarray, labels: np.ndarray, centers_2d: np.ndarray = None, title: str = ""):
    num_clusters = len(np.unique(labels))
    scatter = ax.scatter(points_2d[:, 0], points_2d[:, 1], c=labels, s=8, cmap='tab20', alpha=0.8)
    if centers_2d is None:
        centers_2d = []
        for k in range(num_clusters):
            pts_k = points_2d[labels == k]
            if len(pts_k) > 0:
                centers_2d.append(pts_k.mean(axis=0))
        centers_2d = np.array(centers_2d)
    if centers_2d is not None and len(centers_2d) > 0:
        ax.scatter(centers_2d[:, 0], centers_2d[:, 1], c=range(len(centers_2d)), cmap='tab20', s=80, marker='X', edgecolors='k')
    ax.set_title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    return scatter

def plot_svd_elbow(singular_values, selected_d=20, begin_d=5, max_d=30, dataset_name='Dataset', save_path=None):
    
    variance = singular_values ** 2
    variance_norm = variance / np.sum(variance)
    dimensions = np.arange(begin_d, max_d + 1)
    normalized_variance = variance_norm[begin_d:max_d + 1]
    sns.set(style="whitegrid")
    plt.figure(figsize=(5, 4))
    plt.plot(dimensions, normalized_variance, marker='o', color='b', label='Normalized Variance')
    plt.xlabel('Dimension (d)')  # x-axis label in English
    plt.ylabel('Normalized Variance')  # y-axis label in English
    plt.title(f'Elbow Method for SVD Dimensions ({dataset_name})')  # Title in English
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, format='pdf', bbox_inches='tight')  # Save as PDF
        print(f"Plot saved to {save_path}")
    # plt.show()

def determine_optimal_dimension(singular_values, threshold=0.9):
    cumulative_variance = np.cumsum(singular_values ** 2)
    total_variance = cumulative_variance[-1]
    variance_ratio = cumulative_variance / total_variance
    optimal_d = np.searchsorted(variance_ratio, threshold) + 1
    print(f"Selected dimension d: {optimal_d} based on {threshold * 100}% cumulative variance.")
    return optimal_d

def extract_user_features(num_users, filename, selected_d=20, plot_elbow=True, save_plot=False, plot_save_path=None,
                          dataset_name='Dataset'):
    X = np.load(filename)
    A1 = X[:num_users, :]
    u, s, vt = np.linalg.svd(A1, full_matrices=False)
    if plot_elbow:
        plot_svd_elbow(s, selected_d=selected_d, dataset_name=dataset_name, save_path=plot_save_path)
    return selected_d

def calculate_min_gap(cluster_centers):
    m = len(cluster_centers)
    min_gap = float('inf')
    for i in range(m):
        for j in range(i + 1, m):
            distance = np.linalg.norm(cluster_centers[i] - cluster_centers[j])
            if distance < min_gap:
                min_gap = distance
    return min_gap

def perform_kmeans(num_users, n_clusters, feature_filename):
    U = np.load(feature_filename)
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    kmeans.fit(U)
    min_gap = calculate_min_gap(kmeans.cluster_centers_)
    print(f"Minimum gap between cluster centers: {min_gap}")
    thetas = kmeans.cluster_centers_[kmeans.labels_]
    print(f"Thetas shape: {thetas.shape}")
    return thetas, kmeans.cluster_centers_

def plot_kmeans_elbow(U, max_clusters=20, save_path=None):
    inertias = []
    cluster_range = range(1, max_clusters + 1)
    for k in cluster_range:
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(U)
        inertias.append(kmeans.inertia_)
    sns.set(style="whitegrid")
    plt.figure(figsize=(8, 5))
    sns.lineplot(x=list(cluster_range), y=inertias, marker='o', color='g', label='Inertia')
    plt.xlabel('Number of Clusters (k)', fontsize=50, fontweight='bold')  # x-axis label in English
    plt.ylabel('Inertia', fontsize=18, fontweight='bold')  # y-axis label in English
    plt.title('Elbow Method for KMeans', fontsize=18, fontweight='bold')  # Title in English
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=14)
    plt.legend(fontsize=14)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, format='pdf')  # Save as PDF
        print(f"Elbow plot saved to {save_path}")
    plt.show()

if __name__ == "__main__":
    generate_theta = False
    if generate_theta:
        # Ensure output directory exists for elbow plots
        os.makedirs(PLOT_DIR, exist_ok=True)
        num_users_yelp = 1000
        filename_yelp = 'yelp_1000user_1000item.npy'
        selected_d_yelp = 20
        plot_save_path_yelp = f'{PLOT_DIR}/yelp_svd_elbow.pdf'  # Save as PDF
        dataset_name_yelp = 'Yelp'

        num_users_ml = 1000
        filename_ml = 'ml_1000user_1000item.npy'
        selected_d_ml = 20
        plot_save_path_ml = f'{PLOT_DIR}/ml_svd_elbow.pdf'  # Save as PDF
        dataset_name_ml = 'MovieLens'

        extract_user_features(
            num_users=num_users_yelp,
            filename=filename_yelp,
            selected_d=selected_d_yelp,
            plot_elbow=True,
            save_plot=True,
            plot_save_path=plot_save_path_yelp,
            dataset_name=dataset_name_yelp,
        )

        extract_user_features(
            num_users=num_users_ml,
            filename=filename_ml,
            selected_d=selected_d_ml,
            plot_elbow=True,
            save_plot=True,
            plot_save_path=plot_save_path_ml,
            dataset_name=dataset_name_ml
        )

    # =============================
    # Visualization: KMeans clustering scatter after PCA to 2D
    # - Load pre-extracted user vectors (same sources as plot_vector.py)
    # - Run KMeans clustering
    # - Reduce vectors to 2D via PCA and plot
    # - Save figures as PDF (for paper/report)
    # =============================
    try:
        # Load vectors (adjust paths if needed)
        U_ml = load_user_vectors('OffClusBandit/data/datasets/ml_1000user_d20.npy')
        U_yelp = load_user_vectors('OffClusBandit/data/datasets/yelp_1000user_d20.npy')

        # Optional sampling to avoid overcrowded points and improve speed
        rng = np.random.RandomState(42)
        def sample_rows(X: np.ndarray, k: int) -> np.ndarray:
            if k is None or k >= X.shape[0]:
                return X
            idx = rng.choice(X.shape[0], size=k, replace=False)
            return X[idx]

        sample_n = 1000
        U_ml_s = sample_rows(U_ml, sample_n)
        U_yelp_s = sample_rows(U_yelp, sample_n)

        # Clustering: prefer XMeans (auto k) with more seeds and restarts; fallback to HDBSCAN or KMeans
        auto_clusters = True
        if auto_clusters:
            try:
                ml_labels, ml_centers = xmeans_labels(U_ml_s, random_state=42, init_k=8, n_init=5, kmax=30)
                yelp_labels, yelp_centers = xmeans_labels(U_yelp_s, random_state=43, init_k=8, n_init=5, kmax=30)
                print("XMeans clustering successful for MovieLens and Yelp")
            except Exception as e:
                # Try HDBSCAN if available
                try:
                    import hdbscan  # type: ignore
                    def hdbscan_labels(X: np.ndarray, min_cluster_size: int = 15, min_samples: int | None = None):
                        clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples)
                        lbls = clusterer.fit_predict(X)
                        # compute centers for non-noise labels
                        centers = []
                        valid = np.where(lbls >= 0)[0]
                        if valid.size == 0:
                            raise ValueError("HDBSCAN found only noise")
                        for cid in np.unique(lbls[valid]):
                            idxs = np.where(lbls == cid)[0]
                            centers.append(X[idxs].mean(axis=0))
                        return lbls, np.vstack(centers)
                    ml_labels, ml_centers = hdbscan_labels(U_ml_s, min_cluster_size=max(10, U_ml_s.shape[0] // 100))
                    yelp_labels, yelp_centers = hdbscan_labels(U_yelp_s, min_cluster_size=max(10, U_yelp_s.shape[0] // 100))
                    print("HDBSCAN clustering successful for MovieLens and Yelp")
                except Exception:
                    # Fallback to fixed KMeans
                    k_fixed = 10
                    ml_labels, ml_centers = kmeans_labels(U_ml_s, k_fixed, random_state=42)
                    yelp_labels, yelp_centers = kmeans_labels(U_yelp_s, k_fixed, random_state=42)
                    print(f"Fallback to KMeans(k={k_fixed}) for MovieLens and Yelp")
        else:
            k_fixed = 10
            ml_labels, ml_centers = kmeans_labels(U_ml_s, k_fixed, random_state=42)
            yelp_labels, yelp_centers = kmeans_labels(U_yelp_s, k_fixed, random_state=42)

        # PCA to 2D
        ml_2d, ml_pca = pca_to_2d(U_ml_s, random_state=42)
        yelp_2d, yelp_pca = pca_to_2d(U_yelp_s, random_state=42)

        # Project cluster centers to the same PCA space
        ml_centers_2d = ml_pca.transform(ml_centers)
        yelp_centers_2d = yelp_pca.transform(yelp_centers)

        # Plot and save
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        k_ml = len(np.unique(ml_labels))
        k_yelp = len(np.unique(yelp_labels))
        plot_clusters(axes[0], ml_2d, ml_labels, ml_centers_2d, title=f"MovieLens vector clustering (k={k_ml})")
        plot_clusters(axes[1], yelp_2d, yelp_labels, yelp_centers_2d, title=f"Yelp vector clustering (k={k_yelp})")
        plt.tight_layout()
        os.makedirs(PLOT_DIR, exist_ok=True)
        ml_out = f'{PLOT_DIR}/ml_kmeans_pca.pdf'
        yelp_out = f'{PLOT_DIR}/yelp_kmeans_pca.pdf'
        plt.savefig(f'{PLOT_DIR}/kmeans_pca_both.pdf', format='pdf', bbox_inches='tight')
        print(f"Saved PCA clustering figures to {PLOT_DIR}/kmeans_pca_both.pdf")

        # -----------------------------
        # Visualization: KMeans clustering with t-SNE to 2D
        # - Run t-SNE embedding on sampled vectors
        # - Plot and save side-by-side figures
        # -----------------------------
        tsne_ml = TSNE(n_components=2, perplexity=30.0, n_iter=1000, learning_rate='auto', init='pca', random_state=42)
        tsne_yelp = TSNE(n_components=2, perplexity=30.0, n_iter=1000, learning_rate='auto', init='pca', random_state=42)

        ml_tsne_2d = tsne_ml.fit_transform(U_ml_s)
        yelp_tsne_2d = tsne_yelp.fit_transform(U_yelp_s)

        fig2, axes2 = plt.subplots(1, 2, figsize=(12, 5))
        plot_clusters(axes2[0], ml_tsne_2d, ml_labels, centers_2d=None, title=f"MovieLens t-SNE (k={k_ml})")
        plot_clusters(axes2[1], yelp_tsne_2d, yelp_labels, centers_2d=None, title=f"Yelp t-SNE (k={k_yelp})")
        plt.tight_layout()
        plt.savefig(f'{PLOT_DIR}/kmeans_tsne_both.pdf', format='pdf', bbox_inches='tight')
        print(f"Saved t-SNE clustering figures to {PLOT_DIR}/kmeans_tsne_both.pdf")
    except Exception as e:
        # Visualization is non-critical; do not stop main pipeline on failure
        print(f"Visualization failed: {e}")
        
