import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.manifold import TSNE


def visualize_shift(
    file_g1_before="group1_before.pt",
    file_g1_after="group1_after.pt",
    file_g2_before="group2_before.pt",
    file_g2_after="group2_after.pt",
    n_samples_to_plot=100,
    random_seed=42,
):
    """
    Loads four datasets (two groups), performs t-SNE, and visualizes the shifts.

    Args:
        file_g1_before (str): Path for group 1's initial data.
        file_g1_after (str): Path for group 1's shifted data.
        file_g2_before (str): Path for group 2's initial data.
        file_g2_after (str): Path for group 2's shifted data.
        n_samples_to_plot (int): The number of samples to use for visualization.
        random_seed (int): Random seed for reproducibility.
    """
    np.random.seed(random_seed)

    # --- 1. Load Data ---
    print("\nLoading data...")
    # This assumes the data files exist.
    # If you need to regenerate them, please use the previous version of the script.
    data_g1_before = torch.load(file_g1_before)
    data_g1_after = torch.load(file_g1_after)
    data_g2_before = torch.load(file_g2_before)
    data_g2_after = torch.load(file_g2_after)

    if not (
        data_g1_before.shape
        == data_g1_after.shape
        == data_g2_before.shape
        == data_g2_after.shape
    ):
        raise ValueError("All four datasets must have the same shape.")

    n_total_samples = data_g1_before.shape[0]
    if n_samples_to_plot > n_total_samples:
        print(
            f"Warning: n_samples_to_plot ({n_samples_to_plot}) is > total samples ({n_total_samples}). Using all samples."
        )
        n_samples_to_plot = n_total_samples

    # --- 2. Sample Data ---
    print(f"Sampling {n_samples_to_plot} points for visualization...")
    sample_indices = np.random.choice(
        n_total_samples, n_samples_to_plot, replace=False
    )

    samples_g1_before = data_g1_before[sample_indices].numpy()
    samples_g1_after = data_g1_after[sample_indices].numpy()
    samples_g2_before = data_g2_before[sample_indices].numpy()
    samples_g2_after = data_g2_after[sample_indices].numpy()

    # Combine all 4 datasets for a consistent transformation
    combined_samples = np.vstack(
        [
            samples_g1_before,
            samples_g1_after,
            samples_g2_before,
            samples_g2_after,
        ]
    )

    # --- 3. Apply Dimensionality Reduction ---
    print("Performing t-SNE... (this may take a moment)")
    perplexity = min(30, 4 * n_samples_to_plot - 1)
    tsne = TSNE(
        n_components=2,
        random_state=random_seed,
        perplexity=perplexity,
    )
    transformed_tsne = tsne.fit_transform(combined_samples)

    # Slice the results back into 4 parts
    tsne_g1_before, tsne_g1_after, tsne_g2_before, tsne_g2_after = np.split(
        transformed_tsne, 4
    )

    # --- 4. Visualize Results ---
    print("Generating plot...")
    # Adjusted figsize for a single-column format
    fig, ax = plt.subplots(1, 1, figsize=(8, 4))
    # fig.suptitle("t-SNE Visualization of Distribution Shift", fontsize=20)

    # Plot t-SNE
    # ax.set_title("Two Groups", fontsize=18)
    ax.scatter(
        tsne_g1_before[:, 0],
        tsne_g1_before[:, 1],
        c="dodgerblue",
        label="text-embed.-3-large | raw",
        alpha=0.7,
        edgecolors="w",
    )
    ax.scatter(
        tsne_g1_after[:, 0],
        tsne_g1_after[:, 1],
        c="tomato",
        label="text-embed.-3-large | trained",
        alpha=0.7,
        edgecolors="w",
    )
    ax.scatter(
        tsne_g2_before[:, 0],
        tsne_g2_before[:, 1],
        c="mediumseagreen",
        label="text-embed.-v4 | raw",
        alpha=0.7,
        edgecolors="w",
    )
    ax.scatter(
        tsne_g2_after[:, 0],
        tsne_g2_after[:, 1],
        c="orange",
        label="text-embed.-v4 | trained",
        alpha=0.7,
        edgecolors="w",
    )
    for i in range(n_samples_to_plot):
        ax.plot(
            [tsne_g1_before[i, 0], tsne_g1_after[i, 0]],
            [tsne_g1_before[i, 1], tsne_g1_after[i, 1]],
            "b-",
            alpha=0.15,
        )
        ax.plot(
            [tsne_g2_before[i, 0], tsne_g2_after[i, 0]],
            [tsne_g2_before[i, 1], tsne_g2_after[i, 1]],
            "g-",
            alpha=0.15,
        )

    # Increased font sizes for labels and legend
    ax.set_xlabel("t-SNE Dimension 1", fontsize=12)
    ax.set_ylabel("t-SNE Dimension 2", fontsize=12)
    ax.tick_params(axis="x", labelsize=10)
    ax.tick_params(axis="y", labelsize=10)
    ax.legend(fontsize=11)
    ax.grid(True, linestyle="--", alpha=0.6)

    # plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.tight_layout()

    # save as svg
    plt.savefig("shift_tsne.svg", format="svg")


if __name__ == "__main__":
    # You can run the script directly.
    # It will load the four specified .pt files and generate the visualization.
    visualize_shift(
        file_g1_before="initial_embeddings.pt",
        file_g1_after="final_embeddings.pt",
        file_g2_before="initial_embeddings_qwen.pt",
        file_g2_after="final_embeddings_qwen.pt",
    )
