# file: prism/evaluation/visualization.py
from abc import ABC, abstractmethod
import functools
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import numpy as np
import pandas as pd
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
import torch

from prism.core.registry import VISUALIZATIONS

DPI = 500


def _visualization_skipper(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        try:
            return func(self, *args, **kwargs)
        except Exception as e:
            epoch = kwargs.get('epoch', 'N/A')
            print(f"Skipped '{self.__class__.__name__}' (epoch {epoch}): {type(e).__name__} - {e}")
            return None

    return wrapper


def _plot_image(ax, img_tensor, title):
    img = img_tensor.detach().cpu().clone()
    if img.dim() == 4 and img.shape[0] == 1:
        img = img.squeeze(0)

    img = (img + 1) / 2.0
    img = torch.clamp(img, 0, 1)

    if img.shape[0] == 3:
        img = img.permute(1, 2, 0)
        ax.imshow(img)
    else:
        ax.imshow(img.squeeze(), cmap='gray')

    ax.set_title(title)
    ax.axis('off')


def _calculate_partial_correlation(train_df, test_df, x, y, covar):
    x_control_train = train_df[covar].values
    y_target_train, x_target_train = train_df[y].values, train_df[x].values

    x_control_test = test_df[covar].values
    y_target_test, x_target_test = test_df[y].values, test_df[x].values

    res_y = y_target_test - LinearRegression().fit(x_control_train, y_target_train).predict(x_control_test)
    res_x = x_target_test - LinearRegression().fit(x_control_train, x_target_train).predict(x_control_test)

    return np.corrcoef(res_x.flatten(), res_y.flatten())[0, 1]


class BaseVisualization(ABC):
    def __init__(self, config):
        self.config = config
        self.eval_cfg = config.evaluation
        self.model_cfg = config.model
        self.data_cfg = config.data

    @abstractmethod
    @rank_zero_only
    def run(self, trainer, pl_module, plot_dir, epoch, **kwargs):
        raise NotImplementedError


@VISUALIZATIONS.register("intervention")
class InterventionVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, data, y_targets, **kwargs):
        pl_module.encoder.eval()
        pl_module.generator.eval()
        device = pl_module.device

        num_plots_to_generate = self.eval_cfg.get('num_visualization_samples', 1)

        image_pairs = []
        used_indices = set()

        for i in range(len(data)):
            if i in used_indices:
                continue

            for j in range(i + 1, len(data)):
                if j not in used_indices and y_targets[j].item() != y_targets[i].item():
                    image_pairs.append((i, j))
                    used_indices.add(i)
                    used_indices.add(j)
                    break

            if len(image_pairs) >= num_plots_to_generate:
                break

        if len(image_pairs) < num_plots_to_generate:
            print(f"  - [WARNING] InterventionVisualization: Could only find {len(image_pairs)} unique pairs out of {num_plots_to_generate} requested.")

        if not image_pairs:
            print("  - [ERROR] InterventionVisualization: Failed to find any suitable image pairs.")
            return

        for n, (img1_idx, img2_idx) in enumerate(image_pairs):
            img1, lbl1 = data[img1_idx:img1_idx + 1].to(device), int(y_targets[img1_idx].item())
            img2, lbl2 = data[img2_idx:img2_idx + 1].to(device), int(y_targets[img2_idx].item())

            z1, z2 = pl_module.encoder(torch.cat([img1, img2])).chunk(2)

            target_slice = slice(self.model_cfg.latent_space.target_slice_start, self.model_cfg.latent_space.target_slice_stop)
            nontarget_slice = slice(self.model_cfg.latent_space.nontarget_slice_start, self.model_cfg.latent_space.nontarget_slice_stop)

            z_swap1 = torch.cat([z1[:, target_slice], z2[:, nontarget_slice]], dim=1)
            z_swap2 = torch.cat([z2[:, target_slice], z1[:, nontarget_slice]], dim=1)

            recons = pl_module.generator(torch.cat([z1, z2, z_swap1, z_swap2])).chunk(4)
            recon1, recon2, recon_swap1, recon_swap2 = recons

            fig, axes = plt.subplots(2, 3, figsize=(9, 6), constrained_layout=True)
            fig.suptitle(f'Disentanglement via Subspace Swapping (Epoch {epoch})')
            _plot_image(axes[0, 0], img1, f'Original $x_A$ ($y_1={lbl1}$)')
            _plot_image(axes[0, 1], recon1, r'Recon. $\hat{x}_A = G(z_{1,A}, z_{0,A})$')
            _plot_image(axes[0, 2], recon_swap1, r'Swapped $\hat{x}_{A \to B} = G(z_{1,A}, z_{0,B})$')
            _plot_image(axes[1, 0], img2, f'Original $x_B$ ($y_1={lbl2}$)')
            _plot_image(axes[1, 1], recon2, r'Recon. $\hat{x}_B = G(z_{1,B}, z_{0,B})$')
            _plot_image(axes[1, 2], recon_swap2, r'Swapped $\hat{x}_{B \to A} = G(z_{1,B}, z_{0,A})$')

            save_path = plot_dir / f'disentanglement_check_epoch_{epoch:03d}_{n}.png'
            fig.savefig(save_path, dpi=DPI)
            plt.close(fig)


@VISUALIZATIONS.register("subspace_replacement")
class SubspaceReplacementVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, z_full, data, y_targets, **kwargs):
        pl_module.encoder.eval()
        pl_module.generator.eval()

        num_samples_per_viz = min(8, data.size(0))
        num_batches = self.eval_cfg.get('num_visualization_samples', 1)

        for n in range(num_batches):
            start_idx = n * num_samples_per_viz
            end_idx = start_idx + num_samples_per_viz

            if start_idx >= data.size(0):
                break

            images = data[start_idx:end_idx].to(pl_module.device)
            labels = y_targets[start_idx:end_idx]

            target_slice = slice(self.model_cfg.latent_space.target_slice_start, self.model_cfg.latent_space.target_slice_stop)
            nontarget_slice = slice(self.model_cfg.latent_space.nontarget_slice_start, self.model_cfg.latent_space.nontarget_slice_stop)

            mean_z1 = z_full[:, target_slice].mean(dim=0)
            mean_z0 = z_full[:, nontarget_slice].mean(dim=0)

            z = pl_module.encoder(images)
            z_target_only = z.clone()
            z_target_only[:, nontarget_slice] = mean_z0.to(pl_module.device)
            z_nontarget_only = z.clone()
            z_nontarget_only[:, target_slice] = mean_z1.to(pl_module.device)

            recons_orig = pl_module.generator(z)
            recons_target = pl_module.generator(z_target_only)
            recons_nontarget = pl_module.generator(z_nontarget_only)

            num_vis = images.shape[0]
            fig, axes = plt.subplots(4, num_vis, figsize=(num_vis * 1.5, 6), constrained_layout=True)
            fig.suptitle(f'Subspace Mean-Value Replacement (Epoch {epoch})')

            for i in range(num_vis):
                _plot_image(axes[0, i], images[i], f"$x_i$ ($y_1={labels[i].item()}$)")
                _plot_image(axes[1, i], recons_orig[i], r"$\hat{x}_i = G(z_{1,i}, z_{0,i})$")
                _plot_image(axes[2, i], recons_target[i], r"$\hat{x}_i = G(z_{1,i}, \mathbb{E}[z_0])$")
                _plot_image(axes[3, i], recons_nontarget[i], r"$\hat{x}_i = G(\mathbb{E}[z_1], z_{0,i})$")

            save_path = plot_dir / f'subspace_replacement_check_epoch_{epoch:03d}_{n}.png'
            fig.savefig(save_path, dpi=DPI)
            plt.close(fig)


@VISUALIZATIONS.register("latent_traversal")
class LatentTraversalVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, z_full, data, **kwargs):
        pl_module.encoder.eval()
        pl_module.generator.eval()

        num_batches = self.eval_cfg.get('num_visualization_samples', 1)
        for n in range(num_batches):
            if n >= data.size(0):
                break

            start_img = data[n:n + 1].to(pl_module.device)
            z_start = pl_module.encoder(start_img)

            nontarget_slice = slice(self.model_cfg.latent_space.nontarget_slice_start, self.model_cfg.latent_space.nontarget_slice_stop)
            z_nontarget = z_full[:, nontarget_slice].reshape(z_full.shape[0], -1).cpu().numpy()
            n_components = self.eval_cfg.pca_components
            pca = PCA(n_components=n_components, random_state=42).fit(z_nontarget)

            num_steps, n_std = self.eval_cfg.traversal_steps, self.eval_cfg.traversal_std
            fig, axes = plt.subplots(n_components, num_steps, figsize=(num_steps * 1.5, n_components * 2), constrained_layout=True)
            fig.suptitle(f'Latent Traversal of PCA on Residual Subspace $z_0$ (Epoch {epoch})')

            for i in range(n_components):
                pc_flat = torch.from_numpy(pca.components_[i]).to(pl_module.device)
                pc_dir = torch.zeros_like(z_start)
                pc_dir[:, nontarget_slice] = pc_flat.view(z_start[:, nontarget_slice].shape[1:])

                z_proj_std = np.std(pca.transform(z_nontarget)[:, i])
                range_vals = torch.linspace(-n_std * z_proj_std, n_std * z_proj_std, num_steps, device=pl_module.device)
                z_traversed = z_start + range_vals.view(-1, 1) * pc_dir
                recons = pl_module.generator(z_traversed).detach().cpu()

                ax_row = axes[i] if n_components > 1 else axes
                for j in range(num_steps):
                    _plot_image(ax_row[j], recons[j], "")
                ax_row[num_steps // 2].set_title(f'PC {i + 1} (Explains {pca.explained_variance_ratio_[i]:.1%})')

            save_path = plot_dir / f'pca_traversal_epoch_{epoch:03d}_{n}.png'
            fig.savefig(save_path, dpi=DPI)
            plt.close(fig)


@VISUALIZATIONS.register("clustering")
class ClusteringVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, z_full, y_targets, **kwargs):
        method = self.eval_cfg.cluster_method.lower()

        target_slice = slice(self.model_cfg.latent_space.target_slice_start, self.model_cfg.latent_space.target_slice_stop)
        nontarget_slice = slice(self.model_cfg.latent_space.nontarget_slice_start, self.model_cfg.latent_space.nontarget_slice_stop)

        z1 = z_full[:, target_slice].cpu().numpy().reshape(z_full.shape[0], -1)
        z0 = z_full[:, nontarget_slice].cpu().numpy().reshape(z_full.shape[0], -1)
        labels = y_targets.cpu().numpy()

        if method == 'pca':
            reducer_z1, reducer_z0 = PCA(n_components=2, random_state=42), PCA(n_components=2, random_state=42)
        elif method == 'tsne':
            tsne_params = {'n_components': 2, 'perplexity': self.eval_cfg.tsne_perplexity, 'random_state': 42}
            reducer_z1, reducer_z0 = TSNE(**tsne_params), TSNE(**tsne_params)
        else:
            raise ValueError(f"Unsupported cluster method: {method}")

        z1_reduced = reducer_z1.fit_transform(z1)
        z0_reduced = reducer_z0.fit_transform(z0)

        fig, axes = plt.subplots(1, 2, figsize=(14, 6), constrained_layout=True)
        fig.suptitle(f'{method.upper()} Latent Space Cluster Analysis (Epoch {epoch})')

        scatter_z1 = axes[0].scatter(z1_reduced[:, 0], z1_reduced[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
        axes[0].set_title(r'Task-Relevant Subspace $z_1$ (Expect Clusters by Class)')
        axes[0].set_xlabel(f"{method.upper()} Component 1")
        axes[0].set_ylabel(f"{method.upper()} Component 2")
        legend_labels = list(range(self.data_cfg.num_classes))
        legend_title = r"Class Label $y_1$"
        axes[0].legend(handles=scatter_z1.legend_elements(num=len(np.unique(labels)))[0], labels=legend_labels, title=legend_title)
        axes[0].grid(True)

        axes[1].scatter(z0_reduced[:, 0], z0_reduced[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
        axes[1].set_title(r'Residual Subspace $z_0$ (Expect No Clusters by Class)')
        axes[1].set_xlabel(f"{method.upper()} Component 1")
        axes[1].set_ylabel(f"{method.upper()} Component 2")
        axes[1].grid(True)

        save_path = plot_dir / f'cluster_analysis_{method}_epoch_{epoch:03d}.png'
        fig.savefig(save_path, dpi=DPI)
        plt.close(fig)


@VISUALIZATIONS.register("correlation")
class CorrelationVisualization(BaseVisualization):
    @torch.no_grad()
    @_visualization_skipper
    def run(self, trainer, pl_module, plot_dir, epoch, z_full, y_targets, y_style, **kwargs):
        style_feature_map = trainer.datamodule.style_feature_map
        if y_style is None or style_feature_map is None:
            print(f"Skipping correlation analysis for epoch {epoch}: No style labels provided.")
            return

        z_flat = z_full.reshape(z_full.shape[0], -1).numpy()

        if z_flat.shape[1] > 64:
            print(f"Skipping correlation analysis for epoch {epoch}: Latent dimension ({z_flat.shape[1]}) exceeds threshold of 64.")
            return

        df = pd.DataFrame(z_flat, columns=[f'z_{i}' for i in range(z_flat.shape[1])])

        style_names = list(style_feature_map.keys())
        df[style_names] = y_style.numpy()
        target_dummies = pd.get_dummies(y_targets.numpy(), prefix='target_id').astype(int)
        df = pd.concat([df, target_dummies], axis=1)
        train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

        target_slice = slice(self.model_cfg.latent_space.target_slice_start, self.model_cfg.latent_space.target_slice_stop)
        nontarget_slice = slice(self.model_cfg.latent_space.nontarget_slice_start, self.model_cfg.latent_space.nontarget_slice_stop)
        z_target_cols = [f'z_{i}' for i in range(target_slice.start, target_slice.stop)]
        z_nontarget_cols = [f'z_{i}' for i in range(nontarget_slice.start, nontarget_slice.stop)]
        all_z_cols = z_target_cols + z_nontarget_cols
        target_id_cols = list(target_dummies.columns)

        results = []
        for style_name in style_names:
            for z_name in all_z_cols:
                control_vars = [c for c in all_z_cols if c != z_name] + target_id_cols
                corr = _calculate_partial_correlation(train_df, test_df, z_name, style_name, control_vars)
                results.append({'style_factor': style_name, 'z': z_name, 'pcorr': corr})

        pcorr_df = pd.DataFrame(results)
        pcorr_matrix = pcorr_df.pivot(index='style_factor', columns='z', values='pcorr').reindex(index=style_names, columns=all_z_cols)

        fig, ax = plt.subplots(figsize=(12, max(5, len(style_names) * 1.2)), constrained_layout=True)
        for r_idx, style_name in enumerate(style_names):
            for c_idx, z_name in enumerate(all_z_cols):
                pcorr_val = pcorr_matrix.loc[style_name, z_name]
                color = 'darkblue' if pcorr_val > 0 else 'darkred'
                radius = np.sqrt(abs(pcorr_val)) * 0.45
                ax.add_patch(Circle((c_idx, r_idx), radius, color=color, alpha=0.7))

        ax.set_aspect('equal', adjustable='box')
        ax.set_xlim(-0.5, len(all_z_cols) - 0.5)
        ax.set_ylim(-0.5, len(style_names) - 0.5)
        ax.set_yticks(range(len(style_names)))
        ax.set_yticklabels(style_names, fontsize=12)
        ax.invert_yaxis()

        target_labels = [f'$z_1^{{({i})}}$' for i in range(len(z_target_cols))]
        nontarget_labels = [f'$z_0^{{({i})}}$' for i in range(len(z_nontarget_cols))]
        ax.set_xticks(range(len(all_z_cols)))
        ax.set_xticklabels(target_labels + nontarget_labels, rotation=90, fontsize=10)

        ax.set_title(f'Partial Correlation between Latent Dims and Style Factors (Epoch {epoch})', fontsize=14)
        ax.axvline(x=len(z_target_cols) - 0.5, color='black', linestyle='-', linewidth=1.5)
        fig.text(0.31, 0.05, r'Task-Relevant Subspace $z_1$', ha='center', va='center', fontsize=12, backgroundcolor='white')
        fig.text(0.73, 0.05, r'Residual Subspace $z_0$', ha='center', va='center', fontsize=12, backgroundcolor='white')

        save_path = plot_dir / f'correlation_analysis_epoch_{epoch:03d}.png'
        fig.savefig(save_path, dpi=DPI)
        plt.close(fig)