# file: user_extensions/baselines/fader_networks/visualizations.py
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import torch

from prism.core.registry import VISUALIZATIONS
from prism.evaluation.visualization import BaseVisualization, DPI, _visualization_skipper, _plot_image


@VISUALIZATIONS.register("BaselineClustering")
class BaselineClusteringVisualization(BaseVisualization):
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, z_full, y_targets, **kwargs):
        method = self.eval_cfg.cluster_method.lower()
        z = z_full.cpu().numpy().reshape(z_full.shape[0], -1)
        labels = y_targets.cpu().numpy()

        if method == 'pca':
            reducer = PCA(n_components=2, random_state=42)
        elif method == 'tsne':
            tsne_params = {'n_components': 2, 'perplexity': self.eval_cfg.tsne_perplexity, 'random_state': 42}
            reducer = TSNE(**tsne_params)
        else:
            raise ValueError(f"Unsupported cluster method: {method}")

        z_reduced = reducer.fit_transform(z)

        fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)
        fig.suptitle(f'{method.upper()} Latent Space Analysis (Epoch {epoch})')

        scatter = ax.scatter(z_reduced[:, 0], z_reduced[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
        ax.set_title(r'Full Latent Space $z$ (Expect No Clusters by Class)')
        ax.set_xlabel(f"{method.upper()} Component 1")
        ax.set_ylabel(f"{method.upper()} Component 2")
        legend_labels = list(range(self.data_cfg.num_classes))
        legend_title = r"Class Label $y$"
        ax.legend(handles=scatter.legend_elements(num=len(np.unique(labels)))[0], labels=legend_labels, title=legend_title)
        ax.grid(True)

        save_path = plot_dir / f'baseline_cluster_analysis_{method}_epoch_{epoch:03d}.png'
        fig.savefig(save_path, dpi=DPI)
        plt.close(fig)


@VISUALIZATIONS.register("FaderConditionalGeneration")
class FaderConditionalGenerationVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, data, **kwargs):
        pl_module.encoder.eval()
        pl_module.generator.eval()
        device = pl_module.device

        num_source_images = self.eval_cfg.get('num_visualization_samples', 5)
        num_classes = self.data_cfg.num_classes

        num_source_images = min(num_source_images, data.size(0))
        if num_source_images == 0:
            return

        source_images = data[:num_source_images].to(device)
        z = pl_module.encoder(source_images)

        fig, axes = plt.subplots(
            nrows=num_source_images,
            ncols=num_classes + 1,
            figsize=((num_classes + 1) * 2, num_source_images * 2.2),
            constrained_layout=True
        )
        fig.suptitle(f'Fader Network Conditional Generation (Epoch {epoch})', fontsize=16)

        for i in range(num_source_images):
            ax_row = axes[i] if num_source_images > 1 else axes
            _plot_image(ax_row[0], source_images[i], "Original")
            ax_row[0].set_ylabel(f"Source {i+1}", rotation=0, size='large', labelpad=40)

            z_i = z[i:i+1]

            for j in range(num_classes):
                target_label = torch.tensor([j], device=device)
                x_rec = pl_module.generator(z_i, target_label)
                _plot_image(ax_row[j + 1], x_rec, f"Target Attr: {j}")

        if num_source_images > 0:
            ax_row_for_titles = axes[0] if num_source_images > 1 else axes
            for j in range(num_classes):
                 ax_row_for_titles[j+1].set_title(f"Target Attr: {j}", fontsize=12)

        save_path = plot_dir / f'conditional_generation_epoch_{epoch:03d}.png'
        fig.savefig(save_path, dpi=DPI)
        plt.close(fig)