{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root_path = Path.cwd().parent.absolute()\n",
    "import sys\n",
    "sys.path.append(str(root_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import parametric_pushforward.data_sets as toy_data\n",
    "from parametric_pushforward.parametric_mlp import order_state_to_tensor\n",
    "from parametric_pushforward.visualization import disimilarity_snapshots,disimilarity_plot\n",
    "from parametric_pushforward.spline import Assemble_spline\n",
    "from parametric_pushforward.obstacles import obstacle_cost_stunnel, obstacle_cost_vneck, obstacle_cost_gmm,congestion_cost,geodesic\n",
    "from parametric_pushforward.opinion import PolarizeDyn,proj_pca\n",
    "from parametric_pushforward.setup_density_path_problem import load_boundary_models,get_activation,opinion_dynamics_setup,get_potential_functions\n",
    "\n",
    "from geomloss import SamplesLoss\n",
    "\n",
    "sinkhorn = SamplesLoss(loss = 'sinkhorn', p = 2, blur = 0.05)\n",
    "\n",
    "import os\n",
    "import yaml\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "torch.cuda.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# name of experiment\n",
    "exp_dir = str(root_path)+  '/experiments/opinion_1000d'\n",
    "yaml_path = os.path.join(exp_dir, 'config.yaml')\n",
    "with open(yaml_path) as file:\n",
    "    config = yaml.load(file, Loader=yaml.FullLoader)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "name_data0 = config['data']['source']['name']\n",
    "checkpt0 = 'final'#'checkpoint_1999'\n",
    "name_data1 = config['data']['target']['name']\n",
    "checkpt1 = 'final'#'checkpoint_1999'\n",
    "\n",
    "arch_dims = [config['architecture']['input_dim'],config['architecture']['hidden_dim'],config['architecture']['num_layers']]\n",
    "activation = get_activation(config['architecture']['activation'])\n",
    "\n",
    "arch = arch_dims+[activation]\n",
    "\n",
    "spline_type = config['spline']['type']\n",
    "\n",
    "\n",
    "prior = MultivariateNormal(torch.zeros(config['architecture']['input_dim']).to(device),torch.eye(config['architecture']['input_dim']).to(device))\n",
    "\n",
    "state0,state1 = load_boundary_models(config,device)\n",
    "theta0 = order_state_to_tensor(state0)\n",
    "theta1 = order_state_to_tensor(state1)\n",
    "\n",
    "if config.get('opinion_dynamics',{}).get('active',False):\n",
    "    print('Opinion dynamics active')\n",
    "    opinion_dynamics = opinion_dynamics_setup(config)\n",
    "    ke_modifier = [PolarizeDyn(opinion_dynamics).to(device)]\n",
    "else:\n",
    "    ke_modifier = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build spline\n",
    "num_collocation = config['spline']['num_collocation']\n",
    "spline0,t = Assemble_spline(theta0=theta0,\n",
    "                            theta1=theta1,\n",
    "                            arch=arch,\n",
    "                            data0=name_data0,\n",
    "                            data1=name_data1,\n",
    "                            ke_modifier=ke_modifier,\n",
    "                            potential=get_potential_functions(config['potential_functions']),\n",
    "                            number_of_knots=num_collocation,\n",
    "                            spline=spline_type,\n",
    "                            device = device,\n",
    "                            prior_dist=prior)\n",
    "\n",
    "spline0.sigma = config['coefficients_potentials']['sigma']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "spline_path = os.path.join(exp_dir, 'checkpoints/spline.pth')\n",
    "try:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)#['ema_model']\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "except:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)['ema_model']\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "\n",
    "spline0.eval()\n",
    "spline0.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = 5000\n",
    "t_node = 10\n",
    "x0 = torch.from_numpy(toy_data.inf_train_gen(name_data0, batch_size=samples,dim = config['architecture']['input_dim'])).float().to(device)\n",
    "# x1 = torch.from_numpy(toy_data.inf_train_gen(name_data1, batch_size=samples,dim = config['architecture']['input_dim'])).float().to(device)\n",
    "# # Get z values by flowing backwards\n",
    "z0 = spline0.pull_back(spline0.x0.flatten(),x0)\n",
    "# y0 = spline0.push_forward(spline0.x1.flatten(),z0)\n",
    "\n",
    "\n",
    "# z0 = spline0.prior_dist.sample((samples,)).to(device)\n",
    "# z1 = z0.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(z0[:,0].cpu().detach().numpy(),z0[:,1].cpu().detach().numpy(),s=1)\n",
    "#Samples from the prior\n",
    "prior_samples = prior.sample((samples,)).to(device)\n",
    "plt.scatter(prior_samples[:,0].cpu().detach().numpy(),prior_samples[:,1].cpu().detach().numpy(),s=1)\n",
    "plt.title('Samples from the prior and pullback')\n",
    "plt.xlabel('x0')\n",
    "plt.ylabel('x1')\n",
    "plt.savefig(os.path.join(exp_dir,'figures/pullback_samples.png'),dpi=300)\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path plot\n",
    "\n",
    "s = torch.linspace(0,1,10).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# interpolation0 = spline0(s)\n",
    "# samples_path0 = path_visualization_snapshots(interpolation=interpolation0,arch = arch,\n",
    "# spline = spline0,\n",
    "# x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "# y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "# x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "# y1 = config['visualization']['plot_bounds']['y_max'],\n",
    "# num_samples = 50,\n",
    "# time_steps = 10,solver = 'midpoint',\n",
    "# z = z0,num_contour_points = 250)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_path = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spline0.lagrangian(samples_path,s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = 10_000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "z0 = spline0.prior_dist.sample((samples,)).to(device)\n",
    "terminal_dist = spline0.push_forward(spline0.x1.flatten(),z0)\n",
    "x1 = torch.from_numpy(toy_data.inf_train_gen(name_data1, batch_size=samples,dim = config['architecture']['input_dim'])).float().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1_reshape = x1.unsqueeze(0).permute(1,0,2)\n",
    "terminal_dist_reshape = terminal_dist.unsqueeze(0).permute(1,0,2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load samples from gsbm\n",
    "\n",
    "import pickle\n",
    "\n",
    "gsbm_dir = str(root_path)+ '/results_gsbm/opinion_1000d/'\n",
    "\n",
    "with open(os.path.join(gsbm_dir+'xs.pickle'),'rb') as f:\n",
    "    gsbm_sol = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xs_gsbm = gsbm_sol['xs']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gsbm_sol['t']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spline0.lagrangian(xs_gsbm.to(device),gsbm_sol['t'].to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(2, 3, figsize=(8, 5))\n",
    "ax = ax.flatten()\n",
    "\n",
    "# Add this line to increase space between rows\n",
    "plt.subplots_adjust(hspace=0.5,wspace=0.5)  # Increase this value to add more space\n",
    "\n",
    "lim_plot_low = -20\n",
    "lim_plot_high = 20\n",
    "terminal_dist_pca = proj_pca(terminal_dist_reshape)[0]\n",
    "x1_pca = proj_pca(x1_reshape)[0]\n",
    "xs_gsbm_pca = proj_pca(xs_gsbm)[0]\n",
    "ax[0].scatter(x1_pca[:,0,0].cpu().detach().numpy(), x1_pca[:,0,1].cpu().detach().numpy(), s=1)\n",
    "ax[0].set_title('True')\n",
    "ax[0].set_xlim(lim_plot_low, lim_plot_high)\n",
    "ax[0].set_ylim(lim_plot_low, lim_plot_high)\n",
    "ax[1].scatter(terminal_dist_pca[:,0,0].cpu().detach().numpy(), terminal_dist_pca[:,0,1].cpu().detach().numpy(), s=1)\n",
    "ax[1].set_title('PDPO')\n",
    "ax[1].set_xlim(lim_plot_low, lim_plot_high)\n",
    "ax[1].set_ylim(lim_plot_low, lim_plot_high)\n",
    "ax[2].scatter(xs_gsbm_pca[:,-1,0].cpu().detach().numpy(), xs_gsbm_pca[:,-1,1].cpu().detach().numpy(), s=1)\n",
    "ax[2].set_title('GSBM')\n",
    "ax[2].set_xlim(lim_plot_low, lim_plot_high)\n",
    "ax[2].set_ylim(lim_plot_low, lim_plot_high)\n",
    "disimilarity_plot(x1_pca[:,0,:], ax[3])\n",
    "ax[3].set_title('Dissimilarity plot')\n",
    "disimilarity_plot(terminal_dist_pca[:,0,:], ax[4])\n",
    "ax[4].set_title('Dissimilarity plot')\n",
    "disimilarity_plot(xs_gsbm_pca[:,-1,:], ax[5])\n",
    "ax[5].set_title('Dissimilarity plot')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disimilarity_snapshots(samples_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disimilarity_snapshots(xs_gsbm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PDPO",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
