import os
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Set backend before importing pyplot
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
import warnings
#to visualize the features of the clip model and save it as a png file.
def load_data(features_path, labels_path):
    features = np.load(features_path)
    labels = np.load(labels_path)
    return features, labels

def _compute_tsne(features: np.ndarray) -> np.ndarray:
    n = features.shape[0]
    # Adaptive perplexity: grows with sample size
    if n < 500:
        perplexity = max(5, min(30, n // 3))
    elif n < 5000:
        perplexity = 40
    else:
        perplexity = 50
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        tsne = TSNE(
            n_components=2,
            random_state=42,
            init='pca',
            learning_rate='auto',
            perplexity=perplexity,
            n_iter=1500,
            early_exaggeration=12.0,
            n_iter_without_progress=400,
            angle=0.5,
            verbose=0,
        )
        return tsne.fit_transform(features)


def _draw_cov_ellipse(ax, points: np.ndarray, color: tuple, alpha: float = 0.15, scale: float = 2.0):
    if points.shape[0] < 3:
        return
    mean = points.mean(axis=0)
    cov = np.cov(points, rowvar=False)
    # Eigendecomposition to find principal axes
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals, vecs = vals[order], vecs[:, order]
    theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
    width, height = 2 * scale * np.sqrt(np.maximum(vals, 1e-12))
    ell = Ellipse(xy=mean, width=width, height=height, angle=theta, facecolor=color, edgecolor=color, alpha=alpha, lw=1.5)
    ax.add_patch(ell)


def visualize_features(features, labels, class_names=None, highlight_ratio: float = 0.01, point_size: int = 12, out_path: str = 'features_visualization.png'):
    # Unified style and font (with fallback)
    try:
        if 'seaborn-v0_8-white' in plt.style.available:
            plt.style.use('seaborn-v0_8-white')
        elif 'seaborn-whitegrid' in plt.style.available:
            plt.style.use('seaborn-whitegrid')
        elif 'seaborn' in plt.style.available:
            plt.style.use('seaborn')
        elif 'ggplot' in plt.style.available:
            plt.style.use('ggplot')
        else:
            plt.style.use('default')
    except Exception:
        pass
    plt.rcParams.update({
        'figure.dpi': 150,
        'savefig.dpi': 400,
        'font.size': 22,
        'figure.titlesize': 22,
        'axes.titlesize': 22,
        'axes.labelsize': 22,
        'legend.fontsize': 22,
        'xtick.labelsize': 22,
        'ytick.labelsize': 22,
    })

    # L2 normalization (minimal impact if already normalized)
    eps = 1e-12
    norms = np.linalg.norm(features, axis=1, keepdims=True) + eps
    feats_norm = features / norms

    # t-SNE dimensionality reduction
    feats_2d = _compute_tsne(feats_norm)

    fig, ax = plt.subplots(figsize=(13.5, 9))
    unique_labels = np.unique(labels)
    cmap = plt.cm.get_cmap('tab20', len(unique_labels))

    # Scatter + class covariance ellipses
    for idx, label in enumerate(unique_labels):
        mask = (labels == label)
        pts = feats_2d[mask]
        color = cmap(idx)
        ax.scatter(pts[:, 0], pts[:, 1], s=point_size, color=color, alpha=0.75, label=(class_names[label] if class_names and label < len(class_names) else str(label)), edgecolors='none')
        _draw_cov_ellipse(ax, pts, color=color, alpha=0.12, scale=2.5)

    # Use kNN distance to highlight locally sparse/different points
    try:
        k = min(15, max(5, feats_2d.shape[0] // 100))
        nn = NearestNeighbors(n_neighbors=k + 1, metric='euclidean')
        nn.fit(feats_2d)
        dists, _ = nn.kneighbors(feats_2d)
        # Exclude self-distance (column 0 is 0)
        local_density_score = dists[:, 1:].mean(axis=1)
        top_k = max(1, int(highlight_ratio * feats_2d.shape[0]))
        top_idx = np.argsort(local_density_score)[-top_k:]
        ax.scatter(feats_2d[top_idx, 0], feats_2d[top_idx, 1], s=point_size * 3, facecolors='none', edgecolors='black', linewidths=1.5, label=f'Top {int(highlight_ratio*100)}% dissimilar')
    except Exception:
        pass

    # Aesthetic processing
    ax.set_title('CLIP Feature Landscape — t-SNE 2D (local neighborhoods preserved)')
    ax.set_xlabel('t-SNE axis 1')
    ax.set_ylabel('t-SNE axis 2')
    ax.grid(alpha=0.15, linestyle='--', linewidth=0.5)
    # Place legend outside to avoid overlap
    ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0), borderaxespad=0.0, frameon=False)
    plt.tight_layout()

    # Save high-resolution image
    fig.savefig(out_path, bbox_inches='tight')
    # Save another ultra-high-resolution version
    root, ext = os.path.splitext(out_path)
    fig.savefig(root + '_hq' + ext, bbox_inches='tight', dpi=600)
    print(f"Saved visualization to {out_path} and {root + '_hq' + ext}")


def visualize_pca_overview(features: np.ndarray, labels: np.ndarray, class_names=None, point_size: int = 12, out_path: str = 'features_pca.png'):
    # Normalize + PCA to 2D to show global structure
    eps = 1e-12
    feats = features / (np.linalg.norm(features, axis=1, keepdims=True) + eps)
    pca = PCA(n_components=2, random_state=42)
    emb = pca.fit_transform(feats)

    # Style fallback
    try:
        if 'seaborn-whitegrid' in plt.style.available:
            plt.style.use('seaborn-whitegrid')
        elif 'seaborn' in plt.style.available:
            plt.style.use('seaborn')
        else:
            plt.style.use('default')
    except Exception:
        pass

    plt.rcParams.update({
        'figure.dpi': 150,
        'savefig.dpi': 400,
        'font.size': 22,
        'figure.titlesize': 22,
        'axes.titlesize': 22,
        'axes.labelsize': 22,
        'legend.fontsize': 22,
        'xtick.labelsize': 22,
        'ytick.labelsize': 22,
    })
    fig, ax = plt.subplots(figsize=(13.5, 9))
    unique_labels = np.unique(labels)
    cmap = plt.cm.get_cmap('tab20', len(unique_labels))
    for idx, label in enumerate(unique_labels):
        mask = (labels == label)
        pts = emb[mask]
        color = cmap(idx)
        ax.scatter(pts[:, 0], pts[:, 1], s=point_size, color=color, alpha=0.8, label=(class_names[label] if class_names and label < len(class_names) else str(label)), edgecolors='none')
    var_ratio = getattr(pca, 'explained_variance_ratio_', None)
    if var_ratio is not None and len(var_ratio) >= 2:
        pc1_label = f"PC 1 ({var_ratio[0]*100:.1f}% var)"
        pc2_label = f"PC 2 ({var_ratio[1]*100:.1f}% var)"
    else:
        pc1_label, pc2_label = 'PC 1', 'PC 2'
    ax.set_title('CLIP Feature Overview — PCA 2D (global structure)')
    ax.set_xlabel(pc1_label)
    ax.set_ylabel(pc2_label)
    ax.grid(alpha=0.15, linestyle='--', linewidth=0.5)
    ax.legend(loc='upper left', bbox_to_anchor=(1.02, 1.0), frameon=False)
    plt.tight_layout()
    fig.savefig(out_path, bbox_inches='tight')
    root, ext = os.path.splitext(out_path)
    fig.savefig(root + '_hq' + ext, bbox_inches='tight', dpi=600)
    print(f"Saved PCA overview to {out_path} and {root + '_hq' + ext}")


def visualize_similarity_heatmap(features: np.ndarray, out_path: str = 'features_similarity_heatmap.png'):
    # Normalize + cosine similarity matrix
    eps = 1e-12
    feats = features / (np.linalg.norm(features, axis=1, keepdims=True) + eps)
    sim = np.clip(np.dot(feats, feats.T), -1.0, 1.0)  # Cosine similarity

    # Sort (hierarchical clustering preferred, fallback to PCA 1D sorting)
    order = np.arange(sim.shape[0])
    try:
        import scipy.spatial.distance as ssd
        import scipy.cluster.hierarchy as sch
        # Convert to distance matrix (1 - sim), then hierarchical clustering
        dist = 1.0 - sim
        condensed = ssd.squareform(dist, checks=False)
        Z = sch.linkage(condensed, method='average')
        order = sch.leaves_list(sch.dendrogram(Z, no_plot=True))
    except Exception:
        # Fallback: sort by first principal component
        pca1 = PCA(n_components=1, random_state=42).fit_transform(feats).ravel()
        order = np.argsort(pca1)

    sim_sorted = sim[np.ix_(order, order)]
    sim_min = float(np.min(sim_sorted))
    sim_max = float(np.max(sim_sorted))
    if abs(sim_max - sim_min) < 1e-6:
        sim_max = sim_min + 1e-6  # Avoid rendering issues when vmin==vmax

    # Style fallback (simple safe version)
    try:
        plt.style.use('seaborn-white') if 'seaborn-white' in plt.style.available else plt.style.use('default')
    except Exception:
        pass

    plt.rcParams.update({
        'figure.dpi': 150,
        'savefig.dpi': 400,
        'font.size': 22,
        'figure.titlesize': 22,
        'axes.titlesize': 22,
        'axes.labelsize': 22,
        'legend.fontsize': 22,
        'xtick.labelsize': 22,
        'ytick.labelsize': 22,
    })
    fig, ax = plt.subplots(figsize=(12.5, 10))
    # Colorbar range adapts to actual min/max values, from min to max
    im = ax.imshow(sim_sorted, cmap='viridis', vmin=sim_min, vmax=sim_max, interpolation='nearest', aspect='auto')
    ax.set_title('Pairwise Cosine Similarity')
    ax.set_xlabel('Samples')
    ax.set_ylabel('Samples')
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label('Cosine similarity (higher = more similar)', fontsize=22)
    # Sample ticks: from min to max, equally spaced 6 ticks
    try:
        ticks = np.linspace(sim_min, sim_max, num=6)
        cbar.set_ticks(ticks)
        cbar.set_ticklabels([f"{t:.2f}" for t in ticks])
        cbar.ax.tick_params(labelsize=12)
    except Exception:
        pass
    ax.set_xticks([])
    ax.set_yticks([])
    plt.tight_layout()
    fig.savefig(out_path, bbox_inches='tight')
    root, ext = os.path.splitext(out_path)
    fig.savefig(root + '_hq' + ext, bbox_inches='tight', dpi=600)
    print(f"Saved similarity heatmap to {out_path} and {root + '_hq' + ext}")

def print_label_statistics(labels):
    """Print label statistics"""
    unique_labels, counts = np.unique(labels, return_counts=True)
    print("Label Distribution:")
    for label, count in zip(unique_labels, counts):
        print(f"Label {label}: {count} samples")
    print(f"Total {len(labels)} samples, {len(unique_labels)} unique labels")

# Usage example:
if __name__ == "__main__":
    features, labels = load_data("clip_output_features.npy", "clip_output_labels.npy")
    print_label_statistics(labels)  # New statistical output
    visualize_features(features, labels)