import torch
import numpy as np
import os
import umap
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from sklearn.ensemble import IsolationForest
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)

set_seed(42)

def visualize_latent_space(all_latents, topk_indices, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    all_latents = torch.from_numpy(all_latents).clone().detach()
    latents_flattened = [latent.cpu().numpy().flatten() for latent in all_latents]
    latents_np = np.array(latents_flattened)

    pca = PCA(n_components=2)
    latents_pca = pca.fit_transform(latents_np)

    mask = np.zeros(len(latents_np), dtype=bool)
    mask[topk_indices] = True
    main_points = latents_pca[~mask]
    highlight_points = latents_pca[mask]

    plt.figure(figsize=(10, 8))
    plt.scatter(main_points[:, 0], main_points[:, 1], 
                c='royalblue', alpha=0.6, 
                label='General Samples', s=40)
    plt.scatter(highlight_points[:, 0], highlight_points[:, 1], 
                c='crimson', alpha=0.8, 
                label='TopK Samples', s=80, edgecolor='gold', linewidth=0.5)
    plt.title('Latent Space Visualization with TopK Highlight', fontsize=14)
    plt.xlabel('Principal Component 1', fontsize=12)
    plt.ylabel('Principal Component 2', fontsize=12)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.savefig(f"{save_dir}/latent_space_pca_highlight.png", 
               dpi=150, bbox_inches='tight')
    plt.close()

def visualize_latent_space_with_umap(all_latents, all_labels, save_dir,
                                     topk_indices=None, highlight_label=None, num="00"):
    os.makedirs(save_dir, exist_ok=True)
    all_latents = torch.from_numpy(all_latents).clone().detach()
    latents_flattened = [latent.cpu().numpy().flatten() for latent in all_latents]
    latents_np = np.array(latents_flattened)
    all_labels = np.array(all_labels)
    # Dimensionality reduction
    min_dim = min(latents_np.shape[0], latents_np.shape[1]) 
    n_components = min(50, min_dim)
    pca = PCA(n_components=n_components)
    latents_pca = pca.fit_transform(latents_np)
    reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1, n_jobs=-1)
    latents_umap = reducer.fit_transform(latents_pca)

    # color map
    unique_labels = np.unique(all_labels)
    # colors = plt.cm.tab10(np.linspace(0, 1, len(unique_labels)))
    # label_color_map = {label: color for label, color in zip(unique_labels, colors)}
    cmap = plt.cm.tab20 if len(unique_labels) > 10 else plt.cm.tab10
    colors = cmap(np.linspace(0, 1, len(unique_labels)))
    label_color_map = {label: colors[i] for i, label in enumerate(unique_labels)}

    plt.figure(figsize=(12, 10))

    topk_mask = np.zeros(len(all_labels), dtype=bool)
    if topk_indices is not None:
        topk_mask[topk_indices] = True
    for label in unique_labels:
        if highlight_label is not None and label == highlight_label:
            continue  
        label_mask = (all_labels == label)
        main_points = latents_umap[label_mask & ~topk_mask] 
        plt.scatter(main_points[:, 0], main_points[:, 1], 
                    c=[label_color_map[label]], alpha=0.6, 
                    label=f'Label {label}', s=10)
        
    if highlight_label is not None and highlight_label in unique_labels:
        highlight_mask = (all_labels == highlight_label)
        highlight_points = latents_umap[highlight_mask]
        plt.scatter(highlight_points[:, 0], highlight_points[:, 1], 
                    c='purple', alpha=0.8,
                    label=f'Generated (Label={highlight_label})', s=50, 
                    marker='*', edgecolor='black', linewidth=1.5)

    if topk_indices is not None:
        topk_points = latents_umap[topk_mask]
        topk_labels = all_labels[topk_mask]
        for label in unique_labels:
            if highlight_label is not None and label == highlight_label:
                continue  
            label_mask = (topk_labels == label)
            points = topk_points[label_mask]
            plt.scatter(points[:, 0], points[:, 1], 
                        c=[label_color_map[label]], alpha=0.8, 
                        label=f'TopK Label {label}', s=50, 
                        edgecolor='red', linewidth=1.5, marker='o')

    plt.title('Latent Space Visualization with Labels and TopK Highlight', fontsize=14)
    plt.xlabel('UMAP Component 1', fontsize=12)
    plt.ylabel('UMAP Component 2', fontsize=12)
    plt.grid(alpha=0.2)
    plt.legend(loc='best', bbox_to_anchor=(1.05, 1), fontsize=10)
    plt.tight_layout()
    plt.savefig(f"{save_dir}/latent_space_umap_{num}.png", 
                dpi=150, bbox_inches='tight')
    plt.close()

def visualize_latent_space_outlier(
    all_latents, all_labels, save_dir, num="00", topk_indices=[],
    highlight_label=None, remove_outliers=False,
    outlier_method='dbscan', **outlier_params
):

    def process_data(data):
        return data.cpu().numpy() if isinstance(data, torch.Tensor) else np.array(data)

    all_latents = process_data(all_latents)
    all_labels = process_data(all_labels)
    topk_latents = all_latents[topk_indices]
    topk_labels = all_labels[topk_indices]

    combined_latents = np.vstack([all_latents, topk_latents])
    combined_latents = combined_latents.reshape(combined_latents.shape[0], -1) 
    pca = PCA(n_components=min(50, combined_latents.shape[1]))
    reducer = umap.UMAP(n_components=2, random_state=42)
    latents_umap = reducer.fit_transform(pca.fit_transform(combined_latents))

    split_idx = len(all_latents)
    all_umap = latents_umap[:split_idx]
    topk_umap = latents_umap[split_idx:]

    if remove_outliers:
        if outlier_method == 'dbscan':
            detector = DBSCAN(**outlier_params)
            clusters = detector.fit_predict(all_umap)
            inlier_mask = clusters != -1
        elif outlier_method == 'isolation_forest':
            detector = IsolationForest(**outlier_params)
            inlier_mask = detector.fit_predict(all_umap) == 1
        
        all_umap = all_umap[inlier_mask]
        all_labels = all_labels[inlier_mask]

    unique_labels = np.unique(np.concatenate([all_labels, topk_labels]))
    cmap = plt.cm.tab20 if len(unique_labels) > 10 else plt.cm.tab10
    colors = cmap(np.linspace(0, 1, len(unique_labels)))
    label_colors = {label: colors[i] for i, label in enumerate(unique_labels)}

    plt.figure(figsize=(14, 12))

    for label in unique_labels:
        if label == highlight_label:
            continue  
        mask = (all_labels == label)
        if np.sum(mask) > 0:
            plt.scatter(all_umap[mask, 0], all_umap[mask, 1],
                        c=[label_colors[label]], alpha=0.6, s=15,
                        edgecolor='white', linewidth=0.5,
                        label=f'Label {label}')

    if highlight_label is not None:
        highlight_mask = (all_labels == highlight_label)
        if np.sum(highlight_mask) > 0:
            plt.scatter(all_umap[highlight_mask, 0], all_umap[highlight_mask, 1],
                        c=[label_colors[highlight_label]], marker='*', s=50,
                        edgecolor='black', linewidth=1.2,
                        zorder=3, label=f'Highlight ({highlight_label})')

    for label in unique_labels:
        mask = (topk_labels == label)
        if np.sum(mask) == 0:
            continue
        
        if label == highlight_label:
            continue
        else:
            plt.scatter(topk_umap[mask, 0], topk_umap[mask, 1],
                        c=[label_colors[label]], s=40,
                        edgecolor='red', linewidth=1.5,
                        zorder=2, label=f'TopK ({label})')
    title = "Latent Space Visualization"
    if remove_outliers:
        title += " (Filtered)"
    plt.title(title, fontsize=14)
    plt.xlabel('UMAP-1', fontsize=12)
    plt.ylabel('UMAP-2', fontsize=12)
    plt.grid(alpha=0.3)

    handles, labels = plt.gca().get_legend_handles_labels()
    unique_legend = {}
    for h, l in zip(handles, labels):
        if l not in unique_legend:
            unique_legend[l] = h
    plt.legend(unique_legend.values(), unique_legend.keys(),
               loc='upper left', bbox_to_anchor=(1, 1),
               fontsize=8, ncol=2 if len(unique_legend) > 15 else 1)

    filename = "latent_space"
    if remove_outliers:
        filename += "_filtered"
    if highlight_label is not None:
        filename += f"_highlight_{num}"
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, f"{filename}.png"),
                dpi=200, bbox_inches='tight')
    plt.close()