import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from sklearn.manifold import TSNE
import seaborn as sns
import random

class SemanticCorrectionVisualizer:
    """PureZero语义校正可视化器 - 优化版"""

    def __init__(self, trainer, figsize=(18, 14), dpi=300):
        self.trainer = trainer
        self.device = trainer.device
        self.figsize = figsize
        self.dpi = dpi

        # 颜色配置：增强对比度
        self.colors = {
            'seen_base': '#1E88E5',      # 鲜艳蓝
            'unseen_base': '#FF6D00',    # 鲜艳橙 (接近红)

            'seen_samples': '#64B5F6',   # 浅蓝
            'unseen_samples': '#FFB74D', # 浅橙

            'seen_proto_orig': '#0D47A1', # 深蓝边框
            'unseen_proto_orig': '#BF360C', # 深红/橙边框

            'seen_centroid': '#42A5F5',   # 中蓝
            'unseen_centroid': '#FF9800', # 中橙

            'seen_proto_corr': '#01579B',   # 极深蓝
            'unseen_proto_corr': '#E65100', # 极深橙
        }

        self.markers = {
            'samples': 'o',           # 改为实心圆点
            'proto_orig': 'o',        # 空心圆
            'centroid': '*',          # 星形
            'proto_corr': 'D',        # 菱形
        }

        # 调整大小：样本点稍微大一点，原型点保持醒目
        self.sizes = {
            'samples': 45,            # [修改] 调大样本点，增加密度感
            'proto_orig': 220,
            'centroid': 380,
            'proto_corr': 260,
        }

    def collect_data_for_tsne(self, all_test_embeds, all_test_labels,
                                att_original, att_corrected,
                                n_samples_per_class=20,     # [修改] 增加每个类的样本数
                                max_classes_to_plot=25):    # [重要修改] 增加类别数到25，找回聚类结构
        """
        收集用于t-SNE可视化的数据
        """
        print("=" * 80)
        print("Collecting t-SNE visualization data...")

        # 0. 筛选类别
        # 增加到 25 个已见类 + 25 个未见类 = 50 个类，这能形成很好的 Cluster 结构
        n_seen = len(self.trainer.seenclasses)
        n_sel_seen = min(n_seen, max_classes_to_plot)
        perm_seen = torch.randperm(n_seen)[:n_sel_seen]
        selected_seen_classes = self.trainer.seenclasses[perm_seen].cpu()

        n_unseen = len(self.trainer.unseenclasses)
        n_sel_unseen = min(n_unseen, max_classes_to_plot)
        perm_unseen = torch.randperm(n_unseen)[:n_sel_unseen]
        selected_unseen_classes = self.trainer.unseenclasses[perm_unseen].cpu()

        print(f"选取可视化类别: {n_sel_seen} Seen + {n_sel_unseen} Unseen (Total {n_sel_seen + n_sel_unseen})")

        # L2 normalize
        all_test_embeds = F.normalize(all_test_embeds, dim=1)
        att_original = F.normalize(att_original, dim=1)
        att_corrected = F.normalize(att_corrected, dim=1)

        # 1. Sample test data
        seen_samples, seen_labels, unseen_samples, unseen_labels = \
            self._sample_test_data(all_test_embeds, all_test_labels, n_samples_per_class,
                                 selected_seen_classes, selected_unseen_classes)

        # 2. Extract class prototypes
        seen_proto_orig = att_original[selected_seen_classes].cpu()
        unseen_proto_orig = att_original[selected_unseen_classes].cpu()
        seen_proto_corr = att_corrected[selected_seen_classes].cpu()
        unseen_proto_corr = att_corrected[selected_unseen_classes].cpu()

        # 3. Compute pseudo-centroids
        seen_centroids, seen_centroid_classes, unseen_centroids, unseen_centroid_classes = \
            self._compute_pseudo_centroids(all_test_embeds, selected_seen_classes, selected_unseen_classes)

        data_dict = {
            'seen_samples': seen_samples,
            'seen_sample_labels': seen_labels,
            'unseen_samples': unseen_samples,
            'unseen_sample_labels': unseen_labels,
            'seen_proto_orig': seen_proto_orig,
            'unseen_proto_orig': unseen_proto_orig,
            'seen_proto_corr': seen_proto_corr,
            'unseen_proto_corr': unseen_proto_corr,
            'seen_centroids': seen_centroids,
            'seen_centroid_classes': seen_centroid_classes,
            'unseen_centroids': unseen_centroids,
            'unseen_centroid_classes': unseen_centroid_classes,
        }

        return data_dict

    def _sample_test_data(self, all_test_embeds, all_test_labels, n_samples_per_class,
                          selected_seen_classes, selected_unseen_classes):
        """采样逻辑保持不变，确保取到的是tensor"""
        seen_samples, seen_labels = [], []
        unseen_samples, unseen_labels = [], []

        all_test_embeds = all_test_embeds.cpu()
        all_test_labels = all_test_labels.cpu()

        def get_samples(classes, container_samples, container_labels):
            for cls in classes:
                cls_mask = (all_test_labels == cls)
                cls_indices = torch.nonzero(cls_mask).squeeze()
                if cls_indices.dim() == 0: cls_indices = cls_indices.unsqueeze(0)

                if len(cls_indices) > 0:
                    if len(cls_indices) > n_samples_per_class:
                        sampled_indices = cls_indices[torch.randperm(len(cls_indices))[:n_samples_per_class]]
                    else:
                        sampled_indices = cls_indices
                    container_samples.append(all_test_embeds[sampled_indices])
                    container_labels.extend([cls.item()] * len(sampled_indices))

        get_samples(selected_seen_classes, seen_samples, seen_labels)
        get_samples(selected_unseen_classes, unseen_samples, unseen_labels)

        seen_samples = torch.cat(seen_samples, dim=0) if seen_samples else torch.empty(0, all_test_embeds.size(1))
        unseen_samples = torch.cat(unseen_samples, dim=0) if unseen_samples else torch.empty(0, all_test_embeds.size(1))

        return seen_samples, seen_labels, unseen_samples, unseen_labels

    def _compute_pseudo_centroids(self, all_test_embeds, selected_seen_classes, selected_unseen_classes):
        """计算质心逻辑"""
        pu_scores = self.trainer.pu_learner.predict(all_test_embeds.to(self.device))

        # 严格阈值
        high_conf_seen_threshold = -0.57
        high_conf_unseen_threshold = -1.17

        high_conf_seen_mask = (pu_scores > high_conf_seen_threshold).flatten()
        high_conf_unseen_mask = (pu_scores < high_conf_unseen_threshold).flatten()

        def get_centroids(mask, target_classes_set):
            centroids = []
            centroid_classes = []
            if mask.sum() == 0: return centroids, centroid_classes

            indices = torch.nonzero(mask).squeeze()
            class_embeddings = {cls.item(): [] for cls in target_classes_set}

            # 分批处理以防OOM，这里简单处理
            # 实际上可以直接用 mask & label match
            # 为了更精确，这里模拟原文逻辑：
            # 1. 拿到高置信度样本
            # 2. 预测它们的伪标签
            # 3. 如果伪标签属于我们要画的类，就加入计算

            batch_embeds = all_test_embeds[indices].to(self.device)
            # 预测
            sim = torch.mm(batch_embeds, self.trainer.model.att.t())
            preds = torch.argmax(sim, dim=1)

            preds_cpu = preds.cpu()
            batch_embeds_cpu = batch_embeds.cpu()

            target_list = [c.item() for c in target_classes_set]

            for k, pred_cls in enumerate(preds_cpu):
                pred_item = pred_cls.item()
                if pred_item in class_embeddings:
                    class_embeddings[pred_item].append(batch_embeds_cpu[k].numpy())

            for cls in target_list:
                if len(class_embeddings[cls]) >= 2:
                    embeds = np.array(class_embeddings[cls])
                    cent = np.mean(embeds, axis=0)
                    cent = cent / np.linalg.norm(cent)
                    centroids.append(cent)
                    centroid_classes.append(cls)
            return centroids, centroid_classes

        seen_centroids_list, seen_classes_list = get_centroids(high_conf_seen_mask, selected_seen_classes)
        unseen_centroids_list, unseen_classes_list = get_centroids(high_conf_unseen_mask, selected_unseen_classes)

        seen_centroids = torch.tensor(np.array(seen_centroids_list), dtype=torch.float32) if seen_centroids_list else torch.empty(0, all_test_embeds.size(1))
        unseen_centroids = torch.tensor(np.array(unseen_centroids_list), dtype=torch.float32) if unseen_centroids_list else torch.empty(0, all_test_embeds.size(1))

        return seen_centroids, seen_classes_list, unseen_centroids, unseen_classes_list

    def perform_tsne(self, data_dict, perplexity=30):
        print("\nPerforming t-SNE...")

        all_vectors = []
        segment_info = []
        current_idx = 0

        def add_segment(key, type_name):
            nonlocal current_idx
            if key in data_dict and len(data_dict[key]) > 0:
                vecs = data_dict[key].numpy()
                all_vectors.append(vecs)
                segment_info.append({'type': type_name, 'start': current_idx, 'end': current_idx + len(vecs)})
                current_idx += len(vecs)

        # 顺序很重要：样本 -> 原型 -> 质心 -> 校正后原型
        add_segment('seen_samples', 'seen_samples')
        add_segment('unseen_samples', 'unseen_samples')
        add_segment('seen_proto_orig', 'seen_proto_orig')
        add_segment('unseen_proto_orig', 'unseen_proto_orig')
        add_segment('seen_centroids', 'seen_centroids')
        add_segment('unseen_centroids', 'unseen_centroids')
        add_segment('seen_proto_corr', 'seen_proto_corr')
        add_segment('unseen_proto_corr', 'unseen_proto_corr')

        X = np.vstack(all_vectors)
        print(f"Total points: {len(X)}")

        # [重要] 调整 perplexity
        # 样本变多了，perplexity可以稍微大一点，有助于形成紧致的簇
        # 如果 perplexity 太大，所有点会混在一起；太小，簇会太散
        adj_perplexity = min(40, len(X) // 10)
        print(f"Using perplexity: {adj_perplexity}")

        tsne = TSNE(n_components=2, perplexity=adj_perplexity, learning_rate=200,
                    n_iter=2000, random_state=42, init='pca', verbose=1)
        X_embedded = tsne.fit_transform(X)

        results = {}
        for seg in segment_info:
            results[seg['type']] = X_embedded[seg['start']:seg['end']]

        return results

    def visualize(self, tsne_results, save_path='tsne_semantic_correction.png'):
        print(f"\nGenerating visualization...")

        fig, ax = plt.subplots(figsize=self.figsize, dpi=self.dpi)

        plt.rcParams['font.family'] = 'serif'
        plt.rcParams['font.serif'] = ['Times New Roman']
        plt.rcParams['font.size'] = 16

        sns.set_style("white")
        ax.set_facecolor('white')
        for spine in ax.spines.values():
            spine.set_linewidth(1.5)
            spine.set_color('#333333')

        # 1. 画样本 (Alpha 调高到 0.6，让簇更明显)
        if 'seen_samples' in tsne_results:
            ax.scatter(tsne_results['seen_samples'][:, 0], tsne_results['seen_samples'][:, 1],
                       c=self.colors['seen_samples'], marker=self.markers['samples'],
                       s=self.sizes['samples'], alpha=0.6, edgecolors='none', rasterized=True)

        if 'unseen_samples' in tsne_results:
            ax.scatter(tsne_results['unseen_samples'][:, 0], tsne_results['unseen_samples'][:, 1],
                       c=self.colors['unseen_samples'], marker=self.markers['samples'],
                       s=self.sizes['samples'], alpha=0.6, edgecolors='none', rasterized=True)

        # 2. 画原型 (空心圆)
        # zorder 调高，确保在样本上面
        if 'seen_proto_orig' in tsne_results:
            ax.scatter(tsne_results['seen_proto_orig'][:, 0], tsne_results['seen_proto_orig'][:, 1],
                       c='white', marker=self.markers['proto_orig'], s=self.sizes['proto_orig'],
                       edgecolors=self.colors['seen_proto_orig'], linewidths=3, zorder=10)

        if 'unseen_proto_orig' in tsne_results:
            ax.scatter(tsne_results['unseen_proto_orig'][:, 0], tsne_results['unseen_proto_orig'][:, 1],
                       c='white', marker=self.markers['proto_orig'], s=self.sizes['proto_orig'],
                       edgecolors=self.colors['unseen_proto_orig'], linewidths=3, zorder=10)

        # 3. 画伪质心 (星星) - 可选，不画也可以，画了能说明“为什么往这里跑”
        # 为了简洁，可以把 alpha 调低一点
        if 'seen_centroids' in tsne_results:
            ax.scatter(tsne_results['seen_centroids'][:, 0], tsne_results['seen_centroids'][:, 1],
                       c=self.colors['seen_centroid'], marker=self.markers['centroid'], s=self.sizes['centroid'],
                       edgecolors='white', linewidths=1, zorder=11, alpha=0.8)

        if 'unseen_centroids' in tsne_results:
            ax.scatter(tsne_results['unseen_centroids'][:, 0], tsne_results['unseen_centroids'][:, 1],
                       c=self.colors['unseen_centroid'], marker=self.markers['centroid'], s=self.sizes['centroid'],
                       edgecolors='white', linewidths=1, zorder=11, alpha=0.8)

        # 4. 画校正后原型 (实心菱形)
        if 'seen_proto_corr' in tsne_results:
            ax.scatter(tsne_results['seen_proto_corr'][:, 0], tsne_results['seen_proto_corr'][:, 1],
                       c=self.colors['seen_proto_corr'], marker=self.markers['proto_corr'], s=self.sizes['proto_corr'],
                       edgecolors='white', linewidths=2, zorder=12)

        if 'unseen_proto_corr' in tsne_results:
            ax.scatter(tsne_results['unseen_proto_corr'][:, 0], tsne_results['unseen_proto_corr'][:, 1],
                       c=self.colors['unseen_proto_corr'], marker=self.markers['proto_corr'], s=self.sizes['proto_corr'],
                       edgecolors='white', linewidths=2, zorder=12)

        # 5. 画箭头 (关键修改：确保箭头清晰可见)
        arrow_params = dict(
            arrowstyle='-|>',
            mutation_scale=30,  # 箭头头再大一点
            lw=3.0,             # 线再粗一点
            alpha=0.85,         # 不透明度提高
            shrinkA=0,
            shrinkB=0
        )

        for cat, color in [('seen', self.colors['seen_proto_orig']), ('unseen', self.colors['unseen_proto_orig'])]:
            if f'{cat}_proto_orig' in tsne_results and f'{cat}_proto_corr' in tsne_results:
                orig = tsne_results[f'{cat}_proto_orig']
                corr = tsne_results[f'{cat}_proto_corr']

                for i in range(len(orig)):
                    dist = np.linalg.norm(orig[i] - corr[i])
                    # 只有移动距离明显的才画箭头，避免重叠的点看起来脏
                    if dist > 0.8:
                        ax.annotate('', xy=corr[i], xytext=orig[i],
                                    arrowprops=dict(color=color, **arrow_params), zorder=9)

        # 6. 图例
        legend_elements = [
            Line2D([0], [0], marker='o', color='w', label='Seen Classes',
                   markerfacecolor=self.colors['seen_base'], markersize=14),
            Line2D([0], [0], marker='o', color='w', label='Unseen Classes',
                   markerfacecolor=self.colors['unseen_base'], markersize=14),
            Line2D([0], [0], color='#333333', lw=3, label='Semantic Correction',
                   marker='>', markersize=12, markerfacecolor='#333333')
        ]

        leg = ax.legend(handles=legend_elements, loc='upper right',
                  fontsize=14, framealpha=0.95, edgecolor='#CCCCCC')
        leg.get_frame().set_linewidth(1.0)

        ax.set_xlabel('t-SNE Dimension 1', fontsize=18, labelpad=12)
        ax.set_ylabel('t-SNE Dimension 2', fontsize=18, labelpad=12)
        ax.set_title('Semantic Correction Visualization (TransZero + CUB)',
                     fontsize=20, fontweight='bold', pad=25)

        plt.tight_layout()
        plt.savefig(save_path, dpi=self.dpi, bbox_inches='tight')
        plt.close()
        print(f"Saved optimized visualization to {save_path}")

def run_tsne_visualization_experiment(trainer, save_dir='./Result/tsne_CUB', n_samples_per_class=20):
    import os
    os.makedirs(save_dir, exist_ok=True)

    print("\n" + "="*80)
    print("Running Optimized t-SNE Visualization...")
    print("="*80)

    # 1. Extract data
    train_embeds, test_seen_embeds, test_unseen_embeds, all_test_logits = trainer.extract_features()
    all_test_embeds = torch.cat([test_seen_embeds, test_unseen_embeds], dim=0)
    all_test_labels = torch.cat([
        trainer.dataloader.data['test_seen']['labels'],
        trainer.dataloader.data['test_unseen']['labels']
    ], dim=0)

    # 2. Correction
    pu_scores = trainer.pu_learner.predict(all_test_embeds.to(trainer.device))
    att_corrected = trainer.correct_att(pu_scores, all_test_embeds, all_test_logits)
    att_original = trainer.model.att.clone()

    # 3. Visualize
    visualizer = SemanticCorrectionVisualizer(trainer)

    # [关键]
    # n_samples_per_class=20: 增加样本点，让背景不那么空
    # max_classes_to_plot=25: 增加类别数，形成足够的拥挤度来展示Cluster结构
    data_dict = visualizer.collect_data_for_tsne(
        all_test_embeds, all_test_labels, att_original, att_corrected,
        n_samples_per_class=n_samples_per_class,
        max_classes_to_plot=25
    )

    tsne_results = visualizer.perform_tsne(data_dict)
    save_path = os.path.join(save_dir, 'tsne_visualization_optimized.png')
    visualizer.visualize(tsne_results, save_path=save_path)

    print("Visualization experiment finished.")
    return visualizer, data_dict, tsne_results