{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "\n",
    "def process_and_parse_log(input_file):\n",
    "    # 定义正则表达式\n",
    "    patterns = [\n",
    "        r'#wait:\\s+(\\d+)/\\s+(\\d+)\\s+#run:\\s+(\\d+)\\s+#prefill:\\s+(\\d+)\\s+#decode:\\s+(\\d+)\\s+memory_util:\\s+([\\d.]+)\\s+%',\n",
    "        r'Worker rank (\\d+) step time: ([\\d.]+)',\n",
    "        r'sample time: ([\\d.]+)',\n",
    "        r'\\[perf_trace\\] (.+)'  # 新增 perf_trace JSON 匹配\n",
    "    ]\n",
    "\n",
    "    wait_seq_list, wait_token_list = [], []\n",
    "    run_list, prefill_list, decode_list = [], [], []\n",
    "    mem_util_list, sample_time_list = [], []\n",
    "    worker_time_dict = defaultdict(list)\n",
    "\n",
    "    # 用于 perf_trace 数据\n",
    "    prefill_to_compute_all, prefill_computed_all, decode_computed_all = [], [], []\n",
    "    prefill_batch_all, decode_count_all = [], []\n",
    "\n",
    "    with open(input_file, 'r') as f:\n",
    "        for line in f:\n",
    "            line = line.strip()\n",
    "\n",
    "            if m := re.search(patterns[0], line):\n",
    "                wait_seq_list.append(int(m.group(1)))\n",
    "                wait_token_list.append(int(m.group(2)))\n",
    "                run_list.append(int(m.group(3)))\n",
    "                prefill_list.append(int(m.group(4)))\n",
    "                decode_list.append(int(m.group(5)))\n",
    "                mem_util_list.append(float(m.group(6)))\n",
    "                continue\n",
    "\n",
    "            if m := re.search(patterns[1], line):\n",
    "                rank = int(m.group(1))\n",
    "                time = float(m.group(2))\n",
    "                worker_time_dict[rank].append(time)\n",
    "                continue\n",
    "\n",
    "            if m := re.search(patterns[2], line):\n",
    "                sample_time_list.append(float(m.group(1)))\n",
    "                continue\n",
    "\n",
    "            if m := re.search(patterns[3], line):\n",
    "                try:\n",
    "                    log_entry = json.loads(m.group(1))\n",
    "                    prefill_to_compute_all.append(log_entry.get(\"prefill_to_comput_tokens\", []))\n",
    "                    prefill_computed_all.append(log_entry.get(\"prefill_computed_tokens\", []))\n",
    "                    decode_computed_all.append(log_entry.get(\"decode_computed_tokens\", []))\n",
    "                except json.JSONDecodeError:\n",
    "                    print(\"⚠️ JSON decode error:\", line)\n",
    "                continue\n",
    "\n",
    "    # 构建 DataFrame\n",
    "    time_len = len(wait_seq_list)\n",
    "    df = pd.DataFrame({\n",
    "        \"time\": range(time_len),\n",
    "        \"wait_seq\": wait_seq_list,\n",
    "        \"wait_token\": wait_token_list,\n",
    "        \"run\": run_list,\n",
    "        \"prefill\": prefill_list,\n",
    "        \"decode\": decode_list,\n",
    "        \"memory_util\": mem_util_list,\n",
    "        \"sample_time\": sample_time_list[:time_len],  # 避免越界\n",
    "        \"prefill_to_compute\": prefill_to_compute_all[:time_len],\n",
    "        \"prefill_computed\": prefill_computed_all[:time_len],\n",
    "        \"decode_computed\": decode_computed_all[:time_len]\n",
    "    })\n",
    "\n",
    "    for rank, times in worker_time_dict.items():\n",
    "        if len(times) == time_len:\n",
    "            df[f\"worker_{rank}_time\"] = times\n",
    "        else:\n",
    "            print(f\"⚠️ Worker {rank} 时间长度 {len(times)} 与主数据长度 {time_len} 不一致，跳过该列。\")\n",
    "    return df\n",
    "\n",
    "# 示例调用\n",
    "if __name__ == \"__main__\":\n",
    "    # input_log_path = \"/mnt/sda/2022-0526/home/xuhx/projects/gLLM/experiments/profile_model/profile_sharegpt_32b_rate16.log\"\n",
    "    # input_log_path = \"/mnt/sda/2022-0526/home/xuhx/projects/gLLM/experiments/profile_model/profile_splitwise_32b_rate8.log\"\n",
    "    input_log_path=\"/mnt/sda/2022-0526/home/xuhx/projects/gLLM/baseline0_.log\"\n",
    "    df = process_and_parse_log(input_log_path)\n",
    "    df['each_layer_time'] = df['worker_1_time'] / 16\n",
    "    df = df[df['sample_time'] < 10]\n",
    "    df = df[df['worker_1_time'] < 100]\n",
    "    df['cost_ratio'] = df['sample_time'] / df['each_layer_time']\n",
    "    df = df[df['cost_ratio'] < 10]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型系数：\n",
      "θ1: 0.107900\n",
      "θ2: -0.000131\n",
      "θ3: 0.107900\n",
      "Intercept: 11.913007\n",
      "\n",
      "测试集 MSE: 22.069616\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import mean_squared_error\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# 特征构造函数\n",
    "def compute_features(df):\n",
    "    features = []\n",
    "\n",
    "    for _, row in df.iterrows():\n",
    "        n_list = row['prefill_to_compute']\n",
    "        r_list_prefill = row['prefill_computed']\n",
    "        r_list_decode = row['decode_computed']\n",
    "\n",
    "        x1 = sum(n for n in n_list) + len(r_list_decode)\n",
    "        x2 = sum(n**2 for n in n_list) + len(r_list_decode) + sum(n*r for n, r in zip(n_list, r_list_prefill)) + sum(r_list_decode)\n",
    "        x3 = sum(n for n in n_list) + len(r_list_decode)\n",
    "        # x4 = len(n_list) + len(r_list_decode)\n",
    "\n",
    "\n",
    "        features.append([x1, x2, x3])\n",
    "    return np.array(features)\n",
    "\n",
    "# 构造特征和目标变量\n",
    "X_full = compute_features(df)\n",
    "y_full = df['worker_1_time'].values\n",
    "\n",
    "# 划分训练集和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X_full, y_full, test_size=0.05, random_state=42)\n",
    "\n",
    "# 拟合模型\n",
    "model = LinearRegression(fit_intercept=True)\n",
    "model.fit(X_train,y_train)\n",
    "\n",
    "# 模型参数输出\n",
    "theta_names = ['θ1', 'θ2', 'θ3']\n",
    "print(\"模型系数：\")\n",
    "for name, coef in zip(theta_names, model.coef_):\n",
    "    print(f'{name}: {coef:.6f}')\n",
    "print(f'Intercept: {model.intercept_:.6f}')\n",
    "\n",
    "# 计算预测与 MSE\n",
    "y_pred = model.predict(X_test)\n",
    "mse = mean_squared_error(y_test , y_pred)\n",
    "print(f'\\n测试集 MSE: {mse:.6f}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "thetas = model.coef_\n",
    "intercept = model.intercept_\n",
    "print(thetas)\n",
    "print(intercept)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def predict_worker_time(n_list, r_list_prefill, r_list_decode):\n",
    "    thetas = model.coef_\n",
    "    intercept = model.intercept_\n",
    "    # thetas = [8.42961956e-06, 8.96638357e-07, 7.64362298e-02, 2.51992516e-04,1.27560378e-02]\n",
    "    # intercept = 11.413191334277132\n",
    "    # 构造特征\n",
    "\n",
    "    x1 = sum(n for n in n_list) + len(r_list_decode)\n",
    "    x2 = sum(n**2 for n in n_list) + len(r_list_decode) + sum(n*r for n, r in zip(n_list, r_list_prefill)) + sum(r_list_decode)\n",
    "    x3 = sum(n for n in n_list) + len(r_list_decode)\n",
    "\n",
    "\n",
    "    # 计算预测值\n",
    "    result = (thetas[0] * x1 +\n",
    "              thetas[1] * x2 +\n",
    "              thetas[2] * x3 +\n",
    "              intercept)    \n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_worker_time(prefill, decode):     return 11.238080      + 0.075143 * prefill      + 0.140392 * decode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "全体 MSE: 49.200088064374434\n"
     ]
    }
   ],
   "source": [
    "df['predict_worker_time'] = df.apply(\n",
    "    lambda row: predict_worker_time(\n",
    "        row['prefill_to_compute'],\n",
    "        row['prefill_computed'],\n",
    "        row['decode_computed']\n",
    "    ),\n",
    "    axis=1\n",
    ")\n",
    "# df['predict_worker_time'] = df.apply(\n",
    "#     lambda row: predict_worker_time(\n",
    "#         row['prefill'],\n",
    "#         row['decode']\n",
    "#     ),\n",
    "#     axis=1\n",
    "# )\n",
    "df['mse'] = ((df['predict_worker_time'] - df['worker_1_time'])) ** 2\n",
    "print(\"全体 MSE:\", df['mse'].mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max(df['predict_worker_time'] - df['worker_1_time']/16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "df['total_tokens'] = df['prefill'] + df['decode']\n",
    "# 设置图形样式\n",
    "sns.set(style=\"whitegrid\")\n",
    "\n",
    "# 绘制散点图\n",
    "plt.figure(figsize=(10, 6))\n",
    "sns.scatterplot(x='total_tokens', y='worker_1_time', data=df, alpha=0.6, label='Actual Worker 1 Time')\n",
    "sns.scatterplot(x='total_tokens', y='predict_worker_time', data=df, alpha=0.6, label='Predicted worker Time')\n",
    "\n",
    "# 添加标题和标签\n",
    "plt.title('Actual vs Predicted Time by Total Tokens', fontsize=14)\n",
    "plt.xlabel('Total Tokens', fontsize=12)\n",
    "plt.ylabel('Time (s)', fontsize=12)\n",
    "plt.legend()\n",
    "\n",
    "# 显示图形\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(\"tmp.csv\", index=False)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
