{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95c13c5d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sb\n",
    "import matplotlib\n",
    "from datetime import datetime\n",
    "import math\n",
    "\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "\n",
    "print(\"Analysis of GPU RNN Model Results\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Constants\n",
    "tr_len = 23\n",
    "loss_thresh = 0.04\n",
    "decorr_thresh = 0.6  # New threshold for decorrelation analysis\n",
    "\n",
    "file_chosen = '2025-09-10-10-16'  # Update this to your actual filename\n",
    "print(f\"Loading data from files with timestamp: {file_chosen}\")\n",
    "\n",
    "corr_curve = np.load(f'./0910/corr_curve_neuromamba_{file_chosen}.npy')\n",
    "accuracy_curve_all_test = np.load(f'./0910/accuracy_curve_all_test_neuromamba_{file_chosen}.npy')\n",
    "loss_all = np.load(f'./0910/loss_all_neuromamba_{file_chosen}.npy')\n",
    "\n",
    "print(f\"Data loaded successfully.\")\n",
    "print(f\"Total number of simulations: {corr_curve.shape[0]}\")\n",
    "print(f\"Number of epochs: {corr_curve.shape[1]}\")\n",
    "print(f\"Correlation matrix size: {corr_curve.shape[2]}x{corr_curve.shape[3]}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Basic Statistics\n",
    "print(\"Basic Statistics:\")\n",
    "print(f\"Final mean loss: {np.mean(loss_all[:,-1]):.4f}\")\n",
    "print(f\"Final mean accuracy: {np.mean(accuracy_curve_all_test[:,-1]):.4f}\")\n",
    "\n",
    "# Number of runs with good convergence\n",
    "good_runs = (loss_all[:,-1] < loss_thresh).sum()\n",
    "print(f\"Number of runs with good convergence (loss < {loss_thresh}): {good_runs}\")\n",
    "print(f\"Percentage of good runs: {good_runs/loss_all.shape[0]*100:.2f}%\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Last time step mean correlation\n",
    "corr_avg_last_session = np.mean(corr_curve[:,-1,0:tr_len,tr_len:2*tr_len][loss_all[:,-1]<loss_thresh], axis=(1,2))\n",
    "print(\"Last time step mean correlation statistics:\")\n",
    "print(f\"Mean: {np.mean(corr_avg_last_session):.4f}\")\n",
    "print(f\"Median: {np.median(corr_avg_last_session):.4f}\")\n",
    "print(f\"Min: {np.min(corr_avg_last_session):.4f}\")\n",
    "print(f\"Max: {np.max(corr_avg_last_session):.4f}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "print(\"Generating plots...\")\n",
    "\n",
    "# Plot mean correlation matrix\n",
    "plt.figure(figsize=(10, 8))\n",
    "sb.heatmap(np.mean(corr_curve[:,-1,0:tr_len,tr_len:2*tr_len][loss_all[:,-1]<loss_thresh], axis=0), \n",
    "           vmin=-1, vmax=1, cmap='icefire')\n",
    "for plot_line in [6,10,13,15,18,20]:\n",
    "    plt.axvline(plot_line, linestyle='--', color='gray')\n",
    "    plt.axhline(plot_line, linestyle='--', color='gray')\n",
    "plt.title('NeuMa average correlation')\n",
    "output_filename = 'NeuMa_average_correlation.pdf'\n",
    "\n",
    "# plt.savefig(output_filename, format='pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"Mean correlation matrix plotted.\")\n",
    "print(\"This heatmap shows the average correlation between different time steps across all good runs.\")\n",
    "print(\"The diagonal structure suggests temporal dependencies in the model's representations.\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Define regions for correlation analysis\n",
    "regions = [[0, 6], [10, 13], [15, 18], [20, 23]]\n",
    "other_regions = np.array([10, 13, 14, 15, 18, 19, 20, 23])\n",
    "\n",
    "# Create correlation matrices\n",
    "correlation_matrix_1 = np.zeros((tr_len, tr_len))\n",
    "for i, region_i in enumerate(regions):\n",
    "    for j, region_j in enumerate(regions):\n",
    "        if i != j:\n",
    "            correlation_matrix_1[region_i[0]:region_i[1], region_j[0]:region_j[1]] = 1\n",
    "\n",
    "correlation_matrix_2 = np.zeros((tr_len, tr_len))\n",
    "correlation_matrix_2[other_regions[3]:other_regions[5], other_regions[3]:other_regions[5]] = np.eye(other_regions[5] - other_regions[3])\n",
    "\n",
    "correlation_matrix_3 = np.zeros((tr_len, tr_len))\n",
    "correlation_matrix_3[other_regions[0]:other_regions[2], other_regions[0]:other_regions[2]] = np.eye(other_regions[2] - other_regions[0])\n",
    "\n",
    "correlation_matrices = [correlation_matrix_1, correlation_matrix_2, correlation_matrix_3]\n",
    "correlation_names = ['Off-diagonal', 'Pre-R2', 'Pre-R1']\n",
    "\n",
    "# Analyze decorrelation\n",
    "good_run_indices = np.where(loss_all[:,-1] < loss_thresh)[0]\n",
    "all_masks_matrix = np.full((3, len(good_run_indices), corr_curve.shape[1]), np.nan)\n",
    "\n",
    "# Calculate mean values for all regions\n",
    "for i, session_n in enumerate(good_run_indices):\n",
    "    corr_position_day = corr_curve[session_n][:,0:tr_len, tr_len:2*tr_len]\n",
    "    for j, mask in enumerate(correlation_matrices):\n",
    "        mask_array = np.zeros_like(corr_position_day, dtype=bool)\n",
    "        mask_array += mask.astype(bool)\n",
    "        masked_a_array = np.ma.masked_array(corr_position_day, mask=~mask_array)\n",
    "        mean_values_array = masked_a_array.mean(axis=(1, 2))\n",
    "        all_masks_matrix[j, i, 0:len(mean_values_array)] = mean_values_array\n",
    "\n",
    "# Additional filtering based on decorrelation threshold\n",
    "final_means = all_masks_matrix[:, :, -1]\n",
    "good_decorr_runs = np.all(final_means < decorr_thresh, axis=0)\n",
    "filtered_indices = good_run_indices[good_decorr_runs]\n",
    "\n",
    "print(f\"Number of runs passing both loss and decorrelation thresholds: {np.sum(good_decorr_runs)}\")\n",
    "print(f\"Percentage of good runs passing decorrelation threshold: {np.sum(good_decorr_runs)/len(good_run_indices)*100:.2f}%\")\n",
    "\n",
    "# Plot individual correlation matrices as subplots\n",
    "n_plots = min(9, np.sum(good_decorr_runs))\n",
    "n_rows = math.ceil(n_plots / 3)\n",
    "n_cols = min(n_plots, 3)\n",
    "\n",
    "fig, axes = plt.subplots(n_rows, n_cols, figsize=(6*n_cols, 5*n_rows))\n",
    "axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]\n",
    "\n",
    "for i in range(n_plots):\n",
    "    sb.heatmap(corr_curve[:,-1,0:tr_len,tr_len:2*tr_len][filtered_indices[i],:,:], \n",
    "               vmin=-1, vmax=1, cmap='icefire', ax=axes[i])\n",
    "    for plot_line in [6,10,13,15,18,20]:\n",
    "        axes[i].axvline(plot_line, linestyle='--', color='gray')\n",
    "        axes[i].axhline(plot_line, linestyle='--', color='gray')\n",
    "    axes[i].set_title(f'Run {filtered_indices[i]+1}\\nLoss: {loss_all[filtered_indices[i],-1]:.4f}, '\n",
    "                      f'Acc: {accuracy_curve_all_test[filtered_indices[i],-1]:.4f}')\n",
    "\n",
    "for j in range(i+1, len(axes)):\n",
    "    fig.delaxes(axes[j])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Individual correlation matrices plotted for filtered runs.\")\n",
    "print(\"These plots show the variation in correlation patterns across different runs that pass both thresholds.\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Plot average accuracy curve\n",
    "plt.figure(figsize=(10, 6))\n",
    "plt.plot(np.mean(accuracy_curve_all_test[filtered_indices], axis=0))\n",
    "plt.title('Average Accuracy Curve (Filtered Runs)')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Accuracy')\n",
    "plt.show()\n",
    "\n",
    "print(\"Average accuracy curve plotted for filtered runs.\")\n",
    "print(f\"Final average accuracy for filtered runs: {np.mean(accuracy_curve_all_test[filtered_indices, -1]):.4f}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Plot correlation distribution\n",
    "#corr_all = np.mean(corr_curve[:,-1,0:tr_len,tr_len:2*tr_len][filtered_indices], axis=(1,2))\n",
    "corr_all = np.mean(corr_curve[:,-1,0:tr_len,tr_len:2*tr_len][good_run_indices], axis=(1,2))\n",
    "num_points = len(corr_all)\n",
    "xjitter = np.random.normal(0, 0.1, num_points)\n",
    "yall = 6 * np.ones(num_points) + xjitter\n",
    "\n",
    "fig, axs = plt.subplots(1, figsize=(20, 10))\n",
    "axs.bar([1], np.mean(corr_all))\n",
    "axs.plot(corr_all, yall, 'o', color='black')\n",
    "axs.set_ylim(2, 10)\n",
    "axs.set_xlim(0, 1)\n",
    "axs.set_yticks([2, 4, 6, 8, 10], ['', '', 'NeuroMamba', '', ''])\n",
    "axs.set_title('Correlation Distribution (Filtered Runs)')\n",
    "plt.show()\n",
    "\n",
    "print(\"Correlation distribution plotted for filtered runs.\")\n",
    "print(\"This plot shows the distribution of mean correlations across filtered runs.\")\n",
    "print(f\"Mean correlation: {np.mean(corr_all):.4f}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Plot correlation matrices\n",
    "for i, matrix in enumerate(correlation_matrices):\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sb.heatmap(matrix)\n",
    "    plt.title(f'Correlation Matrix {i+1}: {correlation_names[i]}')\n",
    "    plt.show()\n",
    "\n",
    "print(\"Correlation matrices plotted.\")\n",
    "print(\"These matrices show the regions we're interested in analyzing:\")\n",
    "print(\"Matrix 1: Off-diagonal regions\")\n",
    "print(\"Matrix 2: Pre-R2 region\")\n",
    "print(\"Matrix 3: Pre-R1 region\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Plot decorrelation analysis for filtered runs\n",
    "n_iter = 10\n",
    "subplots_per_fig = 20\n",
    "num_figures = math.ceil(np.sum(good_decorr_runs) / subplots_per_fig)\n",
    "\n",
    "for fig_num in range(num_figures):\n",
    "    start_idx = fig_num * subplots_per_fig\n",
    "    end_idx = min((fig_num + 1) * subplots_per_fig, np.sum(good_decorr_runs))\n",
    "    \n",
    "    n_rows = math.ceil(math.sqrt(end_idx - start_idx))\n",
    "    n_cols = math.ceil((end_idx - start_idx) / n_rows)\n",
    "    \n",
    "    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows), dpi=300)\n",
    "    axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes]\n",
    "    \n",
    "    for i, session_n in enumerate(range(start_idx, end_idx)):\n",
    "        for j, mask in enumerate(correlation_matrices):\n",
    "            mean_values_array = all_masks_matrix[j, good_decorr_runs][session_n]\n",
    "            axes[i].plot(np.arange(0, len(mean_values_array)*n_iter, n_iter), mean_values_array, 'o-', label=correlation_names[j])\n",
    "        \n",
    "        axes[i].set_title(f'Run {filtered_indices[session_n]+1}')\n",
    "        axes[i].set_ylim(0, 1)\n",
    "        axes[i].set_xlabel('Epoch')\n",
    "        axes[i].set_ylabel('Mean Correlation')\n",
    "        axes[i].legend()\n",
    "    \n",
    "    for j in range(i+1, len(axes)):\n",
    "        fig.delaxes(axes[j])\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.suptitle(f'Decorrelation Analysis (Runs {start_idx+1}-{end_idx} out of {np.sum(good_decorr_runs)} filtered runs)', fontsize=16)\n",
    "    plt.subplots_adjust(top=0.95)\n",
    "    plt.show()\n",
    "\n",
    "print(\"Decorrelation analysis plots completed.\")\n",
    "print(f\"Total figures generated: {num_figures}\")\n",
    "print(\"These plots show how correlations in different regions change over training epochs for filtered runs.\")\n",
    "print(\"Each line represents a different region of interest in the correlation matrix.\")\n",
    "print(\"Decreasing trends indicate successful decorrelation of representations over time.\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Plot heatmaps for all masks (filtered runs only)\n",
    "for m in range(len(correlation_matrices)):\n",
    "    plt.figure(figsize=(10, 8))\n",
    "    sb.heatmap(all_masks_matrix[m, good_decorr_runs])\n",
    "    plt.title(f'Mask {m+1} Heatmap: {correlation_names[m]}')\n",
    "    plt.xlabel('Epoch')\n",
    "    plt.ylabel('Run')\n",
    "    plt.show()\n",
    "\n",
    "print(\"Mask heatmaps plotted for filtered runs.\")\n",
    "print(\"These heatmaps show the evolution of correlations in different regions across runs and epochs.\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Analyze crossing points (filtered runs only)\n",
    "cross_thresh = decorr_thresh\n",
    "run_plot = np.array([True] * np.sum(good_decorr_runs))\n",
    "where_cross = np.zeros((np.sum(good_decorr_runs), len(correlation_matrices)))\n",
    "\n",
    "for m in range(len(correlation_matrices)):\n",
    "    for run in range(np.sum(good_decorr_runs)):\n",
    "        cross_points = np.where(all_masks_matrix[m, good_decorr_runs][run] < cross_thresh)[0]\n",
    "        if len(cross_points) > 0:\n",
    "            where_cross[run, m] = cross_points[0]\n",
    "        else:\n",
    "            where_cross[run, m] = all_masks_matrix.shape[2]\n",
    "            run_plot[run] = False\n",
    "\n",
    "where_cross = where_cross[run_plot]\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "x_plot = np.repeat(np.array([0,1,2])[:,np.newaxis], where_cross.shape[0], axis=1) + np.random.normal(0,0.05,(3,where_cross.shape[0]))\n",
    "plt.barh([0,1,2], [np.mean(where_cross[:,0]), np.mean(where_cross[:,1]), np.mean(where_cross[:,2])], color='lightblue')\n",
    "plt.plot(np.vstack((where_cross[:,0], where_cross[:,1], where_cross[:,2])), x_plot, color='gray')\n",
    "plt.plot(np.vstack((where_cross[:,0], where_cross[:,1], where_cross[:,2])), x_plot, 'o', color='black')\n",
    "plt.yticks([0,1,2], correlation_names)\n",
    "plt.title('Crossing Points Analysis (Filtered Runs of NeuMa)')\n",
    "plt.xlabel('Epoch')\n",
    "output_filename = 'Crossing Points Analysis (NeuMa).pdf'\n",
    "\n",
    "# plt.savefig(output_filename, format='pdf', bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"Crossing points analysis completed and plotted for filtered runs.\")\n",
    "print(\"This plot shows when different regions' correlations cross the threshold of 0.4.\")\n",
    "print(f\"Mean crossing epoch for {correlation_names[0]}: {np.mean(where_cross[:,0]):.2f}\")\n",
    "print(f\"Mean crossing epoch for {correlation_names[1]}: {np.mean(where_cross[:,1]):.2f}\")\n",
    "print(f\"Mean crossing epoch for {correlation_names[2]}: {np.mean(where_cross[:,2]):.2f}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "print(\"Analysis completed.\")\n",
    "print(f\"Summary:\")\n",
    "print(f\"- Total simulations: {corr_curve.shape[0]}\")\n",
    "print(f\"- Runs with loss < {loss_thresh}: {good_runs}\")\n",
    "print(f\"- Runs passing both loss and decorrelation thresholds: {np.sum(good_decorr_runs)}\")\n",
    "print(f\"- Final mean loss (all runs): {np.mean(loss_all[:,-1]):.4f}\")\n",
    "print(f\"- Final mean loss (filtered runs): {np.mean(loss_all[filtered_indices, -1]):.4f}\")\n",
    "print(f\"- Final mean accuracy (all runs): {np.mean(accuracy_curve_all_test[:,-1]):.4f}\")\n",
    "print(f\"- Final mean accuracy (filtered runs): {np.mean(accuracy_curve_all_test[filtered_indices, -1]):.4f}\")\n",
    "print(f\"- Mean last-step correlation (filtered runs): {np.mean(corr_all):.4f}\")\n",
    "print(\"The analysis suggests that the model is learning to decorrelate representations over time, with different regions showing distinct patterns of decorrelation.\")\n",
    "print(f\"Decorrelation is particularly strong in runs that pass both the loss threshold of {loss_thresh} and the decorrelation threshold of {decorr_thresh}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d26731e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "# Calculate mean across filtered runs\n",
    "mean_corr_curve = np.mean(corr_curve[filtered_indices], axis=0)\n",
    "\n",
    "# Choose 5 epochs to visualize, including the last one\n",
    "num_epochs = mean_corr_curve.shape[0]\n",
    "epochs_to_plot = [0, num_epochs // 4, num_epochs // 2, 3 * num_epochs // 4, num_epochs - 1]\n",
    "\n",
    "# Create subplots\n",
    "fig, axs = plt.subplots(1, 5, figsize=(20, 4), dpi=600)\n",
    "\n",
    "for i, epoch in enumerate(epochs_to_plot):\n",
    "    sns.heatmap(mean_corr_curve[epoch, 0:tr_len, tr_len:2*tr_len], \n",
    "                cmap='icefire', vmin=-1, vmax=1, ax=axs[i], \n",
    "                cbar=False, xticklabels=False, yticklabels=False, linewidths=0)\n",
    "    axs[i].set_aspect('equal')  # make each subplot square\n",
    "    axs[i].set_title(f'Epoch {epoch + 1}')\n",
    "    \n",
    "    # Loop for dotted lines\n",
    "    for lines in [6, 10, 13, 15, 18, 20]:  \n",
    "        axs[i].axvline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)  # dotted vertical lines\n",
    "        axs[i].axhline(lines, linestyle=(0, (2, 5)), color='white', linewidth=1.5)  # dotted horizontal lines\n",
    "    \n",
    "    # Draw square bounding box\n",
    "    for (low, high) in [(6, 10), (13, 15), (18, 20)]:\n",
    "        axs[i].plot([low, high, high, low, low], [low, low, high, high, low], color='white', linewidth=3)  \n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Get today's date for the filename\n",
    "today = datetime.now().strftime('%Y_%m_%d')\n",
    "\n",
    "# Save the figure\n",
    "plt.savefig(f'Softmax_RNN_corr_plot_{today}.pdf', format='pdf', dpi=600)\n",
    "\n",
    "# Display the plot\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ead6799",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c68b8e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze crossing points (filtered runs only)\n",
    "cross_thresh = decorr_thresh\n",
    "run_plot = np.array([True] * np.sum(good_decorr_runs))\n",
    "where_cross = np.zeros((np.sum(good_decorr_runs), len(correlation_matrices)))\n",
    "\n",
    "for m in range(len(correlation_matrices)):\n",
    "    for run in range(np.sum(good_decorr_runs)):\n",
    "        cross_points = np.where(all_masks_matrix[m, good_decorr_runs][run] < cross_thresh)[0]\n",
    "        if len(cross_points) > 0:\n",
    "            where_cross[run, m] = cross_points[0]\n",
    "        else:\n",
    "            where_cross[run, m] = all_masks_matrix.shape[2]\n",
    "            run_plot[run] = False\n",
    "\n",
    "where_cross = where_cross[run_plot]\n",
    "\n",
    "# Normalize cross times\n",
    "max_cross_times = np.max(where_cross, axis=1) + 2\n",
    "normalized_cross = where_cross / max_cross_times[:, np.newaxis]\n",
    "\n",
    "# Reorder the data\n",
    "new_order = [0, 2, 1]  # Off-diagonal, Pre-R1, Pre-R2\n",
    "reordered_normalized_cross = normalized_cross[:, new_order]\n",
    "reordered_correlation_names = [correlation_names[i] for i in new_order]\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "x_plot = np.repeat(np.array([0,1,2])[:,np.newaxis], reordered_normalized_cross.shape[0], axis=1) + np.random.normal(0,0.05,(3,reordered_normalized_cross.shape[0]))\n",
    "plt.barh([0,1,2], [np.mean(reordered_normalized_cross[:,0]), np.mean(reordered_normalized_cross[:,1]), np.mean(reordered_normalized_cross[:,2])], color='lightblue')\n",
    "plt.plot(np.vstack((reordered_normalized_cross[:,0], reordered_normalized_cross[:,1], reordered_normalized_cross[:,2])), x_plot, color='gray')\n",
    "plt.plot(np.vstack((reordered_normalized_cross[:,0], reordered_normalized_cross[:,1], reordered_normalized_cross[:,2])), x_plot, 'o', color='black')\n",
    "plt.yticks([0,1,2], reordered_correlation_names)\n",
    "plt.title('Normalized Crossing Points Analysis (Filtered Runs)')\n",
    "plt.xlabel('Normalized Epoch')\n",
    "plt.xlim(0, 1)\n",
    "plt.show()\n",
    "\n",
    "print(\"Normalized crossing points analysis completed and plotted for filtered runs.\")\n",
    "print(f\"This plot shows when different regions' correlations cross the threshold of {decorr_thresh}, normalized by the longest cross time +2 for each run.\")\n",
    "print(f\"Mean normalized crossing time for {reordered_correlation_names[0]}: {np.mean(reordered_normalized_cross[:,0]):.2f}\")\n",
    "print(f\"Mean normalized crossing time for {reordered_correlation_names[1]}: {np.mean(reordered_normalized_cross[:,1]):.2f}\")\n",
    "print(f\"Mean normalized crossing time for {reordered_correlation_names[2]}: {np.mean(reordered_normalized_cross[:,2]):.2f}\")\n",
    "print(\"=\"*40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d20f02ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import stats\n",
    "\n",
    "# Calculate mean normalized crossing times\n",
    "mean_crossing_times = np.mean(reordered_normalized_cross, axis=0)\n",
    "\n",
    "print(\"Mean normalized crossing times:\")\n",
    "for name, mean_time in zip(reordered_correlation_names, mean_crossing_times):\n",
    "    print(f\"{name}: {mean_time:.4f}\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Perform t-tests between pairs of crossing times\n",
    "pairs = [(0, 1), (0, 2), (1, 2)]\n",
    "pair_names = [\n",
    "    ('Off-diagonal', 'Pre-R1'),\n",
    "    ('Off-diagonal', 'Pre-R2'),\n",
    "    ('Pre-R1', 'Pre-R2')\n",
    "]\n",
    "\n",
    "print(\"T-test results for normalized crossing times:\")\n",
    "for (i, j), (name1, name2) in zip(pairs, pair_names):\n",
    "    t_stat, p_value = stats.ttest_rel(reordered_normalized_cross[:, i], reordered_normalized_cross[:, j])\n",
    "    n = reordered_normalized_cross.shape[0]\n",
    "    print(f\"\\nComparing {name1} vs {name2}:\")\n",
    "    print(f\"t-statistic: {t_stat:.4f}\")\n",
    "    print(f\"p-value: {p_value:.4e}\")\n",
    "    print(f\"N: {n}\")\n",
    "    \n",
    "    # Interpret the results\n",
    "    if p_value < 0.05:\n",
    "        if mean_crossing_times[i] < mean_crossing_times[j]:\n",
    "            print(f\"{name1} decorrelates significantly earlier than {name2}\")\n",
    "        else:\n",
    "            print(f\"{name2} decorrelates significantly earlier than {name1}\")\n",
    "    else:\n",
    "        print(f\"No significant difference in decorrelation timing between {name1} and {name2}\")\n",
    "\n",
    "print(\"\\nNote: A small p-value (< 0.05) indicates a significant difference between the crossing times.\")\n",
    "print(\"=\"*40)\n",
    "\n",
    "# Overall interpretation\n",
    "print(\"Summary of decorrelation order:\")\n",
    "sorted_indices = np.argsort(mean_crossing_times)\n",
    "for i, idx in enumerate(sorted_indices):\n",
    "    print(f\"{i+1}. {reordered_correlation_names[idx]} (mean crossing time: {mean_crossing_times[idx]:.4f})\")\n",
    "\n",
    "print(\"\\nStatistically significant differences:\")\n",
    "for (i, j), (name1, name2) in zip(pairs, pair_names):\n",
    "    if stats.ttest_rel(reordered_normalized_cross[:, i], reordered_normalized_cross[:, j])[1] < 0.05:\n",
    "        if mean_crossing_times[i] < mean_crossing_times[j]:\n",
    "            print(f\"- {name1} decorrelates significantly earlier than {name2}\")\n",
    "        else:\n",
    "            print(f\"- {name2} decorrelates significantly earlier than {name1}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9241707f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "NM",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
