{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f784a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "import plotly.express as px\n",
    "import plotly.io as pio\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "time_step = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000, 10000, 11000, 12000, 13000, 14000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, 55000, 60000, 65000, 70000, 80000, 90000, 100000, 110000, 120000, 130000, 140000, 143000]\n",
    "\n",
    "\n",
    "dic = torch.load(\"/users/qyu10/lab/circuits-over-time/compiled_metric_dict.pt\")\n",
    "\n",
    "logit_diff_ioi = dic['pythia-1.4b']['ioi']['mrr']\n",
    "logit_diff_greater_than = dic['pythia-1.4b']['greater_than']['mrr']\n",
    "logit_diff_sentiment_cont = dic['pythia-1.4b']['sentiment_cont']['mrr']\n",
    "logit_diff_sentiment_class = dic['pythia-1.4b']['sentiment_class']['mrr']\n",
    "\n",
    "def moving_average(input_tensor, window_size):\n",
    "    \"\"\"Calculate the moving average with a given window size.\"\"\"\n",
    "    # Prepend a zero to the input_tensor\n",
    "    padded_input = torch.cat([torch.zeros(1, device=input_tensor.device), input_tensor])\n",
    "    cumsum_vec = torch.cumsum(padded_input, dim=0)\n",
    "    moving_avg = (cumsum_vec[window_size:] - cumsum_vec[:-window_size]) / window_size\n",
    "    return moving_avg\n",
    "\n",
    "\n",
    "def get_start(input_tensor, START_THRESHOLD, END_THRESHOLD, window_size=3):\n",
    "    \"\"\"\n",
    "    Adjust the function to use the rolling average of the last n \"differences\" \n",
    "    to calculate the end_index.\n",
    "\n",
    "    Args:\n",
    "        input_tensor (torch.Tensor): Input tensor.\n",
    "        START_THRESHOLD (float): Threshold to detect the start.\n",
    "        END_THRESHOLD (float): Threshold to detect the end.\n",
    "        window_size (int): Window size for the rolling average calculation.\n",
    "\n",
    "    Returns:\n",
    "        tuple: A tuple containing the start and end indices.\n",
    "    \"\"\"\n",
    "    differences = input_tensor[1:] - input_tensor[:-1]\n",
    "    # Find start index\n",
    "    start_indices = torch.nonzero(differences > START_THRESHOLD).view(-1)\n",
    "    first_index = start_indices[0].item() if len(start_indices) > 0 else None\n",
    "\n",
    "    if first_index is not None and first_index + window_size <= len(differences):\n",
    "        rolled_differences = moving_average(differences[first_index:], window_size)\n",
    "        # Adjust indices to match the original differences tensor\n",
    "        adjusted_indices = torch.nonzero(rolled_differences < END_THRESHOLD).view(-1) + first_index + window_size - 1\n",
    "        end_index = adjusted_indices[0].item() if len(adjusted_indices) > 0 else len(input_tensor)\n",
    "    else:\n",
    "        end_index = len(input_tensor) - 1  # Adjust to get the actual last index\n",
    "\n",
    "    return first_index, end_index\n",
    "\n",
    "\n",
    "def line_with_gradient(tensor, time_step, intercept, coefficient, x_start, x_end, renderer=None, width=1200, height=500, **kwargs):\n",
    "    # Convert tensor to numpy for plotting\n",
    "    y_values = np.array(tensor)\n",
    "    \n",
    "    # Create the initial line plot\n",
    "    fig = px.line(x = time_step, y=y_values,**kwargs)\n",
    "    \n",
    "    # Calculate y values for the superimposed line based on the given intercept and coefficient\n",
    "    x_values = time_step[x_start: x_end]\n",
    "    y_line = [coefficient[0] * x + intercept for x in x_values]\n",
    "    \n",
    "    # Add the superimposed line to the figure\n",
    "    #fig = px.line(x = x_values, y=y_line, **kwargs)\n",
    "    fig.add_trace(go.Scatter(x=x_values, y=y_line, mode='lines', name='Superimposed Line', fillcolor = 'red'))\n",
    "    \n",
    "    # Update layout with specified width and height\n",
    "    fig.update_layout(\n",
    "        autosize=False,\n",
    "        width=width,\n",
    "        height=height\n",
    "    )\n",
    "    \n",
    "    # Show the figure with the optional renderer\n",
    "    fig.show(renderer=renderer)\n",
    "\n",
    "coef = {}\n",
    "for i in ['pythia-70m', 'pythia-14m', 'pythia-410m', 'pythia-1.4b', 'pythia-160m', 'pythia-12b', 'pythia-31m', 'pythia-2.8b']:\n",
    "    coef[i] = {}\n",
    "    for k in ['ioi', 'greater_than']:\n",
    "        mrr = dic[i][k]['mrr']\n",
    "        s, e = get_start(mrr, 0.001, 0.02*torch.max(mrr))\n",
    "        # Creating a linear regression model\n",
    "        model = LinearRegression()\n",
    "\n",
    "        # Training the model\n",
    "        model.fit(np.array(time_step[s:e]).reshape(-1, 1), mrr[s:e])\n",
    "\n",
    "        coef[i][k] = [model.coef_[0], s, e]\n",
    "        \n",
    "        line_with_gradient(mrr, time_step, model.intercept_, model.coef_, s, e, title = f\"{i}_{k}\", log_x = True)\n",
    "\n",
    "from pprint import pprint\n",
    "pprint(coef)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "682b2a2f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
