# python rsa.py --emb1 embeddings/embeddings_dolph2vec_dolphin_reef_balanced.npy --emb2 embeddings/embeddings_aves_bio_dolphin_reef_balanced.npy --emb3 embeddings/embeddings_biolingual_dolphin_reef_balanced.npy --labels embeddings/labels_dolphin_reef_balanced.npy --output rsa_plot --metric cosine --model_name1 Dolph2Vec --model_name2 Aves-bio --model_name3 BioLingual

import argparse

import numpy as np
from scipy.spatial.distance import pdist, squareform
from scipy.stats import spearmanr
from sklearn.preprocessing import StandardScaler, normalize
import matplotlib.pyplot as plt
import seaborn as sns
import json
from matplotlib import gridspec


class RSAAnalyzer:
    def __init__(self, metric="cosine", normalize_method=None):
        """
        Args:
            metric: 'cosine' or 'euclidean'
            normalize_method: 'zscore', 'l2', or None
        """
        self.metric = metric
        self.normalize_method = normalize_method

    def _normalize(self, X):
        if self.normalize_method == "zscore":
            return StandardScaler().fit_transform(X)
        elif self.normalize_method == "l2":
            return normalize(X, norm="l2")
        elif self.normalize_method == "minmax":
            from sklearn.preprocessing import MinMaxScaler

            return MinMaxScaler().fit_transform(X)
        return X

    def _compute_rdm(self, X):
        # Return similarity matrix instead of dissimilarity
        dist_matrix = squareform(pdist(X, metric=self.metric))
        sim_matrix = 1 - dist_matrix  # Similarity = 1 - dissimilarity
        return sim_matrix

    def _rsa_correlation(self, rdm1, rdm2):
        idx = np.triu_indices_from(rdm1, k=1)
        return spearmanr(rdm1[idx], rdm2[idx])

    def analyze(
        self,
        embeddings1,
        embeddings2,
        embeddings3,
        labels,
        plot=True,
        title1="",
        title2="",
        title3="",
        output_path=None,
    ):
        """
        Perform RSA analysis between two embedding sets.

        Args:
            embeddings1, embeddings2: np.ndarray of shape (n_samples, dim)
            labels: array of shape (n_samples,)
            plot: bool, whether to visualize RDMs

        Returns:
            results: dict with full and prototype RSA results
        """

        # Normalize
        X1 = self._normalize(embeddings1)
        X2 = self._normalize(embeddings2)
        X3 = self._normalize(embeddings3)

        # Reorder by label
        labels = np.array(labels)
        sort_idx = np.argsort(labels)
        X1 = X1[sort_idx]
        X2 = X2[sort_idx]
        X3 = X3[sort_idx]
        labels_sorted = labels[sort_idx]
        unique_labels = np.unique(labels_sorted)

        # Full RSA
        rdm1_full = self._compute_rdm(X1)
        rdm2_full = self._compute_rdm(X2)
        rdm3_full = self._compute_rdm(X3)

        rsa_12, p_12 = self._rsa_correlation(rdm1_full, rdm2_full)
        rsa_13, p_13 = self._rsa_correlation(rdm1_full, rdm3_full)

        results = {
            "rsa_12": {
                "correlation": rsa_12,
                "p_value": p_12,
            },
            "rsa_13": {
                "correlation": rsa_13,
                "p_value": p_13,
            },
        }

        if plot:
            self._plot_rdms(
                rdm1_full,
                rdm2_full,
                rdm3_full,
                unique_labels,
                labels_sorted,
                rsa_12,
                rsa_13,
                title1,
                title2,
                title3,
                output_path,
            )

        return results

    def export(self, results, filepath):
        with open(filepath, "w") as f:
            json.dump(results, f, indent=4)

    def _plot_rdms(
        self,
        rdm1_full,
        rdm2_full,
        rdm3_full,
        labels,
        labels_sorted,
        rsa12,
        rsa13,
        title1,
        title2,
        title3,
        output_path,
    ):
        # fig, axs = plt.subplots(1, 3, figsize=(12, 5))
        fig = plt.figure(figsize=(13, 5))  # Extra width for colorbar

        from matplotlib import gridspec
        # Create GridSpec for layout control
        gs = gridspec.GridSpec(1, 4, width_ratios=[1, 1, 1, 0.05])  # , wspace=0.1)

        # Create axes for the three heatmaps
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1])
        ax3 = plt.subplot(gs[2])
        cbar_ax = plt.subplot(gs[3])  # Separate axis for colorbar

        # Compute label boundaries for group annotations
        label_ticks = []
        label_names = []
        current = 0
        for lbl in labels:
            count = np.sum(labels_sorted == lbl)
            center = current + count / 2
            label_ticks.append(center)
            label_names.append(str(lbl))
            current += count

        hm1 = sns.heatmap(
            rdm1_full, ax=ax1, cmap="RdBu_r", xticklabels=False, yticklabels=False, vmin=-0.51, vmax=0.51, cbar=False, square=True,
        )

        title_fontsize = 12

        ax1.set_title(title1, fontsize=title_fontsize)
        ax1.xaxis.set_ticks_position('top')
        ax1.xaxis.set_label_position('top')
        ax1.set_xticks(label_ticks)
        ax1.set_yticks(label_ticks)
        ax1.set_xticklabels(label_names, rotation=90)
        ax1.set_yticklabels(label_names, rotation=0)

        hm2 =sns.heatmap(
            rdm2_full, ax=ax2, cmap="RdBu_r", xticklabels=False, yticklabels=False, vmin=-0.51, vmax=0.51, cbar=False, square=True,
        )
        ax2.set_title(title2, fontsize=title_fontsize)
        ax2.xaxis.set_ticks_position('top')
        ax2.xaxis.set_label_position('top')
        ax2.set_xticks(label_ticks)
        ax2.set_yticks(label_ticks)
        ax2.set_xticklabels(label_names, rotation=90)
        ax2.set_yticklabels(label_names, rotation=0)

        hm3 = sns.heatmap(
            rdm3_full, ax=ax3, cmap="RdBu_r", xticklabels=False, yticklabels=False, vmin=-0.51, vmax=0.51, cbar=False, square=True,
        )
        ax3.set_title(title3, fontsize=title_fontsize)
        ax3.xaxis.set_ticks_position('top')
        ax3.xaxis.set_label_position('top')
        ax3.set_xticks(label_ticks)
        ax3.set_yticks(label_ticks)
        ax3.set_xticklabels(label_names, rotation=90)
        ax3.set_yticklabels(label_names, rotation=0)

        # Add colorbar to the right
        plt.colorbar(hm3.get_children()[0], cax=cbar_ax, label="Similarity")

        plt.suptitle("RSA Matrices", fontsize=14, weight='bold')
        plt.tight_layout()
        # plt.tight_layout(rect=[0, 0, 0.9, 1])  # leave space for external colorbar

        if output_path:
            plt.savefig(f"{output_path}.png", dpi=300)
        plt.show()


def main():
    parser = argparse.ArgumentParser(
        description="Perform RSA analysis between two embedding sets"
    )
    parser.add_argument(
        "--emb1", required=True, help="Path to first embeddings file in .npy format"
    )
    parser.add_argument(
        "--emb2", required=True, help="Path to second embeddings file in .npy format"
    )
    parser.add_argument(
        "--emb3", required=True, help="Path to second embeddings file in .npy format"
    )
    parser.add_argument(
        "--model_name1", required=True, help="Title of model1 for visualization"
    )
    parser.add_argument(
        "--model_name2", required=True, help="Title of model2 for visualization"
    )
    parser.add_argument(
        "--model_name3", required=True, help="Title of model2 for visualization"
    )
    parser.add_argument(
        "--labels", required=True, help="Path to labels file in .npy format"
    )
    parser.add_argument("--output", required=True, help="Path to output file name")
    parser.add_argument(
        "--metric",
        default="cosine",
        choices=["cosine", "euclidean"],
        help="Distance metric for RDM computation",
    )
    parser.add_argument(
        "--normalize_method",
        default="zscore",
        choices=["zscore", "minmax", "l2", "none"],
        help="Normalization method",
    )
    parser.add_argument(
        "--no-plot", action="store_true", help="Disable RDM visualization"
    )

    args = parser.parse_args()

    # Load data
    emb1 = np.load(args.emb1)
    emb2 = np.load(args.emb2)
    emb3 = np.load(args.emb3)
    labels = np.load(args.labels)

    # Create analyzer
    analyzer = RSAAnalyzer(
        metric=args.metric,
        normalize_method=args.normalize_method,
    )

    # Run analysis
    results = analyzer.analyze(
        emb1,
        emb2,
        emb3,
        labels,
        plot=not args.no_plot,
        title1=args.model_name1,
        title2=args.model_name2,
        title3=args.model_name3,
        output_path=args.output,
    )

    # Export results
    analyzer.export(results, args.output)
    print(f"Results saved to {args.output}")


if __name__ == "__main__":
    main()
