{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "from sklearn.cluster import KMeans\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from lifelines import KaplanMeierFitter, CoxPHFitter\n",
        "from lifelines.statistics import logrank_test\n",
        "import plotly.graph_objects as go\n",
        "import warnings\n",
        "warnings.filterwarnings('ignore')\n",
        "\n",
        "print(\"=== Loading preprocessed data ===\")\n",
        "# Load the data (recreate since we need it in this script)\n",
        "mrna_data = pd.read_csv(\"/workdir/KIRC_mRNA_top_column_cleaned.csv\")\n",
        "mrna_transposed = mrna_data.set_index('unnamed_0').T\n",
        "survival_data = pd.read_csv(\"/workdir/survival_KIRC.csv\")\n",
        "\n",
        "# Get matched samples\n",
        "matched_samples = list(set(mrna_transposed.index).intersection(set(survival_data['sample_name'])))\n",
        "matched_expression_data = mrna_transposed.loc[matched_samples].copy()\n",
        "matched_survival_data = survival_data[survival_data['sample_name'].isin(matched_samples)].copy()\n",
        "matched_survival_data = matched_survival_data.set_index('sample_name').loc[matched_samples].reset_index()\n",
        "\n",
        "# Calculate variance and get initial features\n",
        "gene_variances = matched_expression_data.var(axis=0).sort_values(ascending=False)\n",
        "n_initial_features = min(1000, len(gene_variances))\n",
        "initial_features = gene_variances.head(n_initial_features).index.tolist()\n",
        "\n",
        "print(f\"Working with {len(matched_samples)} patients and {len(initial_features)} initial features\")\n",
        "\n",
        "print(\"\\n=== Step 5-7: Iterative optimization for survival separation ===\")\n",
        "\n",
        "# Function to evaluate clustering quality based on survival separation\n",
        "def evaluate_clustering(expression_data, features, survival_data, random_state=42):\n",
        "    \"\"\"Evaluate clustering quality based on survival separation\"\"\"\n",
        "    try:\n",
        "        # Standardize features\n",
        "        scaler = StandardScaler()\n",
        "        X_scaled = scaler.fit_transform(expression_data[features])\n",
        "        \n",
        "        # Perform K-means clustering\n",
        "        kmeans = KMeans(n_clusters=2, random_state=random_state, n_init=10)\n",
        "        cluster_labels = kmeans.fit_predict(X_scaled)\n",
        "        \n",
        "        # Get survival data for each cluster\n",
        "        cluster_0_indices = np.where(cluster_labels == 0)[0]\n",
        "        cluster_1_indices = np.where(cluster_labels == 1)[0]\n",
        "        \n",
        "        if len(cluster_0_indices) < 5 or len(cluster_1_indices) < 5:\n",
        "            return None, None, float('inf')  # Invalid clustering\n",
        "        \n",
        "        # Extract survival data for each cluster\n",
        "        times_0 = survival_data.iloc[cluster_0_indices]['survival_times']\n",
        "        events_0 = survival_data.iloc[cluster_0_indices]['event_observed']\n",
        "        times_1 = survival_data.iloc[cluster_1_indices]['survival_times']\n",
        "        events_1 = survival_data.iloc[cluster_1_indices]['event_observed']\n",
        "        \n",
        "        # Perform log-rank test\n",
        "        results = logrank_test(times_0, times_1, event_observed_A=events_0, event_observed_B=events_1)\n",
        "        p_value = results.p_value\n",
        "        \n",
        "        return cluster_labels, kmeans, p_value\n",
        "        \n",
        "    except Exception as e:\n",
        "        print(f\"Error in clustering evaluation: {e}\")\n",
        "        return None, None, float('inf')\n",
        "\n",
        "# Initial clustering\n",
        "print(\"Performing initial clustering...\")\n",
        "initial_clusters, initial_model, initial_p_value = evaluate_clustering(\n",
        "    matched_expression_data, initial_features, matched_survival_data\n",
        ")\n",
        "\n",
        "print(f\"Initial p-value: {initial_p_value:.6f}\")\n",
        "\n",
        "# Initialize optimization variables\n",
        "best_features = initial_features.copy()\n",
        "best_clusters = initial_clusters.copy()\n",
        "best_p_value = initial_p_value\n",
        "best_model = initial_model\n",
        "iteration_results = []\n",
        "\n",
        "print(\"\\nStarting iterative optimization...\")\n",
        "\n",
        "# Iterative optimization (simplified approach)\n",
        "max_iterations = 5  # Reduced for efficiency\n",
        "for iteration in range(max_iterations):\n",
        "    print(f\"\\nIteration {iteration + 1}/{max_iterations}\")\n",
        "    \n",
        "    # Current feature set size\n",
        "    current_n_features = len(best_features)\n",
        "    step_size = max(10, current_n_features // 20)  # Remove/add 5% of features\n",
        "    \n",
        "    # Get unused features sorted by variance\n",
        "    unused_features = [gene for gene in gene_variances.index if gene not in best_features]\n",
        "    unused_top = unused_features[:step_size * 2]  # Get top unused features\n",
        "    \n",
        "    if len(unused_top) == 0:\n",
        "        print(\"No more unused features available\")\n",
        "        break\n",
        "    \n",
        "    # Try different feature modifications\n",
        "    modifications = [\n",
        "        (\"add_high_var\", best_features + unused_top[:step_size]),\n",
        "        (\"replace_low_var\", best_features[:-step_size] + unused_top[:step_size]),\n",
        "    ]\n",
        "    \n",
        "    for mod_name, new_features in modifications:\n",
        "        if len(new_features) < 50:  # Minimum feature threshold\n",
        "            continue\n",
        "            \n",
        "        # Evaluate new feature set\n",
        "        new_clusters, new_model, new_p_value = evaluate_clustering(\n",
        "            matched_expression_data, new_features, matched_survival_data\n",
        "        )\n",
        "        \n",
        "        iteration_results.append({\n",
        "            'iteration': iteration + 1,\n",
        "            'modification': mod_name,\n",
        "            'n_features': len(new_features),\n",
        "            'p_value': new_p_value\n",
        "        })\n",
        "        \n",
        "        print(f\"  {mod_name}: p-value = {new_p_value:.6f} (features: {len(new_features)})\")\n",
        "        \n",
        "        # Update if improvement found\n",
        "        if new_p_value < best_p_value and new_clusters is not None:\n",
        "            best_features = new_features.copy()\n",
        "            best_clusters = new_clusters.copy()\n",
        "            best_p_value = new_p_value\n",
        "            best_model = new_model\n",
        "            print(f\"  *** New best p-value: {best_p_value:.6f} ***\")\n",
        "\n",
        "print(f\"\\nOptimization completed!\")\n",
        "print(f\"Best p-value: {best_p_value:.6f}\")\n",
        "print(f\"Best feature count: {len(best_features)}\")\n",
        "\n",
        "# Final clustering with best features\n",
        "print(\"\\n=== Step 8: Final clustering and survival analysis ===\")\n",
        "\n",
        "# Standardize final features\n",
        "scaler = StandardScaler()\n",
        "X_final = scaler.fit_transform(matched_expression_data[best_features])\n",
        "\n",
        "# Final clustering\n",
        "final_clusters = best_clusters.copy()\n",
        "\n",
        "# Add cluster labels to survival data\n",
        "matched_survival_data['cluster'] = final_clusters\n",
        "cluster_0_data = matched_survival_data[matched_survival_data['cluster'] == 0]\n",
        "cluster_1_data = matched_survival_data[matched_survival_data['cluster'] == 1]\n",
        "\n",
        "print(f\"Cluster 0: {len(cluster_0_data)} patients\")\n",
        "print(f\"Cluster 1: {len(cluster_1_data)} patients\")\n",
        "\n",
        "# Kaplan-Meier analysis\n",
        "kmf_0 = KaplanMeierFitter()\n",
        "kmf_1 = KaplanMeierFitter()\n",
        "\n",
        "kmf_0.fit(cluster_0_data['survival_times'], cluster_0_data['event_observed'], label='Cluster 0')\n",
        "kmf_1.fit(cluster_1_data['survival_times'], cluster_1_data['event_observed'], label='Cluster 1')\n",
        "\n",
        "# Log-rank test\n",
        "logrank_result = logrank_test(\n",
        "    cluster_0_data['survival_times'], cluster_1_data['survival_times'],\n",
        "    event_observed_A=cluster_0_data['event_observed'], \n",
        "    event_observed_B=cluster_1_data['event_observed']\n",
        ")\n",
        "\n",
        "print(f\"Log-rank test p-value: {logrank_result.p_value:.6f}\")\n",
        "print(f\"Log-rank test statistic: {logrank_result.test_statistic:.4f}\")\n",
        "\n",
        "# Cox proportional hazards model\n",
        "cox_data = matched_survival_data[['survival_times', 'event_observed', 'cluster']].copy()\n",
        "cox_data.columns = ['duration', 'event', 'cluster']\n",
        "\n",
        "cph = CoxPHFitter()\n",
        "cph.fit(cox_data, duration_col='duration', event_col='event')\n",
        "\n",
        "print(\"\\nCox Proportional Hazards Results:\")\n",
        "print(cph.summary)\n",
        "\n",
        "hazard_ratio = np.exp(cph.params_['cluster'])\n",
        "ci_lower = np.exp(cph.confidence_intervals_.loc['cluster', 'coef lower 95%'])\n",
        "ci_upper = np.exp(cph.confidence_intervals_.loc['cluster', 'coef upper 95%'])\n",
        "\n",
        "print(f\"Hazard Ratio: {hazard_ratio:.4f} (95% CI: {ci_lower:.4f} - {ci_upper:.4f})\")\n",
        "\n",
        "# Save intermediate results\n",
        "print(\"\\nSaving analysis results...\")\n",
        "results_summary = {\n",
        "    'best_p_value': best_p_value,\n",
        "    'n_best_features': len(best_features),\n",
        "    'cluster_0_size': len(cluster_0_data),\n",
        "    'cluster_1_size': len(cluster_1_data),\n",
        "    'logrank_p_value': logrank_result.p_value,\n",
        "    'hazard_ratio': hazard_ratio,\n",
        "    'ci_lower': ci_lower,\n",
        "    'ci_upper': ci_upper,\n",
        "    'cluster_0_median_survival': cluster_0_data['survival_times'].median(),\n",
        "    'cluster_1_median_survival': cluster_1_data['survival_times'].median()\n",
        "}\n",
        "\n",
        "pd.Series(results_summary).to_csv(\"/workdir/execution_outputs/analysis_summary.csv\")\n",
        "pd.DataFrame(iteration_results).to_csv(\"/workdir/execution_outputs/iteration_results.csv\")\n",
        "\n",
        "print(\"Clustering optimization completed successfully!\")\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}