{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0696a88d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root_path = Path.cwd().parent.absolute()\n",
    "import sys\n",
    "sys.path.append(str(root_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e3237b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import parametric_pushforward.data_sets as toy_data\n",
    "from parametric_pushforward.parametric_mlp import order_state_to_tensor\n",
    "from parametric_pushforward.visualization import path_visualization_with_trajectories,path_visualization_particles,create_particle_animation\n",
    "from parametric_pushforward.spline import Assemble_spline\n",
    "from parametric_pushforward.obstacles import obstacle_cost_stunnel, obstacle_cost_vneck, obstacle_cost_gmm,congestion_cost,geodesic\n",
    "from parametric_pushforward.opinion import PolarizeDyn\n",
    "from parametric_pushforward.setup_density_path_problem import load_boundary_models,get_activation,opinion_dynamics_setup,get_potential_functions,setup_prior\n",
    "\n",
    "from geomloss import SamplesLoss\n",
    "\n",
    "\n",
    "\n",
    "import os\n",
    "import yaml\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b774a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "torch.cuda.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17cb15c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# name of experiment\n",
    "exp_dir = str(root_path)+ '/experiments/gmm_seed0'\n",
    "yaml_path = os.path.join(exp_dir, 'config.yaml')\n",
    "with open(yaml_path) as file:\n",
    "    config = yaml.load(file, Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a94b88cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "name_data0 = config['data']['source']['name']\n",
    "checkpt0 = 'final'#'checkpoint_1999'\n",
    "name_data1 = config['data']['target']['name']\n",
    "checkpt1 = 'final'#'checkpoint_1999'\n",
    "\n",
    "arch_dims = [config['architecture']['input_dim'],config['architecture']['hidden_dim'],config['architecture']['num_layers']]\n",
    "activation = get_activation(config['architecture']['activation'])\n",
    "\n",
    "arch = arch_dims+[activation]\n",
    "\n",
    "spline_type = config['spline']['type']\n",
    "\n",
    "\n",
    "prior = setup_prior(config,device)\n",
    "\n",
    "state0,state1 = load_boundary_models(config,device)\n",
    "theta0 = order_state_to_tensor(state0)\n",
    "theta1 = order_state_to_tensor(state1)\n",
    "\n",
    "if config.get('opinion_dynamics',{}).get('active',False):\n",
    "    print('Opinion dynamics active')\n",
    "    opinion_dynamics = opinion_dynamics_setup(config)\n",
    "    ke_modifier = [PolarizeDyn(opinion_dynamics).to(device)]\n",
    "else:\n",
    "    ke_modifier = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b212938",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build spline\n",
    "num_collocation = config['spline']['num_collocation']\n",
    "spline0,t = Assemble_spline(theta0=theta0,\n",
    "                            theta1=theta1,\n",
    "                            arch=arch,\n",
    "                            data0=name_data0,\n",
    "                            data1=name_data1,\n",
    "                            ke_modifier=ke_modifier,\n",
    "                            potential=get_potential_functions(config['potential_functions']),\n",
    "                            number_of_knots=num_collocation,\n",
    "                            spline=spline_type,\n",
    "                            device = device,\n",
    "                            prior_dist=prior)\n",
    "\n",
    "spline0.sigma = config['coefficients_potentials']['sigma']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90746f3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "spline_path = os.path.join(exp_dir, 'checkpoints/spline.pth') #geo_initial / initial / spline\n",
    "try:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)#['ema_model']\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "except:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)['ema_model'] #direct_model #ema_model\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "    \n",
    "\n",
    "spline0.eval()\n",
    "spline0.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97959321",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path plot\n",
    "\n",
    "s = torch.linspace(0.,1,50).to(device)\n",
    "num_samples = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be67f6c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "z0 = spline0.prior_dist.sample((num_samples,)).to(device)\n",
    "interpolation0 = spline0(s)\n",
    "samples_path0 = path_visualization_with_trajectories(interpolation=interpolation0,arch = arch,\n",
    "spline = spline0,\n",
    "x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],\n",
    "num_samples = 50,\n",
    "time_steps = 40,solver = 'midpoint',\n",
    "z = z0,num_contour_points = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8123739",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_visualization_particles(samples_path0.detach().cpu(),spline0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6152f4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "if spline0.fisher_pot and not spline0.entropy_pot:\n",
    "    norm_socre0,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "    entropy = None\n",
    "elif spline0.entropy_pot and not spline0.fisher_pot:\n",
    "    entropy,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "    norm_score0 = None\n",
    "elif spline0.entropy_pot and spline0.fisher_pot:\n",
    "    entropy,norm_socre0,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "else:\n",
    "    entropy = None\n",
    "    norm_socre0 = None\n",
    "\n",
    "lagrangian0,ke,pe = spline0.lagrangian(samples_path0.to(device),s,log_density= entropy,score=norm_socre0)\n",
    "\n",
    "print(lagrangian0,ke,pe)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a250c14",
   "metadata": {},
   "source": [
    "# GSBM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c051e18",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GSBM\n",
    "\n",
    "path_gsbm = str(root_path)+ '/results_gsbm/gmm'#/seed0\n",
    "\n",
    "import pickle\n",
    "import jax\n",
    "import numpy as np\n",
    "direction = 'fwd'\n",
    "with open(path_gsbm + '/xs'+direction+'.pickle', 'rb') as f:\n",
    "    samples = pickle.load(f)\n",
    "\n",
    "# samples_path = torch.from_numpy(np.asarray(samples))\n",
    "samples_gsbm =samples['xs'].detach().cpu()\n",
    "print(samples_gsbm.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a726da83",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_visualization_particles(samples_gsbm,spline0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e05c206",
   "metadata": {},
   "outputs": [],
   "source": [
    "spline0.lagrangian(samples['xs'],samples['t'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db6d9770",
   "metadata": {},
   "source": [
    "# APAC-NET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41b603b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # APAC-NET\n",
    "\n",
    "# path_apac = str(root_path)+ '/results_apacnet/seed0'\n",
    "\n",
    "# with open(path_apac + '/xs.pickle', 'rb') as f:\n",
    "#     samples_apac_net = pickle.load(f)\n",
    "\n",
    "# samples_apac_net = torch.from_numpy(samples_apac_net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79279fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# path_visualization_particles(samples_apac_net,spline0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "042b46a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t_apac_net = torch.linspace(0,1,samples_apac_net.shape[1])\n",
    "# spline0.lagrangian(samples_apac_net,t_apac_net)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2024a076",
   "metadata": {},
   "source": [
    "# NLOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59e4b30c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "path_nlot = str(root_path)+ '/results_nlot/gmm/seed0'\n",
    "\n",
    "samples_nlot = torch.from_numpy(np.load(path_nlot + '/xs.npy'))\n",
    "\n",
    "t_nlot = torch.linspace(0,1,samples_nlot.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06c7b38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "path_visualization_particles(samples_nlot,spline0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e1dd401",
   "metadata": {},
   "outputs": [],
   "source": [
    "spline0.lagrangian(samples_nlot,t_nlot)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce26ea95",
   "metadata": {},
   "source": [
    "# Animations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b115cb39",
   "metadata": {},
   "outputs": [],
   "source": [
    "animation = create_particle_animation(spline0,samples_path0.detach().cpu().permute(1,0,2),x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],interval=250)\n",
    "display(animation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f7a3d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "animation = create_particle_animation(spline0,samples['xs'].detach().cpu().permute(1,0,2),x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],interval=250)\n",
    "display(animation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2270263e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# animation = create_particle_animation(spline0,samples_apac_net.cpu().permute(1,0,2),x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "# y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "# x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "# y1 = config['visualization']['plot_bounds']['y_max'],interval=250)\n",
    "# display(animation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdeee9ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "animation = create_particle_animation(spline0,samples_nlot.cpu().permute(1,0,2),x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],interval=250)\n",
    "display(animation)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1c23ad2",
   "metadata": {},
   "source": [
    "# Compare boundaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd1b225d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ot as pot\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7215f7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare boundaries\n",
    "x0 = torch.from_numpy(toy_data.inf_train_gen(name_data0, batch_size=num_samples,dim = config['architecture']['input_dim'])).float()#.to(device)\n",
    "x1 = torch.from_numpy(toy_data.inf_train_gen(name_data1, batch_size=num_samples,dim = config['architecture']['input_dim'])).float()#.to(device)\n",
    "a,b =  torch.ones(x0.shape[0])/num_samples,torch.ones(x1.shape[0])/num_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dab8b86a",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0_PDPO = samples_path0[:,0,:].detach().cpu()\n",
    "x1_PDPO = samples_path0[:,-1,:].detach().cpu()\n",
    "M0 = pot.dist(x0,x0_PDPO,metric = 'euclidean')\n",
    "M1 = pot.dist(x1_PDPO,x1,metric = 'euclidean')\n",
    "\n",
    "print('OT distance PDPO:{},{}'.format(pot.emd2(a,b,M0,numItermax=500000),pot.emd2(a,b,M1,numItermax=500000)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82e799b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0_GSBM = samples['xs'][:,0,:].detach().cpu()\n",
    "x1_GSBM = samples['xs'][:,-1,:].detach().cpu()  \n",
    "\n",
    "M0 = pot.dist(x0, x0_GSBM,metric = 'euclidean')\n",
    "M1 = pot.dist(x1_GSBM,x1,metric = 'euclidean')\n",
    "\n",
    "\n",
    "print('OT distance GSBM:{},{}'.format(pot.emd2(a,b,M0),pot.emd2(a,b,M1)))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "162e7483",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.scatter(x0_PDPO[:,0].detach().cpu(),x0_PDPO[:,1].detach().cpu(),label = 'PDPO',s = 1)\n",
    "plt.scatter(x0_GSBM[:,0].detach().cpu(),x0_GSBM[:,1].detach().cpu(),label = 'GSBM',s = 1)\n",
    "plt.scatter(x0[:,0].detach().cpu(),x0[:,1].detach().cpu(),label = 'source',s = 1)\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c982cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 15))\n",
    "plt.scatter(x1[:,0].detach().cpu(),x1[:,1].detach().cpu(),label = 'target',s = 10)\n",
    "plt.scatter(x1_GSBM[:,0].detach().cpu(),x1_GSBM[:,1].detach().cpu(),label = 'GSBM',s = 10)\n",
    "plt.scatter(x1_PDPO[:,0].detach().cpu(),x1_PDPO[:,1].detach().cpu(),label = 'PDPO',s = 10)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b472228c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# x0_APAC = samples_apac_net[:,0,:].detach().cpu()\n",
    "# x1_APAC = samples_apac_net[:,-1,:].detach().cpu()\n",
    "\n",
    "# M0 = pot.dist(x0_APAC,x0,metric = 'euclidean')\n",
    "# M1 = pot.dist(x1_APAC,x1,metric = 'euclidean')\n",
    "# print('OT distance APAC:{},{}'.format(pot.emd2(a,b,M0),pot.emd2(a,b,M1)))\n",
    "\n",
    "# # print('Sinkhorn distance APAC:{},{}'.format(sinkhorn(x0,x0_APAC),sinkhorn(x1,x1_APAC)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25b93de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x0_nlot = samples_nlot[:,0,:].detach().cpu()\n",
    "x1_nlot = samples_nlot[:,-1,:].detach().cpu()\n",
    "M0 = pot.dist(x0_nlot,x0,metric = 'euclidean')\n",
    "M1 = pot.dist(x1_nlot,x1,metric = 'euclidean')\n",
    "print('OT distance NLOT:{},{}'.format(pot.emd2(a,b,M0),pot.emd2(a,b,M1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b89e46",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x1_nlot[:,0].detach().cpu(),x1_nlot[:,1].detach().cpu(),label = 'NLOT',s = 1)\n",
    "plt.scatter(x1[:,0].detach().cpu(),x1[:,1].detach().cpu(),label = 'target',s = 1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PDPO",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
