{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root_path = Path.cwd().parent.absolute()\n",
    "import sys\n",
    "sys.path.append(str(root_path))\n",
    "import os\n",
    "import yaml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import parametric_pushforward.data_sets as toy_data\n",
    "from parametric_pushforward.parametric_mlp import order_state_to_tensor\n",
    "from parametric_pushforward.visualization import path_visualization_snapshots, path_visualization_with_trajectories,display_bds,create_particle_animation,path_visualization_particles\n",
    "from parametric_pushforward.spline import Assemble_spline\n",
    "from parametric_pushforward.obstacles import obstacle_cost_stunnel, obstacle_cost_vneck, obstacle_cost_gmm,congestion_cost,geodesic\n",
    "from parametric_pushforward.opinion import PolarizeDyn\n",
    "from parametric_pushforward.setup_density_path_problem import load_boundary_models,get_activation,opinion_dynamics_setup,get_potential_functions,setup_prior\n",
    "from parametric_pushforward.parametric_mlp import MLP\n",
    "\n",
    "import ot as pot\n",
    "\n",
    "from geomloss import SamplesLoss\n",
    "\n",
    "sinkhorn = SamplesLoss(loss = 'sinkhorn', p = 2, blur = 0.05)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "torch.cuda.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# name of experiment\n",
    "exp_dir = str(root_path)+ '/experiments/gmm_seed0'\n",
    "yaml_path = os.path.join(exp_dir, 'config.yaml')\n",
    "with open(yaml_path) as file:\n",
    "    config = yaml.load(file, Loader=yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "name_data0 = config['data']['source']['name']\n",
    "checkpt0 = 'final'#'checkpoint_1999'\n",
    "name_data1 = config['data']['target']['name']\n",
    "checkpt1 = 'final'#'checkpoint_1999'\n",
    "\n",
    "arch_dims = [config['architecture']['input_dim'],config['architecture']['hidden_dim'],config['architecture']['num_layers']]\n",
    "activation = get_activation(config['architecture']['activation'])\n",
    "\n",
    "arch = arch_dims+[activation]\n",
    "\n",
    "spline_type = config['spline']['type']\n",
    "\n",
    "\n",
    "prior = setup_prior(config,device)\n",
    "\n",
    "if config['data']['source']['checkpoint'] == 'None' or config['data']['target']['checkpoint'] == 'None':\n",
    "    # Setup architecture\n",
    "    arch = [\n",
    "        config['architecture']['input_dim'],\n",
    "        config['architecture']['hidden_dim'],\n",
    "        config['architecture']['num_layers'],\n",
    "        torch.nn.Softplus()\n",
    "    ]\n",
    "    # Initialize dummy model\n",
    "    model0 = MLP(arch, time_varying=config['architecture']['time_varying']).to(device)\n",
    "    # Get weights from model\n",
    "    theta0 = order_state_to_tensor(model0.state_dict())\n",
    "    theta1 = theta0.clone()\n",
    "else:\n",
    "\n",
    "    # Load boundary models\n",
    "    state0, state1 = load_boundary_models(config, device)\n",
    "    theta0 = order_state_to_tensor(state0)\n",
    "    theta1 = order_state_to_tensor(state1)\n",
    "\n",
    "\n",
    "\n",
    "if config.get('opinion_dynamics',{}).get('active',False):\n",
    "    print('Opinion dynamics active')\n",
    "    opinion_dynamics = opinion_dynamics_setup(config)\n",
    "    ke_modifier = [PolarizeDyn(opinion_dynamics).to(device)]\n",
    "else:\n",
    "    ke_modifier = None\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build spline\n",
    "num_collocation = config['spline']['num_collocation']\n",
    "spline0,t = Assemble_spline(theta0=theta0,\n",
    "                            theta1=theta1,\n",
    "                            arch=arch,\n",
    "                            data0=name_data0,\n",
    "                            data1=name_data1,\n",
    "                            ke_modifier=ke_modifier,\n",
    "                            potential=get_potential_functions(config['potential_functions']),\n",
    "                            number_of_knots=num_collocation,\n",
    "                            spline=spline_type,\n",
    "                            device = device,\n",
    "                            prior_dist=prior)\n",
    "\n",
    "spline0.sigma = config['coefficients_potentials']['sigma']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "spline_path = os.path.join(exp_dir, 'checkpoints/spline.pth') #geo_inital / initial / spline / spline_i\n",
    "try:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)#['ema_model']\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "except:\n",
    "    state_spline0 = torch.load(spline_path,map_location=device)['ema_model'] #direct_model #ema_model\n",
    "    spline0.load_state_dict(state_spline0)\n",
    "    \n",
    "\n",
    "spline0.eval()\n",
    "spline0.sigma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples_per_gaussian = 3000\n",
    "t_node = 10\n",
    "x0 = torch.from_numpy(toy_data.inf_train_gen(name_data0, batch_size=samples_per_gaussian,dim = config['architecture']['input_dim'])).float().to(device)\n",
    "x1 = torch.from_numpy(toy_data.inf_train_gen(name_data1, batch_size=samples_per_gaussian,dim = config['architecture']['input_dim'])).float().to(device)\n",
    "# Evaluate modle\n",
    "z = spline0.prior_dist.sample((samples_per_gaussian,))\n",
    "x0_aprox = spline0.push_forward(spline0.x0.flatten(),z)\n",
    "x1_aprox = spline0.push_forward(spline0.x1.flatten(),z)\n",
    "# Obtain distances\n",
    "print('Sinkhorn distance to source density: {}'.format(sinkhorn(x0,x0_aprox)))\n",
    "print('Sinkhorn distance to target density: {}'.format(sinkhorn(x1,x1_aprox)))\n",
    "\n",
    "\n",
    "# Get z values by flowing backwards\n",
    "z0 = spline0.pull_back(spline0.x0.flatten(),x0)\n",
    "y0 = spline0.push_forward(spline0.x1.flatten(),z0)\n",
    "\n",
    "\n",
    "z0 = spline0.prior_dist.sample((samples_per_gaussian,)).to(device)\n",
    "z1 = z0.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x0.cpu().detach().numpy()[:,0],x0.cpu().detach().numpy()[:,1],s=1)\n",
    "plt.scatter(x0_aprox.cpu().detach().numpy()[:,0],x0_aprox.cpu().detach().numpy()[:,1],s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(x1.cpu().detach().numpy()[:,0],x1.cpu().detach().numpy()[:,1],s=1)\n",
    "plt.scatter(x1_aprox.cpu().detach().numpy()[:,0],x1_aprox.cpu().detach().numpy()[:,1],s=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_bds(spline0,device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For checkboard or 8gmm coupling visualization\n",
    "\n",
    "colors = ['r','g','b','c','m','y','k','orange']#\n",
    "fig,ax = plt.subplots(1,1,figsize=(5,5))\n",
    "ax.grid(True,linestyle = '--',alpha = 0.7)\n",
    "for i in range(8):\n",
    "    plt.scatter(x0[i*samples_per_gaussian//8:(i+1)*samples_per_gaussian//8,0].cpu().detach().numpy(),x0[i*samples_per_gaussian//8:(i+1)*samples_per_gaussian//8,1].cpu().detach().numpy(),color=colors[i],s = 1)\n",
    "    plt.scatter(y0[i*samples_per_gaussian//8:(i+1)*samples_per_gaussian//8,0].cpu().detach().numpy(),y0[i*samples_per_gaussian//8:(i+1)*samples_per_gaussian//8,1].cpu().detach().numpy(),color=colors[i],marker='x', s = 1)\n",
    "\n",
    "plt.title('Induced coupling')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path plot\n",
    "\n",
    "s = torch.linspace(0,1,50).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "interpolation0 = spline0(t)\n",
    "print('Number of interpolation points:',len(t)-2)\n",
    "samples_path0 = path_visualization_with_trajectories(interpolation=interpolation0,arch = arch,\n",
    "spline = spline0,\n",
    "x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],\n",
    "\n",
    "num_samples = 50,\n",
    "time_steps = 10,solver = 'midpoint',\n",
    "z = z0,num_contour_points = 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_visualization_particles(samples_path0.detach().cpu(),spline0, title= r'Optimized $\\{(T_{t_i})_{\\#}(\\lambda)\\}_{i = 0}^{K+1}$') # Optimized $\\{(T_{t_i})_{\\#}(\\lambda)\\}_{i = 0}^{K+1}$  density path path with #,title = r\"Optimized path \""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpolation0 = spline0(s)\n",
    "samples_path0 = path_visualization_snapshots(interpolation=interpolation0,arch = arch,\n",
    "spline = spline0,\n",
    "x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],\n",
    "num_samples = 50,\n",
    "time_steps = 10,solver = 'midpoint',\n",
    "z = z0,num_contour_points = 250)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "animation = create_particle_animation(spline0,samples_path0.detach().cpu().permute(1,0,2),x0 = config['visualization']['plot_bounds']['x_min'],\n",
    "y0 = config['visualization']['plot_bounds']['y_min'],\n",
    "x1 = config['visualization']['plot_bounds']['x_max'],\n",
    "y1 = config['visualization']['plot_bounds']['y_max'],interval=250)\n",
    "display(animation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if spline0.fisher_pot and not spline0.entropy_pot:\n",
    "    norm_socre0,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "    entropy = None\n",
    "elif spline0.entropy_pot and not spline0.fisher_pot:\n",
    "    entropy,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "    norm_score0 = None\n",
    "elif spline0.entropy_pot and spline0.fisher_pot:\n",
    "    entropy,norm_socre0,samples_path0 = spline0.gen_sample_trajectory(z0,num_samples=len(z0),t_traj=s,time_steps_node=10,solver='midpoint')\n",
    "else:\n",
    "    entropy = None\n",
    "    norm_socre0 = None\n",
    "\n",
    "lagrangian0,ke,pe = spline0.lagrangian(samples_path0.to(device),s,log_density= entropy,score=norm_socre0)\n",
    "\n",
    "print(lagrangian0,ke,pe)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PDPO",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
