{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from lifelines import KaplanMeierFitter\n",
        "from lifelines.statistics import logrank_test\n",
        "import plotly.graph_objects as go\n",
        "from plotly.subplots import make_subplots\n",
        "import plotly.io as pio\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')\n",
        "\n",
        "# Load the results first to get the significant genes\n",
        "results_df = pd.read_csv(\"/workdir/execution_outputs/top_N_differentially_expressed_genes.csv\")\n",
        "print(f\"Loaded results for {len(results_df)} genes\")\n",
        "\n",
        "# Load the original data\n",
        "survival = pd.read_csv(\"/workdir/survival_KIRC.csv\")\n",
        "mrna = pd.read_csv(\"/workdir/KIRC_mRNA_top_column_cleaned.csv\")\n",
        "mrna = mrna.rename(columns={'unnamed_0': 'gene'})\n",
        "mrna = mrna.set_index('gene')\n",
        "\n",
        "# Prepare data for plotting\n",
        "mrna_transposed = mrna.T\n",
        "mrna_transposed.index.name = 'sample_name'\n",
        "mrna_transposed = mrna_transposed.reset_index()\n",
        "survival_mrna = survival.merge(mrna_transposed, on='sample_name', how='inner')\n",
        "\n",
        "# Get top 5 genes with most significant survival associations\n",
        "# Add hr_distance column for proper sorting\n",
        "results_df['hr_distance'] = np.abs(results_df['hazard_ratio'] - 1)\n",
        "results_df = results_df.sort_values(['cox_padj', 'hr_distance'], ascending=[True, False])\n",
        "\n",
        "# Save the properly sorted results\n",
        "results_df.to_csv('/workdir/execution_outputs/top_N_differentially_expressed_genes.csv', index=False)\n",
        "\n",
        "top_5_genes = results_df.head(5)['gene'].tolist()\n",
        "print(f\"\\nTop 5 genes for survival curve visualization:\")\n",
        "for i, gene in enumerate(top_5_genes, 1):\n",
        "    gene_info = results_df[results_df['gene'] == gene].iloc[0]\n",
        "    print(f\"{i}. {gene} (HR: {gene_info['hazard_ratio']:.3f}, Cox p-adj: {gene_info['cox_padj']:.2e})\")\n",
        "\n",
        "def create_survival_plot(gene_name, survival_data, title_suffix=\"\"):\n",
        "    \"\"\"Create survival plot for a single gene\"\"\"\n",
        "    \n",
        "    # Get gene expression\n",
        "    gene_expr = survival_data[gene_name].values\n",
        "    median_expr = np.median(gene_expr)\n",
        "    \n",
        "    # Create high/low expression groups\n",
        "    high_mask = gene_expr >= median_expr\n",
        "    low_mask = gene_expr < median_expr\n",
        "    \n",
        "    # Fit Kaplan-Meier curves\n",
        "    kmf_high = KaplanMeierFitter()\n",
        "    kmf_low = KaplanMeierFitter()\n",
        "    \n",
        "    kmf_high.fit(\n",
        "        survival_data.loc[high_mask, 'survival_times'],\n",
        "        survival_data.loc[high_mask, 'event_observed'],\n",
        "        label='High Expression'\n",
        "    )\n",
        "    \n",
        "    kmf_low.fit(\n",
        "        survival_data.loc[low_mask, 'survival_times'],\n",
        "        survival_data.loc[low_mask, 'event_observed'],\n",
        "        label='Low Expression'\n",
        "    )\n",
        "    \n",
        "    # Log-rank test\n",
        "    logrank_results = logrank_test(\n",
        "        survival_data.loc[high_mask, 'survival_times'],\n",
        "        survival_data.loc[low_mask, 'survival_times'],\n",
        "        survival_data.loc[high_mask, 'event_observed'],\n",
        "        survival_data.loc[low_mask, 'event_observed']\n",
        "    )\n",
        "    \n",
        "    # Get survival curves and confidence intervals\n",
        "    times_high = kmf_high.timeline\n",
        "    survival_high = kmf_high.survival_function_['High Expression']\n",
        "    ci_high_lower = kmf_high.confidence_interval_.iloc[:, 0]\n",
        "    ci_high_upper = kmf_high.confidence_interval_.iloc[:, 1]\n",
        "    \n",
        "    times_low = kmf_low.timeline\n",
        "    survival_low = kmf_low.survival_function_['Low Expression']\n",
        "    ci_low_lower = kmf_low.confidence_interval_.iloc[:, 0]\n",
        "    ci_low_upper = kmf_low.confidence_interval_.iloc[:, 1]\n",
        "    \n",
        "    # Get gene statistics\n",
        "    gene_stats = results_df[results_df['gene'] == gene_name].iloc[0]\n",
        "    \n",
        "    # Create figure\n",
        "    fig = go.Figure()\n",
        "    \n",
        "    # High expression group\n",
        "    fig.add_trace(go.Scatter(\n",
        "        x=times_high,\n",
        "        y=survival_high,\n",
        "        mode='lines',\n",
        "        name=f'High Expression (n={high_mask.sum()})',\n",
        "        line=dict(color='red', width=2),\n",
        "        showlegend=True\n",
        "    ))\n",
        "    \n",
        "    # High expression confidence interval\n",
        "    fig.add_trace(go.Scatter(\n",
        "        x=times_high.tolist() + times_high.tolist()[::-1],\n",
        "        y=ci_high_upper.tolist() + ci_high_lower.tolist()[::-1],\n",
        "        fill='toself',\n",
        "        fillcolor='rgba(255,0,0,0.2)',\n",
        "        line=dict(color='rgba(255,255,255,0)'),\n",
        "        name='High Expression CI',\n",
        "        showlegend=False\n",
        "    ))\n",
        "    \n",
        "    # Low expression group\n",
        "    fig.add_trace(go.Scatter(\n",
        "        x=times_low,\n",
        "        y=survival_low,\n",
        "        mode='lines',\n",
        "        name=f'Low Expression (n={low_mask.sum()})',\n",
        "        line=dict(color='blue', width=2),\n",
        "        showlegend=True\n",
        "    ))\n",
        "    \n",
        "    # Low expression confidence interval\n",
        "    fig.add_trace(go.Scatter(\n",
        "        x=times_low.tolist() + times_low.tolist()[::-1],\n",
        "        y=ci_low_upper.tolist() + ci_low_lower.tolist()[::-1],\n",
        "        fill='toself',\n",
        "        fillcolor='rgba(0,0,255,0.2)',\n",
        "        line=dict(color='rgba(255,255,255,0)'),\n",
        "        name='Low Expression CI',\n",
        "        showlegend=False\n",
        "    ))\n",
        "    \n",
        "    # Update layout\n",
        "    fig.update_layout(\n",
        "        title=f'{gene_name} Survival Analysis{title_suffix}<br>' +\n",
        "              f'HR: {gene_stats[\"hazard_ratio\"]:.3f} ' +\n",
        "              f'(95% CI: {gene_stats[\"ci_lower\"]:.3f}-{gene_stats[\"ci_upper\"]:.3f})<br>' +\n",
        "              f'Cox p-value: {gene_stats[\"cox_pvalue\"]:.2e}, ' +\n",
        "              f'Log-rank p-value: {logrank_results.p_value:.2e}',\n",
        "        xaxis_title='Time (days)',\n",
        "        yaxis_title='Survival Probability',\n",
        "        template='plotly_white',\n",
        "        width=800,\n",
        "        height=600,\n",
        "        font=dict(size=12),\n",
        "        legend=dict(x=0.7, y=0.95)\n",
        "    )\n",
        "    \n",
        "    fig.update_xaxis(showgrid=True)\n",
        "    fig.update_yaxis(showgrid=True, range=[0, 1])\n",
        "    \n",
        "    return fig\n",
        "\n",
        "# Create survival plots for top 5 genes\n",
        "print(f\"\\nCreating survival plots for top 5 genes...\")\n",
        "\n",
        "for i, gene in enumerate(top_5_genes):\n",
        "    print(f\"Creating plot {i+1}/5: {gene}\")\n",
        "    \n",
        "    fig = create_survival_plot(gene, survival_mrna, f\" (Rank #{results_df[results_df['gene'] == gene]['rank'].iloc[0]})\")\n",
        "    \n",
        "    # Save individual plot\n",
        "    filename = f'/workdir/execution_outputs/survival_plot_{gene.replace(\"/\", \"_\")}.html'\n",
        "    fig.write_html(filename)\n",
        "    print(f\"  Saved: survival_plot_{gene.replace('/', '_')}.html\")\n",
        "\n",
        "# Create combined plot with subplots\n",
        "print(f\"\\nCreating combined survival plot...\")\n",
        "\n",
        "fig_combined = make_subplots(\n",
        "    rows=3, cols=2,\n",
        "    subplot_titles=[f'{gene} (HR: {results_df[results_df[\"gene\"] == gene][\"hazard_ratio\"].iloc[0]:.3f})' \n",
        "                   for gene in top_5_genes],\n",
        "    vertical_spacing=0.08,\n",
        "    horizontal_spacing=0.08\n",
        ")\n",
        "\n",
        "positions = [(1,1), (1,2), (2,1), (2,2), (3,1)]\n",
        "\n",
        "for i, gene in enumerate(top_5_genes):\n",
        "    row, col = positions[i]\n",
        "    \n",
        "    # Get gene expression and create groups\n",
        "    gene_expr = survival_mrna[gene].values\n",
        "    median_expr = np.median(gene_expr)\n",
        "    high_mask = gene_expr >= median_expr\n",
        "    low_mask = gene_expr < median_expr\n",
        "    \n",
        "    # Fit Kaplan-Meier curves\n",
        "    kmf_high = KaplanMeierFitter()\n",
        "    kmf_low = KaplanMeierFitter()\n",
        "    \n",
        "    kmf_high.fit(\n",
        "        survival_mrna.loc[high_mask, 'survival_times'],\n",
        "        survival_mrna.loc[high_mask, 'event_observed'],\n",
        "        label='High'\n",
        "    )\n",
        "    \n",
        "    kmf_low.fit(\n",
        "        survival_mrna.loc[low_mask, 'survival_times'],\n",
        "        survival_mrna.loc[low_mask, 'event_observed'],\n",
        "        label='Low'\n",
        "    )\n",
        "    \n",
        "    # Add traces to subplot\n",
        "    fig_combined.add_trace(\n",
        "        go.Scatter(\n",
        "            x=kmf_high.timeline,\n",
        "            y=kmf_high.survival_function_['High'],\n",
        "            mode='lines',\n",
        "            name=f'High (n={high_mask.sum()})' if i == 0 else None,\n",
        "            line=dict(color='red', width=2),\n",
        "            showlegend=(i == 0),\n",
        "            legendgroup='high'\n",
        "        ),\n",
        "        row=row, col=col\n",
        "    )\n",
        "    \n",
        "    fig_combined.add_trace(\n",
        "        go.Scatter(\n",
        "            x=kmf_low.timeline,\n",
        "            y=kmf_low.survival_function_['Low'],\n",
        "            mode='lines',\n",
        "            name=f'Low (n={low_mask.sum()})' if i == 0 else None,\n",
        "            line=dict(color='blue', width=2),\n",
        "            showlegend=(i == 0),\n",
        "            legendgroup='low'\n",
        "        ),\n",
        "        row=row, col=col\n",
        "    )\n",
        "\n",
        "fig_combined.update_layout(\n",
        "    title='Top 5 Genes: Survival Analysis (High vs Low Expression)',\n",
        "    template='plotly_white',\n",
        "    height=900,\n",
        "    width=1200,\n",
        "    font=dict(size=10)\n",
        ")\n",
        "\n",
        "# Update all x and y axes\n",
        "for i in range(1, 4):\n",
        "    for j in range(1, 3):\n",
        "        if i == 3 and j == 2:  # Skip the empty subplot\n",
        "            continue\n",
        "        fig_combined.update_xaxis(title_text='Time (days)', showgrid=True, row=i, col=j)\n",
        "        fig_combined.update_yaxis(title_text='Survival Probability', showgrid=True, range=[0, 1], row=i, col=j)\n",
        "\n",
        "# Save combined plot\n",
        "fig_combined.write_html('/workdir/execution_outputs/combined_survival_plots_top5.html')\n",
        "print(f\"Saved: combined_survival_plots_top5.html\")\n",
        "\n",
        "# Create summary table of the top 5 genes\n",
        "summary_data = []\n",
        "for gene in top_5_genes:\n",
        "    gene_info = results_df[results_df['gene'] == gene].iloc[0]\n",
        "    summary_data.append({\n",
        "        'Gene': gene,\n",
        "        'Rank': int(gene_info['rank']),\n",
        "        'Hazard_Ratio': round(gene_info['hazard_ratio'], 3),\n",
        "        'CI_Lower': round(gene_info['ci_lower'], 3),\n",
        "        'CI_Upper': round(gene_info['ci_upper'], 3),\n",
        "        'Cox_P_Value': f\"{gene_info['cox_pvalue']:.2e}\",\n",
        "        'Cox_P_Adjusted': f\"{gene_info['cox_padj']:.2e}\",\n",
        "        'LogRank_P_Value': f\"{gene_info['logrank_pvalue']:.2e}\",\n",
        "        'LogRank_P_Adjusted': f\"{gene_info['logrank_padj']:.2e}\",\n",
        "        'Median_Survival_High': gene_info['median_survival_high'],\n",
        "        'Median_Survival_Low': gene_info['median_survival_low']\n",
        "    })\n",
        "\n",
        "summary_df = pd.DataFrame(summary_data)\n",
        "summary_df.to_csv('/workdir/execution_outputs/top_5_genes_summary.csv', index=False)\n",
        "\n",
        "print(f\"\\nSummary of Top 5 Genes:\")\n",
        "print(summary_df.to_string(index=False))\n",
        "print(f\"\\nSaved: top_5_genes_summary.csv\")\n",
        "\n",
        "print(f\"\\nAll files created successfully!\")\n",
        "print(f\"Main outputs:\")\n",
        "print(f\"- top_N_differentially_expressed_genes.csv: All 100 genes analyzed\")\n",
        "print(f\"- combined_survival_plots_top5.html: Combined survival plots\")\n",
        "print(f\"- Individual survival plots for each of the top 5 genes\")\n",
        "print(f\"- top_5_genes_summary.csv: Summary statistics for top 5 genes\")\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}