{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean Squared Error: 0.7397491834205001\n",
      "Coefficients: [0.         0.04108204 0.05519134]\n",
      "Intercept: 12.987725363753507\n",
      "Predicted worker2_time: 26.06511263778271\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/sda/2022-0526/home/xuhx/software/miniconda3/envs/gllm/lib/python3.10/site-packages/sklearn/utils/validation.py:2739: UserWarning: X does not have valid feature names, but PolynomialFeatures was fitted with feature names\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import time\n",
    "# Step 1: 读取数据并清洗\n",
    "df = pd.read_csv('csv/profile_sharegpt.csv')\n",
    "df = df.iloc[10:-10].reset_index(drop=True)  # 去掉一些异常值\n",
    "df = df[df['sample_time'] < 10]\n",
    "df = df[df['worker_1_time'] < 100]\n",
    "# 特征列和目标列\n",
    "X = df[['prefill', 'decode']]\n",
    "y = df['worker_1_time']\n",
    "# X = df[['decode']]\n",
    "# y = df['sample_time']\n",
    "\n",
    "# 划分训练集和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "\n",
    "# 设置多项式特征的度数（根据实际情况调整）\n",
    "poly = PolynomialFeatures(degree=1)  # degree可以调整，根据复杂度需求\n",
    "X_train_poly = poly.fit_transform(X_train)\n",
    "X_test_poly = poly.transform(X_test)\n",
    "\n",
    "# 使用线性回归模型拟合数据\n",
    "model = LinearRegression()\n",
    "model.fit(X_train_poly, y_train)\n",
    "\n",
    "# 预测\n",
    "y_pred = model.predict(X_test_poly)\n",
    "\n",
    "# 评估模型\n",
    "mse = mean_squared_error(y_test, y_pred)\n",
    "print(f'Mean Squared Error: {mse}')\n",
    "\n",
    "# 输出拟合的多项式系数\n",
    "print(\"Coefficients:\", model.coef_)\n",
    "print(\"Intercept:\", model.intercept_)\n",
    "\n",
    "# 可以用于预测的新数据\n",
    "new_data = np.array([[231, 65]])  # 新数据点\n",
    "# new_data = np.array([[65]])  # 新数据点\n",
    "new_data_poly = poly.transform(new_data)\n",
    "predicted_time = model.predict(new_data_poly)\n",
    "print(f\"Predicted worker2_time: {predicted_time[0]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated Python function:\n",
      "def predict_worker_time(prefill, decode):     return 12.987725      + 0.041082 * prefill      + 0.055191 * decode\n"
     ]
    }
   ],
   "source": [
    "# 获取特征名\n",
    "feature_names = ['prefill', 'decode']\n",
    "# feature_names = ['decode']\n",
    "poly_feature_names = poly.get_feature_names_out(feature_names)\n",
    "\n",
    "# 系数和截距\n",
    "coefs = model.coef_\n",
    "intercept = model.intercept_\n",
    "\n",
    "# 自动构造函数代码字符串\n",
    "func_lines = []\n",
    "func_lines.append(\"def predict_worker_time(prefill, decode):\")\n",
    "# func_lines.append(\"def predict_sample_time(decode):\")\n",
    "func_lines.append(f\"    return {intercept:.6f}\")\n",
    "\n",
    "# 从第1个特征开始（跳过偏置项）\n",
    "for i, name in enumerate(poly_feature_names[1:], start=1):\n",
    "    coef = coefs[i]\n",
    "    func_lines.append(f\"     + {coef:.6f} * {name}\")\n",
    "\n",
    "# 拼接成完整代码\n",
    "function_code = \" \".join(func_lines)\n",
    "print(\"Generated Python function:\")\n",
    "print(function_code)\n",
    "exec(function_code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_times = []\n",
    "predictions = []\n",
    "\n",
    "# 将X_test转换为DataFrame（如果还不是DataFrame的话）\n",
    "X_test_df = pd.DataFrame(X_test, columns=['prefill', 'decode', 'memory_util'])\n",
    "\n",
    "for i in range(len(X_test_df)):\n",
    "    row = X_test_df.iloc[i]\n",
    "    \n",
    "    # 记录开始时间\n",
    "    start_time = time.time()\n",
    "    \n",
    "    # 进行预测\n",
    "    pred = predict_time(\n",
    "        row['prefill'],\n",
    "        row['decode'],\n",
    "        row['memory_util']\n",
    "    )\n",
    "    \n",
    "    # 记录结束时间\n",
    "    end_time = time.time()\n",
    "    \n",
    "    prediction_times.append(end_time - start_time)\n",
    "    predictions.append(pred)\n",
    "\n",
    "# 创建结果DataFrame\n",
    "results_df = pd.DataFrame({\n",
    "    'prefill': X_test_df['prefill'],\n",
    "    'decode': X_test_df['decode'],\n",
    "    'memory_util': X_test_df['memory_util'],\n",
    "    'actual_time': y_test,\n",
    "    'predicted_time': predictions,\n",
    "    'prediction_time': prediction_times\n",
    "})\n",
    "\n",
    "# 显示结果\n",
    "print(\"\\n预测结果统计：\")\n",
    "print(\"=\"*50)\n",
    "print(f\"平均预测时间: {np.mean(prediction_times)*1000:.3f} 毫秒\")\n",
    "print(f\"最短预测时间: {min(prediction_times)*1000:.3f} 毫秒\")\n",
    "print(f\"最长预测时间: {max(prediction_times)*1000:.3f} 毫秒\")\n",
    "print(\"\\n详细预测结果：\")\n",
    "print(results_df.to_string(index=False))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
