{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "root_path = Path().cwd().parent.absolute()\n",
    "import sys \n",
    "sys.path.append(str(root_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import parametric_pushforward.data_sets as toy_data\n",
    "from parametric_pushforward.parametric_mlp import order_state_to_tensor, MLP,torch_wrapper,ParameterizedMLP,ParameterizedWrapper\n",
    "from parametric_pushforward.setup_density_path import get_activation\n",
    "\n",
    "from flow_matching.train_fm import ACTIVATION_FNS\n",
    "\n",
    "from torchdyn.core import NeuralODE\n",
    "\n",
    "\n",
    "\n",
    "from geomloss import SamplesLoss\n",
    "\n",
    "import yaml\n",
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)\n",
    "torch.cuda.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_dir = str(root_path)+'/models/gaussian0_s[2,64,4,softplus]'\n",
    "yaml_path = os.path.join(exp_dir,'config.yaml')\n",
    "with open(yaml_path) as file:\n",
    "    config = yaml.load(file,Loader = yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_set = config['data']['type']\n",
    "input_dim = config['model']['input_dim']\n",
    "hidden_dim = config['model']['hidden_dim']\n",
    "num_layers = config['model']['num_layers']\n",
    "activation = ACTIVATION_FNS[config['model']['activation_fn']]\n",
    "time_varying = config['model']['time_varying']\n",
    "\n",
    "arch = [input_dim,hidden_dim,num_layers,activation]\n",
    "\n",
    "model = MLP(arch=arch,time_varying=time_varying)\n",
    "\n",
    "state_path = os.path.join(exp_dir,'final.pth') #  checkpoint_999.pth #final\n",
    "state_tensor = torch.load(state_path,map_location=device)['model_state_dict']\n",
    "\n",
    "model.load_state_dict(state_dict=state_tensor)\n",
    "\n",
    "model = torch_wrapper(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activation = get_activation(config['model']['activation_fn'])\n",
    "arch = [input_dim,hidden_dim,num_layers,activation]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theta = order_state_to_tensor(state_tensor)\n",
    "parametric_model = ParameterizedWrapper(ParameterizedMLP(arch,time_varying=time_varying),theta=theta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "referece = MultivariateNormal(loc = torch.zeros(input_dim).to(device),covariance_matrix=torch.eye(input_dim).to(device))\n",
    "\n",
    "t_node = torch.linspace(0,1,10).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "node = NeuralODE(model,solver = 'midpoint').to(device)\n",
    "parametric_node = NeuralODE(parametric_model,solver = 'midpoint').to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 5000\n",
    "\n",
    "z = referece.sample((bs,)).to(device)\n",
    "trajecotry = node.trajectory(z,t_span=t_node)\n",
    "par_trajectory = parametric_node.trajectory(z,t_span=t_node)\n",
    "\n",
    "node_samples = trajecotry[-1].detach().cpu()\n",
    "par_node_samples = par_trajectory[-1].detach().cpu()\n",
    "\n",
    "true_samples = torch.from_numpy(toy_data.inf_train_gen(data=data_set,batch_size=bs,dim = config['model']['input_dim']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare pushforward samples\n",
    "fig = plt.figure(figsize=(10,10))\n",
    "plt.scatter(true_samples[:,0],true_samples[:,1],label = 'True samples')\n",
    "plt.scatter(node_samples[:,0],node_samples[:,1],label='Generated samples',marker = '*')\n",
    "\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_samples.shape,node_samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = SamplesLoss(loss = 'sinkhorn', p = 2, blur = 0.05)\n",
    "L = loss(node_samples,true_samples)\n",
    "L.item()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PDPO",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
