{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import scipy\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from tqdm import tqdm\n",
    "from scipy.linalg.lapack import zggev\n",
    "from scipy.linalg import block_diag\n",
    "import pandas as pd\n",
    "from sklearn.cross_decomposition import CCA, PLSCanonical\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from direct_effect_analysis import * \n",
    "from utils import *\n",
    "from mvlearn.embed import GCCA\n",
    "from pgmpy.estimators import PC\n",
    "from plotnine import ggplot, aes, geom_line, geom_ribbon, scale_x_log10, scale_x_continuous, labs, theme, facet_wrap, ggsave, theme_bw, element_text, facet_grid, scale_color_manual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithms = ['T_D', 'T_F', 'T_S', 'PCA', 'pCCA']\n",
    "B_conds = ['d', '1', '1/d', '1/d^2']\n",
    "Sigma_conds = ['d', '1', '1/d', '1/d^2']\n",
    "results = {B_cond:{Sigma_cond:{algo: {} for algo in algorithms } for Sigma_cond in Sigma_conds} for B_cond in B_conds} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, r, d, q, rk = 10, 10, 100, 1, 10\n",
    "N = 100\n",
    "dimensions = [2, 5, 10, 20, 50, 100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a, b, c = 0.1, 0.1, 0.8\n",
    "B = 10\n",
    "for B_cond in tqdm(B_conds):\n",
    "    for Sigma_cond in Sigma_conds:\n",
    "        for algo in algorithms:\n",
    "            # print(noise, algo)\n",
    "            CORRS = []\n",
    "            for d in dimensions:\n",
    "                d = int(d)\n",
    "                CORR = []\n",
    "                for j in range(B):\n",
    "\n",
    "                    # Generate parameters\n",
    "                    _, A_z, beta, gamma = parameters(r, p, q, d, False)\n",
    "\n",
    "                    if B_cond == 'd':\n",
    "                        A_x = np.array([i for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1':\n",
    "                        A_x = np.array([1 for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1/d':\n",
    "                        A_x = np.array([1/(i+1) for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1/d^2':\n",
    "                        A_x = np.array([1/((i+1)**2) for i in range(d)])[:, None].T\n",
    "\n",
    "                    if Sigma_cond == 'd':\n",
    "                        Sigma = np.diag(np.array([i for i in range(d)]))\n",
    "                    elif Sigma_cond == '1':\n",
    "                        Sigma = np.identity(d)\n",
    "                    elif Sigma_cond == '1/d':\n",
    "                        Sigma = np.diag(np.array([1/(i+1) for i in range(d)]))\n",
    "                    elif Sigma_cond == '1/d^2':\n",
    "                        Sigma = np.diag(np.array([1/((i+1)**2) for i in range(d)]))\n",
    "\n",
    "\n",
    "                    # Generate training and test data\n",
    "                    if d>50 and N < 200:\n",
    "                        alpha = 1e2\n",
    "                    else :\n",
    "                        alpha = 1e-5\n",
    "                    X_train, Y_train, Z_train, Y_x_train = generate_data_Sigma(N, p, r, d, beta, gamma, A_x, A_z, Sigma, a=a, b=b, c=c)\n",
    "                    dea = DirectEffectAnalysis(type=algo, alpha=1e-5)\n",
    "                    dea.fit(X_train, Y_train, Z_train)\n",
    "                    Y_hat = dea.transform(X_train, Y_train, Z_train)\n",
    "                    corr = np.abs(np.corrcoef(Y_hat, Y_x_train[:,0])[0, 1])\n",
    "                    \n",
    "                    # Compute MSE for each iteration\n",
    "                    CORR.append(corr)\n",
    "                CORRS.append(CORR)\n",
    "            # Append MSE for current sample size\n",
    "            results[B_cond][Sigma_cond][algo] = CORRS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming you have the following data structures\n",
    "algo_names = {'T_D':'T_D', 'T_F':'T_F', 'T_S':'T_S', 'PCA':'PCA', 'pCCA':'pCCA'}\n",
    "# results, B_conds, Sigma_conds, algorithms, dimensions are defined elsewhere\n",
    "\n",
    "# Prepare data for ggplot\n",
    "data = []\n",
    "for B_cond in B_conds:\n",
    "    for Sigma_cond in Sigma_conds:\n",
    "        for algo in algorithms:\n",
    "            algo_data = results[B_cond][Sigma_cond][algo]\n",
    "            values = np.array(algo_data)\n",
    "            median = np.percentile(values, 50, axis=1)\n",
    "            lower = np.percentile(values, 5, axis=1)\n",
    "            upper = np.percentile(values, 95, axis=1)\n",
    "            \n",
    "            for i, dim in enumerate(dimensions):\n",
    "                data.append({\n",
    "                    'Dimension': dim,\n",
    "                    'Median': median[i],\n",
    "                    'Lower': lower[i],\n",
    "                    'Upper': upper[i],\n",
    "                    'Algorithm': algo_names[algo],\n",
    "                    'Sigma': Sigma_cond,\n",
    "                    'B': B_cond  # Add weight name to the data\n",
    "                })\n",
    "\n",
    "# Create a DataFrame\n",
    "df = pd.DataFrame(data)\n",
    "df['B'] = df['B'].astype('category')\n",
    "df['B'] = df['B'].cat.reorder_categories(['d', '1', '1/d', '1/d^2'])\n",
    "\n",
    "df['Sigma'] = df['Sigma'].astype('category')\n",
    "df['Sigma'] = df['Sigma'].cat.reorder_categories(['d', '1', '1/d', '1/d^2'])\n",
    "\n",
    "\n",
    "# df2 = df[df['Algorithm']!='PCA']\n",
    "df2 = df\n",
    "\n",
    "# Create a new column for the facet label\n",
    "df2['Facet_Label'] = '$\\\\mathbf{\\\\sigma}=' + df2['Sigma'].astype(str) + ',\\\\hspace{1} \\\\mathbf{b}_i=' + df2['B'].astype(str) + '$'\n",
    "\n",
    "p = (ggplot(df2, aes(x='Dimension', y='Median', color='Algorithm', fill='Algorithm'))\n",
    "     + geom_line(size=1.5)\n",
    "     + geom_ribbon(aes(ymin='Lower', ymax='Upper', fill='Algorithm', color='Algorithm'), alpha=0.2)\n",
    "     + scale_color_manual(values={'PCA': 'rgba(255, 0, 0, 0.3)', 'Alg2': 'rgba(0, 255, 0, 0.3)', 'Alg3': 'rgba(0, 0, 255, 0.3)'})  # Adjust opacity\n",
    "     + scale_x_log10(breaks=dimensions)  \n",
    "     + labs(x='Dimension (d)', y='Absolute Correlation')\n",
    "     + facet_grid(rows=\"B\", cols=\"Sigma\", labeller=\"label_both\")\n",
    "     + theme_bw()\n",
    "     + theme(\n",
    "         legend_position='bottom', \n",
    "         figure_size=(12, 12),\n",
    "         axis_title=element_text(size=23),\n",
    "         axis_text=element_text(size=22),\n",
    "         axis_text_x=element_text(angle=55, hjust=1),\n",
    "         legend_title=element_text(size=22),\n",
    "         legend_text=element_text(size=25),\n",
    "         strip_text=element_text(size=18)\n",
    "     )\n",
    ")\n",
    "\n",
    "\n",
    "# Display the plot\n",
    "print(p)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
