{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zIo-TghKazyj"
   },
   "outputs": [],
   "source": [
    "!pip install easy-tpp\n",
    "!pip install pyyaml\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pickle\n",
    "import torch\n",
    "import pandas as pd\n",
    "from easy_tpp.config_factory import Config\n",
    "from easy_tpp.runner import Runner\n",
    "import warnings\n",
    "\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mRIDt4oqyNbp"
   },
   "outputs": [],
   "source": [
    "\n",
    "!pip install easy-tpp\n",
    "!pip install pyyaml\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pickle\n",
    "import multiprocessing\n",
    "import torch\n",
    "import pandas as pd\n",
    "from easy_tpp.config_factory import Config\n",
    "from easy_tpp.runner import Runner\n",
    "import warnings\n",
    "import shutil\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "\n",
    "\n",
    "TOTAL_SAMPLES = 50000\n",
    "T = 100.0\n",
    "DT = 0.01\n",
    "MU = .1\n",
    "SIGMA = 1.0\n",
    "U_BOUND = 12.0\n",
    "\n",
    "def exponential_decay_drift(t):\n",
    "    return MU * np.exp(-0.5 * t)\n",
    "\n",
    "def simulate_batch(seed, n_samples, mode, lower_bound=None):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    batch_results = []\n",
    "    sq_dt = np.sqrt(DT)\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        t = 0.0\n",
    "        x = 0.0\n",
    "        crossings = []\n",
    "\n",
    "        while t < T:\n",
    "            if mode == 'time_varying':\n",
    "                drift = exponential_decay_drift(t)\n",
    "            else:\n",
    "                drift = MU\n",
    "\n",
    "            dx = drift * DT + SIGMA * sq_dt * rng.normal()\n",
    "            x_new = x + dx\n",
    "            t += DT\n",
    "\n",
    "            if mode == 'reflected':\n",
    "                if x_new < 0: x_new = -x_new\n",
    "                if x_new >= U_BOUND:\n",
    "                    crossings.append(t)\n",
    "                    x_new = 0.0\n",
    "            elif mode == 'two_boundary':\n",
    "                if x_new >= U_BOUND:\n",
    "                    crossings.append(t)\n",
    "                    x_new = 0.0\n",
    "                elif lower_bound is not None and x_new <= lower_bound:\n",
    "                    x_new = 0.0\n",
    "            elif mode == 'time_varying':\n",
    "                if x_new >= U_BOUND:\n",
    "                    crossings.append(t)\n",
    "                    x_new = 0.0\n",
    "            elif mode == 'single_boundary':\n",
    "                if x_new >= U_BOUND:\n",
    "                    crossings.append(t)\n",
    "                    x_new = 0.0\n",
    "\n",
    "            x = x_new\n",
    "\n",
    "        batch_results.append(crossings)\n",
    "    return batch_results\n",
    "\n",
    "def run_parallel_simulations(mode, n_total, n_processes=None, lower_bound=None):\n",
    "    if n_processes is None: n_processes = multiprocessing.cpu_count()\n",
    "    samples_per_worker = [n_total // n_processes] * n_processes\n",
    "    samples_per_worker[-1] += n_total % n_processes\n",
    "\n",
    "    ss = np.random.SeedSequence()\n",
    "    child_seeds = ss.spawn(n_processes)\n",
    "\n",
    "    tasks = [(child_seeds[i], samples_per_worker[i], mode, lower_bound) for i in range(n_processes)]\n",
    "\n",
    "    results_nested = [simulate_batch(*task) for task in tasks]\n",
    "    return [item for sublist in results_nested for item in sublist]\n",
    "\n",
    "def format_for_easytpp(raw_data, split_name):\n",
    "    \"\"\"\n",
    "    FIXED: Converts to List[List[Dict]] structure.\n",
    "    EasyTPP expects: [ [ {event1}, {event2} ], [ {event1} ... ] ]\n",
    "    \"\"\"\n",
    "    formatted_seqs = []\n",
    "    for path in raw_data:\n",
    "        if len(path) < 2: continue\n",
    "        path = sorted(path)\n",
    "\n",
    "        deltas = [path[0]] + [path[i] - path[i-1] for i in range(1, len(path))]\n",
    "        types = [0] * len(path)\n",
    "\n",
    "        seq_events = []\n",
    "        for t, dt, k in zip(path, deltas, types):\n",
    "            event_dict = {\n",
    "                \"time_since_start\": t,\n",
    "                \"time_since_last_event\": dt,\n",
    "                \"type_event\": k\n",
    "            }\n",
    "            seq_events.append(event_dict)\n",
    "\n",
    "        formatted_seqs.append(seq_events)\n",
    "\n",
    "    return {\n",
    "        \"dim_process\": 1,\n",
    "        split_name: formatted_seqs\n",
    "    }\n",
    "\n",
    "def generate_and_save_data():\n",
    "    if os.path.exists('data'):\n",
    "        shutil.rmtree('data')\n",
    "    os.makedirs('data')\n",
    "\n",
    "    scenarios = {\n",
    "        'reflected': {'mode': 'reflected', 'lower_bound': None},\n",
    "        'two_boundary': {'mode': 'two_boundary', 'lower_bound': -1.0},\n",
    "        'time_varying': {'mode': 'time_varying', 'lower_bound': None},\n",
    "        'single_boundary': {'mode': 'single_boundary', 'lower_bound': None}\n",
    "    }\n",
    "\n",
    "    data_paths = {}\n",
    "    print(\"Regenerating data with correct structure...\")\n",
    "\n",
    "    for name, params in scenarios.items():\n",
    "        print(f\"  Simulating: {name}...\")\n",
    "        raw_data = run_parallel_simulations(n_total=TOTAL_SAMPLES, **params)\n",
    "\n",
    "        n = len(raw_data)\n",
    "        idx = np.arange(n)\n",
    "        np.random.shuffle(idx)\n",
    "\n",
    "        splits = {\n",
    "            'train': idx[:int(0.6*n)],\n",
    "            'dev': idx[int(0.6*n):int(0.8*n)],\n",
    "            'test': idx[int(0.8*n):]\n",
    "        }\n",
    "\n",
    "        path_dict = {}\n",
    "        for split_name, indices in splits.items():\n",
    "            split_raw_data = [raw_data[i] for i in indices]\n",
    "\n",
    "            # Format strictly for EasyTPP\n",
    "            final_data = format_for_easytpp(split_raw_data, split_name)\n",
    "\n",
    "            file_path = f'data/{name}_{split_name}.pkl'\n",
    "            with open(file_path, 'wb') as f:\n",
    "                pickle.dump(final_data, f)\n",
    "\n",
    "            path_dict[split_name] = file_path\n",
    "\n",
    "        data_paths[name] = path_dict\n",
    "        print(f\"  Saved {name}.\")\n",
    "\n",
    "    return data_paths\n",
    "\n",
    "data_paths = generate_and_save_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DZ6vBS2OuGzg"
   },
   "outputs": [],
   "source": [
    "\n",
    "!pip install easy-tpp\n",
    "!pip install pyyaml\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from easy_tpp.model.torch_model.torch_basemodel import TorchBaseModel\n",
    "from easy_tpp.config_factory import Config\n",
    "from easy_tpp.runner import Runner\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "class FPTRenewal(TorchBaseModel):\n",
    "    \"\"\"\n",
    "    First Passage Time Renewal Process with Polynomial -> Neural Drift.\n",
    "    \"\"\"\n",
    "    def __init__(self, model_config):\n",
    "        super(FPTRenewal, self).__init__(model_config)\n",
    "\n",
    "\n",
    "        self.u = nn.Parameter(torch.tensor(5.0))\n",
    "\n",
    "        self.poly_degree = 3\n",
    "        self.time_scale = 0.01\n",
    "\n",
    "\n",
    "        self.drift_net = nn.Sequential(\n",
    "            nn.Linear(self.poly_degree, 16),\n",
    "            nn.Tanh(),\n",
    "            nn.Linear(16, 1)\n",
    "        )\n",
    "\n",
    "        self.num_integration_steps = 20\n",
    "\n",
    "    def get_drift_integral_and_mu(self, t_tensor):\n",
    "        \"\"\"\n",
    "        Computes mu(t) and integral_0^t mu(s) ds numerically.\n",
    "        \"\"\"\n",
    "        B, S = t_tensor.shape\n",
    "\n",
    "        steps = torch.linspace(0, 1, self.num_integration_steps, device=t_tensor.device)\n",
    "        steps = steps.view(1, 1, -1).expand(B, S, -1)\n",
    "\n",
    "        s_points = steps * t_tensor.unsqueeze(-1) # [B, S, K]\n",
    "\n",
    "        s_flat = s_points.view(-1, 1)\n",
    "\n",
    "\n",
    "        s_scaled = s_flat * self.time_scale\n",
    "\n",
    "\n",
    "        poly_terms = []\n",
    "        for p in range(1, self.poly_degree + 1):\n",
    "            poly_terms.append(torch.pow(s_scaled, p))\n",
    "\n",
    "        s_poly = torch.cat(poly_terms, dim=-1)\n",
    "\n",
    "        mu_flat = self.drift_net(s_poly)\n",
    "\n",
    "        mu_points = mu_flat.view(B, S, self.num_integration_steps)\n",
    "\n",
    "        if self.num_integration_steps > 1:\n",
    "            integral = torch.trapz(mu_points, dim=-1) * (t_tensor / (self.num_integration_steps - 1))\n",
    "        else:\n",
    "            integral = mu_points.squeeze(-1) * t_tensor\n",
    "\n",
    "        mu_t = mu_points[..., -1]\n",
    "\n",
    "        return integral, mu_t\n",
    "\n",
    "    def compute_pdf(self, t_tensor):\n",
    "        \"\"\"\n",
    "        Computes the FPT PDF f(t).\n",
    "        \"\"\"\n",
    "        t = t_tensor.clamp(min=1e-5)\n",
    "\n",
    "        integral_mu, mu_t = self.get_drift_integral_and_mu(t)\n",
    "\n",
    "\n",
    "        numerator_inner = self.u - integral_mu + t * mu_t\n",
    "\n",
    "        exponent = - (self.u - integral_mu)**2 / (2 * t)\n",
    "\n",
    "        base = 1.0 / (torch.sqrt(torch.tensor(2 * torch.pi, device=t.device)) * t.pow(1.5))\n",
    "\n",
    "        pdf = base * numerator_inner * torch.exp(exponent)\n",
    "\n",
    "        pdf = torch.relu(pdf)\n",
    "\n",
    "        return pdf\n",
    "\n",
    "    def compute_survival(self, t_tensor):\n",
    "        \"\"\"\n",
    "        Compute survival function numerically.\n",
    "        \"\"\"\n",
    "        B, S = t_tensor.shape\n",
    "        steps = torch.linspace(0, 1, self.num_integration_steps, device=t_tensor.device)\n",
    "        steps = steps.view(1, 1, -1).expand(B, S, -1)\n",
    "\n",
    "        x_points = steps * t_tensor.unsqueeze(-1)\n",
    "        x_flat = x_points.view(-1, 1)\n",
    "\n",
    "\n",
    "        pdf_flat = self.compute_pdf(x_flat)\n",
    "        pdf_points = pdf_flat.view(B, S, self.num_integration_steps)\n",
    "\n",
    "        cdf = torch.trapz(pdf_points, dim=-1) * (t_tensor / (self.num_integration_steps - 1))\n",
    "\n",
    "        survival = 1.0 - cdf\n",
    "        return survival.clamp(min=1e-6, max=1.0)\n",
    "\n",
    "    def loglike_loss(self, batch):\n",
    "        time_seqs, time_delta_seqs, type_seqs, batch_non_pad_mask, _ = batch\n",
    "\n",
    "        dt = time_delta_seqs[:, 1:]\n",
    "        mask = batch_non_pad_mask[:, 1:]\n",
    "\n",
    "        pdf_vals = self.compute_pdf(dt)\n",
    "        log_pdf = torch.log(pdf_vals + 1e-6)\n",
    "\n",
    "        event_mask = torch.logical_and(mask, type_seqs[:, 1:] != self.pad_token_id)\n",
    "        loss_event = - (log_pdf * event_mask).sum()\n",
    "\n",
    "        loss = loss_event\n",
    "        num_events = event_mask.sum().item()\n",
    "\n",
    "        return loss, num_events\n",
    "\n",
    "    def predict_one_step_at_every_event(self, batch):\n",
    "        time_seq, time_delta_seq, event_seq, _, _ = batch\n",
    "\n",
    "        horizon = 10.0\n",
    "        steps = torch.linspace(0, horizon, 50, device=time_delta_seq.device)\n",
    "        pdf_vals = self.compute_pdf(steps.unsqueeze(0))\n",
    "        expected_time = torch.trapz(steps.unsqueeze(0) * pdf_vals, steps)\n",
    "\n",
    "        dtimes_pred = expected_time.expand(time_delta_seq.shape[0], time_delta_seq.shape[1]-1)\n",
    "        types_pred = torch.zeros_like(dtimes_pred, dtype=torch.long)\n",
    "\n",
    "        return dtimes_pred, types_pred\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment(data_paths):\n",
    "    models = [\n",
    "        'FPTRenewal',\n",
    "        'RMTPP',\n",
    "        'NHP',\n",
    "        'IntensityFree'\n",
    "    ]\n",
    "\n",
    "    results = []\n",
    "\n",
    "    base_config_template = \"\"\"\n",
    "pipeline_config_id: runner_config\n",
    "\n",
    "data:\n",
    "  {dataset_name}:\n",
    "    data_format: pkl\n",
    "    train_dir: {train_path}\n",
    "    valid_dir: {dev_path}\n",
    "    test_dir: {test_path}\n",
    "    data_specs:\n",
    "      num_event_types: 1\n",
    "      pad_token_id: 1\n",
    "      padding_side: right\n",
    "\n",
    "{model_name}_experiment:\n",
    "  base_config:\n",
    "    stage: train\n",
    "    backend: torch\n",
    "    dataset_id: {dataset_name}\n",
    "    runner_id: std_tpp\n",
    "    model_id: {model_name}\n",
    "    base_dir: './checkpoints/{dataset_name}/{model_name}'\n",
    "  trainer_config:\n",
    "    batch_size: 64\n",
    "    max_epoch: 10\n",
    "    shuffle: True\n",
    "    optimizer: adam\n",
    "    learning_rate: 1.e-3\n",
    "    valid_freq: 1\n",
    "    use_tfb: False\n",
    "    metrics: ['acc', 'rmse']\n",
    "    seed: 2024\n",
    "    gpu: 0\n",
    "  model_config:\n",
    "    hidden_size: 32\n",
    "    time_emb_size: 16\n",
    "    num_layers: 1\n",
    "    num_heads: 2\n",
    "    mc_num_sample_per_step: 20\n",
    "    loss_integral_num_sample_per_step: 20\n",
    "    dropout: 0.1\n",
    "    use_ln: False\n",
    "    model_specs:\n",
    "      num_mlp_layers: 2\n",
    "      num_mix_components: 3\n",
    "      ode_num_sample_per_step: 5\n",
    "    thinning:\n",
    "      num_seq: 10\n",
    "      num_sample: 1\n",
    "      num_exp: 50\n",
    "      look_ahead_time: 10\n",
    "      patience_counter: 5\n",
    "      over_sample_rate: 5\n",
    "      num_samples_boundary: 5\n",
    "      dtime_max: 5\n",
    "\"\"\"\n",
    "\n",
    "    for scenario, paths in data_paths.items():\n",
    "        print(f\"\\n{'='*50}\\nEvaluating Scenario: {scenario}\\n{'='*50}\")\n",
    "\n",
    "        for model in models:\n",
    "            print(f\"\\n--- Training {model} on {scenario} ---\")\n",
    "\n",
    "            try:\n",
    "                config_str = base_config_template.format(\n",
    "                    dataset_name=scenario,\n",
    "                    train_path=paths['train'],\n",
    "                    dev_path=paths['dev'],\n",
    "                    test_path=paths['test'],\n",
    "                    model_name=model\n",
    "                )\n",
    "\n",
    "                config_filename = f\"config_{scenario}_{model}.yaml\"\n",
    "                with open(config_filename, \"w\") as f:\n",
    "                    f.write(config_str)\n",
    "\n",
    "                exp_id = f\"{model}_experiment\"\n",
    "                config = Config.build_from_yaml_file(config_filename, experiment_id=exp_id)\n",
    "\n",
    "                # Device Logic\n",
    "                if torch.cuda.is_available():\n",
    "                    gpu_id = 0\n",
    "                else:\n",
    "                    gpu_id = -1\n",
    "\n",
    "                config.trainer_config.gpu = gpu_id\n",
    "                config.model_config.gpu = gpu_id\n",
    "\n",
    "                runner = Runner.build_from_config(config)\n",
    "                runner.run()\n",
    "\n",
    "                # Evaluate\n",
    "                config.base_config.stage = 'gen'\n",
    "                config.trainer_config.gpu = gpu_id\n",
    "                config.model_config.gpu = gpu_id\n",
    "\n",
    "                runner_eval = Runner.build_from_config(config)\n",
    "                valid_loader = runner_eval._data_loader.valid_loader()\n",
    "                metrics = runner_eval._evaluate_model(valid_loader)\n",
    "\n",
    "                res_entry = {\n",
    "                    'Scenario': scenario,\n",
    "                    'Model': model,\n",
    "                    'LogLikelihood': metrics.get('loglike', np.nan),\n",
    "                    'Accuracy': metrics.get('acc', np.nan),\n",
    "                    'RMSE': metrics.get('rmse', np.nan)\n",
    "                }\n",
    "                results.append(res_entry)\n",
    "                print(f\"Result: {res_entry}\")\n",
    "\n",
    "            except Exception as e:\n",
    "                print(f\"Failed to run {model} on {scenario}: {e}\")\n",
    "                results.append({'Scenario': scenario, 'Model': model, 'LogLikelihood': 'Failed'})\n",
    "\n",
    "    return pd.DataFrame(results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y7RWcFBTHeLN"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from easy_tpp.model.torch_model.torch_nhp import NHP\n",
    "\n",
    "def nhp_forward_patched(self, batch):\n",
    "    '''\n",
    "    Patched forward method to handle sequence length N=1 (common during generation start).\n",
    "    '''\n",
    "    t_BN, dt_BN, marks_BN, _, _ = batch\n",
    "    B, N = dt_BN.shape\n",
    "    left_hs = []\n",
    "    right_states = []\n",
    "\n",
    "    all_event_emb_BNP = self.layer_type_emb(marks_BN)\n",
    "    c_t, c_bar_t, delta_t, o_t = self.get_init_state(B)\n",
    "\n",
    "    for i in range(N):\n",
    "        ct_d_t, h_d_t = self.rnn_cell.decay(c_t, c_bar_t, delta_t, o_t, dt_BN[..., i][..., None])\n",
    "\n",
    "        event_emb_t = all_event_emb_BNP[..., i, :]\n",
    "        c_t, c_bar_t, delta_t, o_t = self.rnn_cell(\n",
    "            x_i=event_emb_t,\n",
    "            hidden_ti_minus=h_d_t,\n",
    "            ct_ti_minus=ct_d_t,\n",
    "            c_bar_im1=c_bar_t,\n",
    "        )\n",
    "\n",
    "        left_hs.append(h_d_t)\n",
    "        right_states.append(torch.cat((c_t, c_bar_t, delta_t, o_t), dim=-1))\n",
    "\n",
    "    if len(left_hs) > 1:\n",
    "        left_hiddens = torch.stack(left_hs[1:], dim=-2)\n",
    "    else:\n",
    "\n",
    "        left_hiddens = torch.zeros(B, 0, self.hidden_size, device=self.device)\n",
    "\n",
    "    right_hiddens = torch.stack(right_states, dim=-2)\n",
    "    return left_hiddens, right_hiddens\n",
    "\n",
    "print(\"Applying NHP forward patch...\")\n",
    "NHP.forward = nhp_forward_patched\n",
    "print(\"Patch applied successfully.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DjUyZlqfxAxX"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "\n",
    "BASE_DATA_DIR = \"\"\n",
    "\n",
    "def load_paths_from_drive(base_dir):\n",
    "    \"\"\"\n",
    "    Scans the directory to reconstruct the data_paths dictionary.\n",
    "    Expects filenames like: 'reflected_train.pkl', 'single_boundary_test.pkl', etc.\n",
    "    \"\"\"\n",
    "    scenarios = ['reflected', 'two_boundary', 'time_varying', 'single_boundary']\n",
    "    splits = ['train', 'dev', 'test']\n",
    "    data_paths = {}\n",
    "\n",
    "    print(f\"Scanning directory: {base_dir}\")\n",
    "\n",
    "    for scenario in scenarios:\n",
    "        scenario_paths = {}\n",
    "        missing_files = False\n",
    "\n",
    "        for split in splits:\n",
    "\n",
    "            filename = f\"{scenario}_{split}.pkl\"\n",
    "            full_path = os.path.join(base_dir, filename)\n",
    "\n",
    "            if os.path.exists(full_path):\n",
    "                scenario_paths[split] = full_path\n",
    "            else:\n",
    "                print(f\"  [Missing] Could not find: {filename}\")\n",
    "                missing_files = True\n",
    "\n",
    "\n",
    "        if not missing_files:\n",
    "            data_paths[scenario] = scenario_paths\n",
    "            print(f\"  [Loaded] Scenario '{scenario}' ready.\")\n",
    "\n",
    "    return data_paths\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    data_paths = load_paths_from_drive(BASE_DATA_DIR)\n",
    "\n",
    "    if data_paths:\n",
    "        print(\"\\nStarting Experiments using data from Drive...\\n\")\n",
    "\n",
    "\n",
    "        df_results = run_experiment(data_paths)\n",
    "\n",
    "        print(\"\\nFinal Test Performance Report:\\n\")\n",
    "        pd.set_option('display.max_rows', None)\n",
    "        print(df_results)\n",
    "\n",
    "        if not df_results.empty and 'Test LogLikelihood' in df_results.columns:\n",
    "            try:\n",
    "                pivot_df = df_results.pivot(index='Model', columns='Scenario', values='Test LogLikelihood')\n",
    "                print(\"\\n--- Test Log-Likelihood Comparison ---\\n\")\n",
    "                print(pivot_df)\n",
    "            except:\n",
    "                pass\n",
    "    else:\n",
    "        print(\"No complete scenarios found\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "machine_shape": "hm",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
