import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import seaborn as sns


class Visualization:
    """Visualization Tools"""

    def __init__(self, color_palette='tab10'):
        self.color_palette = color_palette
        sns.set_style("whitegrid")

    def plot_embeddings_tsne(self, embeddings, labels=None, env_labels=None,
                             title='t-SNE Visualization', save_path=None):
        """t-SNE Visualization"""
        if isinstance(embeddings, torch.Tensor):
            embeddings = embeddings.cpu().numpy()
        if labels is not None and isinstance(labels, torch.Tensor):
            labels = labels.cpu().numpy()
        if env_labels is not None and isinstance(env_labels, torch.Tensor):
            env_labels = env_labels.cpu().numpy()

        # Dimensionality reduction
        tsne = TSNE(n_components=2, perplexity=30, random_state=42)
        embeddings_2d = tsne.fit_transform(embeddings)

        fig, axes = plt.subplots(1, 2 if env_labels is not None else 1,
                                 figsize=(12, 5))

        if env_labels is not None:
            axes = [axes] if not isinstance(axes, np.ndarray) else axes

            # Color by class
            ax = axes[0]
            scatter = ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                                 c=labels, cmap=self.color_palette,
                                 alpha=0.6, s=30)
            ax.set_title(f'{title} - By Class')
            ax.set_xlabel('t-SNE 1')
            ax.set_ylabel('t-SNE 2')
            plt.colorbar(scatter, ax=ax, label='Class')

            # Color by environment
            ax = axes[1]
            scatter = ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                                 c=env_labels, cmap='Set2',
                                 alpha=0.6, s=30)
            ax.set_title(f'{title} - By Environment')
            ax.set_xlabel('t-SNE 1')
            ax.set_ylabel('t-SNE 2')
            plt.colorbar(scatter, ax=ax, label='Environment')
        else:
            ax = axes if isinstance(axes, plt.Axes) else axes[0]
            scatter = ax.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1],
                                 c=labels, cmap=self.color_palette,
                                 alpha=0.6, s=30)
            ax.set_title(title)
            ax.set_xlabel('t-SNE 1')
            ax.set_ylabel('t-SNE 2')
            plt.colorbar(scatter, ax=ax, label='Class')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

    def plot_invariant_variant_comparison(self, invariant_embs, variant_embs,
                                          labels, env_labels, save_path=None):
        """Compare invariant and variant representations"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))

        # t-SNE of invariant representations (by class)
        inv_all = torch.cat(list(invariant_embs.values()), dim=0).cpu().numpy()
        labels_all = labels.repeat(len(invariant_embs)).cpu().numpy()

        tsne = TSNE(n_components=2, random_state=42)
        inv_2d = tsne.fit_transform(inv_all)

        axes[0, 0].scatter(inv_2d[:, 0], inv_2d[:, 1], c=labels_all,
                           cmap=self.color_palette, alpha=0.6, s=20)
        axes[0, 0].set_title('Invariant Representations (by Class)')
        axes[0, 0].set_xlabel('t-SNE 1')
        axes[0, 0].set_ylabel('t-SNE 2')

        # t-SNE of invariant representations (by environment)
        env_labels_all = []
        for env_id, embs in invariant_embs.items():
            env_labels_all.extend([env_id] * len(embs))
        env_labels_all = np.array(env_labels_all)

        axes[0, 1].scatter(inv_2d[:, 0], inv_2d[:, 1], c=env_labels_all,
                           cmap='Set2', alpha=0.6, s=20)
        axes[0, 1].set_title('Invariant Representations (by Environment)')
        axes[0, 1].set_xlabel('t-SNE 1')
        axes[0, 1].set_ylabel('t-SNE 2')

        # t-SNE of variant representations (by class)
        var_all = torch.cat(list(variant_embs.values()), dim=0).cpu().numpy()

        var_2d = tsne.fit_transform(var_all)

        axes[1, 0].scatter(var_2d[:, 0], var_2d[:, 1], c=labels_all,
                           cmap=self.color_palette, alpha=0.6, s=20)
        axes[1, 0].set_title('Variant Representations (by Class)')
        axes[1, 0].set_xlabel('t-SNE 1')
        axes[1, 0].set_ylabel('t-SNE 2')

        # t-SNE of variant representations (by environment)
        axes[1, 1].scatter(var_2d[:, 0], var_2d[:, 1], c=env_labels_all,
                           cmap='Set2', alpha=0.6, s=20)
        axes[1, 1].set_title('Variant Representations (by Environment)')
        axes[1, 1].set_xlabel('t-SNE 1')
        axes[1, 1].set_ylabel('t-SNE 2')

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

    def plot_prototype_alignment(self, invariant_embs, prototypes, labels,
                                 save_path=None):
        """Prototype alignment visualization"""
        fig, ax = plt.subplots(figsize=(10, 8))

        # Map all representations to 2D
        all_embs = torch.cat(list(invariant_embs.values()), dim=0).cpu().numpy()
        all_labels = labels.repeat(len(invariant_embs)).cpu().numpy()

        # Use PCA for dimensionality reduction
        pca = PCA(n_components=2)
        embs_2d = pca.fit_transform(all_embs)

        # Plot representation points
        scatter = ax.scatter(embs_2d[:, 0], embs_2d[:, 1],
                             c=all_labels, cmap=self.color_palette,
                             alpha=0.4, s=20, label='Representations')

        # Plot prototype points
        prototype_points = []
        prototype_labels = []

        for class_id, proto in prototypes.items():
            if 'mean' in proto:
                # Map prototype to same PCA space
                proto_2d = pca.transform(proto['mean'].cpu().numpy().reshape(1, -1))
                prototype_points.append(proto_2d[0])
                prototype_labels.append(class_id)

        if prototype_points:
            proto_points = np.array(prototype_points)
            ax.scatter(proto_points[:, 0], proto_points[:, 1],
                       c=prototype_labels, cmap=self.color_palette,
                       s=200, marker='*', edgecolors='black',
                       linewidths=2, label='Prototypes')

            # Draw connecting lines (each point to its class prototype)
            for i in range(len(embs_2d)):
                class_id = all_labels[i]
                if class_id in prototypes:
                    proto_idx = prototype_labels.index(class_id)
                    ax.plot([embs_2d[i, 0], proto_points[proto_idx, 0]],
                            [embs_2d[i, 1], proto_points[proto_idx, 1]],
                            color='gray', alpha=0.1, linewidth=0.5)

        ax.set_title('Prototype Alignment Visualization')
        ax.set_xlabel('PCA 1')
        ax.set_ylabel('PCA 2')
        ax.legend()

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

    def plot_loss_curves(self, loss_history, save_path=None):
        """Loss curves visualization"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Total loss
        if 'total_loss' in loss_history:
            axes[0, 0].plot(loss_history['total_loss'])
            axes[0, 0].set_title('Total Loss')
            axes[0, 0].set_xlabel('Step')
            axes[0, 0].set_ylabel('Loss')
            axes[0, 0].grid(True, alpha=0.3)

        # Classification loss
        if 'classification_loss' in loss_history:
            axes[0, 1].plot(loss_history['classification_loss'])
            axes[0, 1].set_title('Classification Loss')
            axes[0, 1].set_xlabel('Step')
            axes[0, 1].set_ylabel('Loss')
            axes[0, 1].grid(True, alpha=0.3)

        # Alignment loss
        if 'alignment_loss' in loss_history:
            axes[1, 0].plot(loss_history['alignment_loss'])
            axes[1, 0].set_title('Alignment Loss')
            axes[1, 0].set_xlabel('Step')
            axes[1, 0].set_ylabel('Loss')
            axes[1, 0].grid(True, alpha=0.3)

        # Disentanglement loss
        if 'disentangle_loss' in loss_history:
            axes[1, 1].plot(loss_history['disentangle_loss'])
            axes[1, 1].set_title('Disentanglement Loss')
            axes[1, 1].set_xlabel('Step')
            axes[1, 1].set_ylabel('Loss')
            axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

    def plot_environment_distributions(self, env_embeddings, save_path=None):
        """Environment distributions visualization"""
        num_envs = len(env_embeddings)

        fig, axes = plt.subplots(1, num_envs, figsize=(5 * num_envs, 4))
        if num_envs == 1:
            axes = [axes]

        for i, (env_id, embs) in enumerate(env_embeddings.items()):
            embs_np = embs.cpu().numpy()

            # Calculate mean and covariance ellipse
            mean = embs_np.mean(axis=0)
            cov = np.cov(embs_np.T)

            # Only show first two dimensions
            axes[i].scatter(embs_np[:, 0], embs_np[:, 1], alpha=0.5, s=10)

            # Plot mean point
            axes[i].scatter(mean[0], mean[1], color='red', s=100, marker='x', linewidths=2)

            # Plot covariance ellipse
            from matplotlib.patches import Ellipse
            import matplotlib.transforms as transforms

            if embs_np.shape[1] >= 2:
                # Calculate ellipse parameters
                eigvals, eigvecs = np.linalg.eigh(cov[:2, :2])
                angle = np.degrees(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))

                # Draw ellipse
                for j in range(1, 4):
                    width = 2 * j * np.sqrt(eigvals[0])
                    height = 2 * j * np.sqrt(eigvals[1])
                    ellipse = Ellipse(xy=mean[:2], width=width, height=height,
                                      angle=angle, alpha=0.2 / j, color='red')
                    axes[i].add_patch(ellipse)

            axes[i].set_title(f'Environment {env_id}')
            axes[i].set_xlabel('Dimension 1')
            axes[i].set_ylabel('Dimension 2')
            axes[i].grid(True, alpha=0.3)

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        plt.show()

