{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import scipy\n",
    "from matplotlib import pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from tqdm import tqdm\n",
    "from scipy.linalg.lapack import zggev\n",
    "from scipy.linalg import block_diag\n",
    "import pandas as pd\n",
    "from sklearn.cross_decomposition import CCA, PLSCanonical\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from direct_effect_analysis import * \n",
    "from utils import *\n",
    "from mvlearn.embed import GCCA\n",
    "from pgmpy.estimators import PC\n",
    "from plotnine import ggplot, aes, geom_line, geom_ribbon, scale_x_log10, scale_x_continuous, labs, theme, facet_wrap, ggsave\n",
    "from sklearn.ensemble import RandomForestRegressor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "algorithms = ['PCA', 'genCCA', 'OptiDet', 'Roy', 'RoyS']\n",
    "B_conds = ['d', '1', '1/d', '1/d^2']\n",
    "Sigma_conds = ['d', '1', '1/d', '1/d^2']\n",
    "results = {B_cond:{Sigma_cond:{algo: {} for algo in algorithms } for Sigma_cond in Sigma_conds} for B_cond in B_conds} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "p, r, d, q, rk = 10, 10, 100, 1, 10\n",
    "N = 1000\n",
    "dimensions = [2, 5, 10, 20, 50, 100, 200]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/4 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "a, b, c = 0.1, 0.1, 0.8\n",
    "B = 1\n",
    "for B_cond in tqdm(B_conds):\n",
    "    for Sigma_cond in Sigma_conds:\n",
    "        for algo in algorithms:\n",
    "            # print(noise, algo)\n",
    "            CORRS = []\n",
    "            for d in dimensions:\n",
    "                d = int(d)\n",
    "                CORR = []\n",
    "                for j in range(B):\n",
    "\n",
    "                    # Generate parameters\n",
    "                    _, A_z, beta, gamma = parameters(r, p, q, d, False)\n",
    "\n",
    "                    if B_cond == 'd':\n",
    "                        A_x = np.array([i for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1':\n",
    "                        A_x = np.array([1 for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1/d':\n",
    "                        A_x = np.array([1/(i+1) for i in range(d)])[:, None].T\n",
    "                    elif B_cond == '1/d^2':\n",
    "                        A_x = np.array([1/((i+1)**2) for i in range(d)])[:, None].T\n",
    "\n",
    "                    if Sigma_cond == 'd':\n",
    "                        Sigma = np.diag(np.array([i for i in range(d)]))\n",
    "                    elif Sigma_cond == '1':\n",
    "                        Sigma = np.identity(d)\n",
    "                    elif Sigma_cond == '1/d':\n",
    "                        Sigma = np.diag(np.array([1/(i+1) for i in range(d)]))\n",
    "                    elif Sigma_cond == '1/d^2':\n",
    "                        Sigma = np.diag(np.array([1/((i+1)**2) for i in range(d)]))\n",
    "\n",
    "\n",
    "                    # Generate training and test data\n",
    "                    X_train, Y_train, Z_train, Y_x_train = generate_data_Sigma_nonlinear(N, p, r, d, beta, gamma, A_x, A_z, Sigma, a=a, b=b, c=c, nl_param=2)\n",
    "                    # Generate training and test data\n",
    "                    dea = DirectEffectAnalysis(type=algo, alpha=1e-5, regressor_0=RandomForestRegressor(n_estimators=100), regressor_1=RandomForestRegressor(n_estimators=100))\n",
    "                    dea.fit(X_train, Y_train, Z_train)\n",
    "                    Y_hat = dea.transform(X_train, Y_train, Z_train)\n",
    "                    corr = np.abs(np.corrcoef(Y_hat, Y_x_train[:,0])[0, 1])\n",
    "                    \n",
    "                    # Compute MSE for each iteration\n",
    "                    CORR.append(corr)\n",
    "                CORRS.append(CORR)\n",
    "            # Append MSE for current sample size\n",
    "            results[B_cond][Sigma_cond][algo] = CORRS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_889/2443074691.py:32: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
      "Try using .loc[row_indexer,col_indexer] = value instead\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "/tmp/ipykernel_889/2443074691.py:45: FutureWarning: Using print(plot) to draw and show the plot figure is deprecated and will be removed in a future version. Use plot.show().\n",
      "INFO:matplotlib.font_manager:Fontsize 0.00 < 1.0 pt not allowed by FreeType. Setting fontsize = 1 pt\n"
     ]
    }
   ],
   "source": [
    "# Assuming you have the following data structures\n",
    "algo_names = {'PCA':'PCA', 'genCCA':'T_C', 'OptiDet':'T_D', 'Roy':'T_F', 'RoyS':'T_S'}\n",
    "# results, B_conds, Sigma_conds, algorithms, dimensions are defined elsewhere\n",
    "\n",
    "# Prepare data for ggplot\n",
    "data = []\n",
    "for B_cond in B_conds:\n",
    "    for Sigma_cond in Sigma_conds:\n",
    "        for algo in algorithms:\n",
    "            algo_data = results[B_cond][Sigma_cond][algo]\n",
    "            values = np.array(algo_data)\n",
    "            median = np.percentile(values, 50, axis=1)\n",
    "            lower = np.percentile(values, 5, axis=1)\n",
    "            upper = np.percentile(values, 95, axis=1)\n",
    "            \n",
    "            for i, dim in enumerate(dimensions):\n",
    "                data.append({\n",
    "                    'Dimension': dim,\n",
    "                    'Median': median[i],\n",
    "                    'Lower': lower[i],\n",
    "                    'Upper': upper[i],\n",
    "                    'Algorithm': algo_names[algo],\n",
    "                    'Sigma': Sigma_cond,\n",
    "                    'B': B_cond  # Add weight name to the data\n",
    "                })\n",
    "\n",
    "# Create a DataFrame\n",
    "df = pd.DataFrame(data)\n",
    "df2 = df[df['Algorithm']!='PCA']\n",
    "\n",
    "# Create a new column for the facet label\n",
    "df2['Facet_Label'] = 'Sigma: Diag(...,' + df2['Sigma'].astype(str) + '), B: [...,' + df2['B'].astype(str) + ']'\n",
    "\n",
    "# Plot using ggplot with logarithmic scale on the x-axis\n",
    "p = (ggplot(df2, aes(x='Dimension', y='Median', color='Algorithm', fill='Algorithm'))\n",
    "     + geom_line(size=1.5)\n",
    "     + geom_ribbon(aes(ymin='Lower', ymax='Upper'), alpha=0.2)\n",
    "     + scale_x_log10(breaks=dimensions)  # Set specific x-axis ticks to log scale\n",
    "     + labs(x='Dimension (d)', y='Correlation')\n",
    "     + facet_wrap('~Facet_Label', ncol=4)  # Use the new Facet_Label for facets\n",
    "     + theme(legend_position='right', figure_size=(12, 8))  # Set figure size\n",
    ")\n",
    "\n",
    "# Display the plot\n",
    "print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
