{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import sys\n",
    "import torch\n",
    "import math\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from src.utils import set_seed\n",
    "from src.args import parse_args\n",
    "from src.task_generator import generate_linear_task, generate_circle_task, generate_moon_task\n",
    "from src.datagenerator import generate_grid_data\n",
    "from src.prompt import batch_prompt_generation\n",
    "#from src.kernel_hsic import kernel_HSIC, linear_HSIC\n",
    "from src.conventional import *\n",
    "\n",
    "\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "from tqdm import tqdm\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def centering(K):\n",
    "    n = np.shape(K)[0]\n",
    "    unit = np.ones(n)\n",
    "    I = np.eye(n)\n",
    "    H = I - unit/n\n",
    "    return (H @ K) @ H\n",
    "\n",
    "\n",
    "def linear_HSIC(X, Y):\n",
    "    L_X = X @ X.transpose()\n",
    "    L_Y = Y @ Y.transpose()\n",
    "    return np.sum(centering(L_X) * centering(L_Y))\n",
    "\n",
    "\n",
    "def rbf(X, sigma=None):\n",
    "    GX = X @ X.transpose()\n",
    "    KX = np.diag(GX) - GX + (np.diag(GX) - GX).transpose(0,1)\n",
    "    if sigma is None:\n",
    "        # mdist = torch.median(KX[KX != 0])\n",
    "        try:\n",
    "            mdist = np.median(KX[KX != 0])\n",
    "            #mdist = torch.quantile(KX[KX != 0], q=0.75)\n",
    "        except:\n",
    "            #mdist = 5.\n",
    "            mdist = np.zeros(1).to(KX.device)\n",
    "        sigma = math.sqrt(np.clip(a=mdist, a_min=1e-12, a_max=1e+12))\n",
    "        #print(sigma)\n",
    "    KX = KX * (-0.5 / (sigma * sigma))\n",
    "    KX = np.exp(KX)\n",
    "    #print(KX)\n",
    "    return KX\n",
    "\n",
    "def kernel_HSIC(X, Y, sigma=None):\n",
    "    return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Correlation Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_correlation(X:np.ndarray, Y:np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Calculate linear correlation.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): A vector with the size of [1, num_query]\n",
    "        Y (np.ndarray): A vector with the size of [1, num_query]\n",
    "\n",
    "    Returns:\n",
    "        float: the correlation result\n",
    "    \"\"\"\n",
    "    return np.cov(X, Y) / (np.sqrt(X.var() * Y.var()) + 1e-12)\n",
    "    \n",
    "\n",
    "def r_square_correlation(simu_res:np.ndarray, llm_res:np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Calculate R^2 correlation.\n",
    "\n",
    "    Args:\n",
    "        simu_res (np.ndarray): A vector with the size of [1, num_query]\n",
    "        llm_res (np.ndarray): A vector with the size of [1, num_query]\n",
    "\n",
    "    Returns:\n",
    "        float: the correlation result\n",
    "    \"\"\"\n",
    "    sst = np.sum(np.square(llm_res - np.mean(llm_res)))\n",
    "    sse = np.sum(np.square(llm_res - simu_res))\n",
    "    return 1 - sse/sst\n",
    "\n",
    "\n",
    "def hsic(X:np.ndarray, Y:np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Calculate the HSIC value.\n",
    "\n",
    "    Args:\n",
    "        X (np.ndarray): A vector with the size of [1, num_query]\n",
    "        Y (np.ndarray): A vector with the size of [1, num_query]\n",
    "\n",
    "    Returns:\n",
    "        float: the correlation result\n",
    "    \"\"\"\n",
    "    num_sample = X.shape[1]\n",
    "    return kernel_HSIC(X.transpose(), Y.transpose(), sigma=1.3) / np.square(num_sample - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATAROOT = \"ROOT_PATH_OF_FILES\"\n",
    "\n",
    "llm_pred_paths = {\n",
    "    \"linear\": os.path.join(DATAROOT, \"data_records/pred_results\", \"llama-3_binary_linear_classification_preds.npy\"),\n",
    "    \"circle\": os.path.join(DATAROOT, \"data_records/pred_results\", \"llama-3_binary_circle_classification_preds.npy\"),\n",
    "    \"moon\": os.path.join(DATAROOT, \"data_records/pred_results\", \"llama-3_binary_moon_classification_preds.npy\"),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_preds(task_mode:str) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Get the predictions of LLMs from the saved file.\n",
    "    Make sure all data files exist.\n",
    "\n",
    "    Args:\n",
    "        task_mode (str): The task mode. [\"linear\", \"circle\", \"moon\"].\n",
    "\n",
    "    Returns:\n",
    "        np.ndarray: The predictions of the LLMs.\n",
    "    \"\"\"\n",
    "    assert os.path.exists(llm_pred_paths[task_mode]), \"The data file does not exist. Please run `run_llm.py` first.\"\n",
    "    data = np.load(file=llm_pred_paths[task_mode], allow_pickle=True).item()[\"predictions\"]\n",
    "    \n",
    "    return np.array(data)\n",
    "    \n",
    "\n",
    "def select_data(data, labels, target_label):\n",
    "    \n",
    "    targets = []\n",
    "    for i in range(data.shape[0]):\n",
    "        if labels[i] == target_label:\n",
    "            targets.append(data[i])\n",
    "    return np.array(targets)\n",
    "\n",
    "\n",
    "def get_probability(num_methods:int, mode:str=\"standard\") -> List[float]:\n",
    "    \"\"\"\n",
    "    Get the probability for randomly sampling methods.\n",
    "\n",
    "    Args:\n",
    "        num_methods (int): The number of candidates.\n",
    "        mode (str, optional): The type of probability.\n",
    "\n",
    "    Returns:\n",
    "        List[float]: A list of probability.\n",
    "    \"\"\"\n",
    "    if mode == \"standard\":\n",
    "        probs = [0.673, 0.044, 0.003, 0.0, 0.280]\n",
    "    elif mode == \"ml-only\":\n",
    "        probs = [0.8, 0.02, 0.01, 0.04, 0.13]\n",
    "    elif mode == \"uniform\":\n",
    "        probs = [1./num_methods] * num_methods\n",
    "    # elif mode == \"random\":\n",
    "    #     rand_numbers = np.random.rand(num_methods)\n",
    "    #     probs = rand_numbers / rand_numbers.sum()\n",
    "    else:\n",
    "        raise ValueError(\"Unrecognized mode.\")\n",
    "    \n",
    "    return probs\n",
    "\n",
    "\n",
    "def apply_method(model_kw:str, data:np.ndarray, labels:np.ndarray, query:np.ndarray, args):\n",
    "    \"\"\"\n",
    "    Apply machine methods according to the model key words.\n",
    "\n",
    "    Args:\n",
    "        model_kw (str): Model key word.\n",
    "        data (np.ndarray): In-context data.\n",
    "        labels (np.ndarray): Labels of in-context data.\n",
    "        query (np.ndarray): Query data.\n",
    "        args (_type_): Hyperparameters\n",
    "\n",
    "    Raises:\n",
    "        ValueError: _description_\n",
    "\n",
    "    Returns:\n",
    "        _type_: predictions.\n",
    "    \"\"\"\n",
    "    \n",
    "    if model_kw == \"decision_tree\":\n",
    "        pred_labels = decisiontree(data=data, labels=labels, queries=query, seed=args[\"seed\"])\n",
    "    elif model_kw == \"mlp\":\n",
    "        pred_labels = mlp(data=data, labels=labels, queries=query, randseed=args[\"seed\"])\n",
    "    elif model_kw == \"knn\":\n",
    "        pred_labels = knn(data=data, labels=labels, queries=query)\n",
    "    elif model_kw == \"svm\":\n",
    "        pred_labels = svm(data=data, labels=labels, queries=query, seed=args[\"seed\"])\n",
    "    elif model_kw == \"linear_regression\":\n",
    "        pred_labels = linear_regression(data=data, labels=labels, queries=query)\n",
    "    else:\n",
    "        raise ValueError(f\"Unrecognized model.\")\n",
    "    return pred_labels\n",
    "\n",
    "\n",
    "def prob_hybrid(args, models:List[str], data:np.ndarray, labels:np.ndarray, queries:np.ndarray, probList:List[float], prob_type:str=\"standard\"):\n",
    "    \"\"\"\n",
    "    Simulate the behavior of LLMs with various kinds of probabilities.\n",
    "\n",
    "    Args:\n",
    "        args (_type_): Hyperparameters.\n",
    "        models (List[str]): A list of models.\n",
    "        data (np.ndarray): In-context data.\n",
    "        labels (np.ndarray): Labels of in-context data.\n",
    "        queries (np.ndarray): Query data.\n",
    "        probs (List[float]): A set of probabilities. If probs is not None, use probs.\n",
    "        prob_type (str): The type of probability.\n",
    "\n",
    "    Returns:\n",
    "        _type_: predictions\n",
    "    \"\"\"\n",
    "    if probList is not None:\n",
    "        probs = probList\n",
    "    else:\n",
    "        probs = get_probability(num_methods=len(models), mode=prob_type)\n",
    "\n",
    "    preds = []\n",
    "    \n",
    "    for subquery in tqdm(queries):\n",
    "        subquery = np.reshape(subquery, (1, -1))\n",
    "        model = np.random.choice(models, p=probs)\n",
    "        \n",
    "        pred_labels = apply_method(model_kw=model, data=data, labels=labels, query=subquery, args=args)\n",
    "        \n",
    "        preds.append(pred_labels[0])\n",
    "        \n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#args = parse_args()\n",
    "seed = 11\n",
    "task_mode = \"linear_classification\"\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "\n",
    "if \"linear\" in task_mode:\n",
    "    data, labels = generate_linear_task(num_classes=2, mode=\"train\", num_samples=128, randseed=seed)\n",
    "elif \"circle\" in task_mode:\n",
    "    data, labels = generate_circle_task(mode=\"train\", noise=0.03, num_samples=128, randseed=seed)\n",
    "elif \"moon\" in task_mode:\n",
    "    data, labels = generate_moon_task(mode=\"train\", num_samples=128, randseed=seed)\n",
    "else:\n",
    "    raise ValueError(\"Unrecognized task mode.\")\n",
    "\n",
    "### generate grid data\n",
    "queries = generate_grid_data(data)\n",
    "\n",
    "models = [\"decision_tree\", \"knn\", \"svm\", \"mlp\", \"linear_regression\"]\n",
    "\n",
    "args = {\n",
    "    \"seed\": seed\n",
    "}\n",
    "\n",
    "### standard simulation\n",
    "standard_preds = prob_hybrid(args=args, models=models, data=data, labels=labels, probList=None, queries=queries, prob_type=\"standard\")\n",
    "\n",
    "### uniform simulation\n",
    "uniform_preds = prob_hybrid(args=args, models=models, data=data, labels=labels, probList=None, queries=queries, prob_type=\"uniform\")\n",
    "\n",
    "### standard pred of LLM\n",
    "llm_std_pred = np.load(file=os.path.join(DATAROOT, \"data_records/pred_results\", \"llama-3_binary_linear_classification_preds.npy\"), allow_pickle=True).item()[\"predictions\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### simu & llm\n",
    "print(f\"The linear correlation between simu and llm in STANDARD mode is: {linear_correlation(X=np.reshape(standard_preds, (1, -1)), Y=np.reshape(llm_std_pred, (1, -1)))}\\n\")\n",
    "\n",
    "print(f\"The R^2 correlation between simu and llm in STANDARD mode is: {r_square_correlation(simu_res=np.reshape(standard_preds, (1, -1)), llm_res=np.reshape(llm_std_pred, (1, -1)))}\\n\")\n",
    "\n",
    "print(f\"The HSIC correlation between simu and llm in STANDARD mode is: {hsic(X=np.reshape(standard_preds, (1, -1)), Y=np.reshape(llm_std_pred, (1, -1)))}\\n\")\n",
    "\n",
    "### uniform simu & llm\n",
    "print(f\"The linear correlation between uniform_simu and llm in UNIFORM mode is: {linear_correlation(X=np.reshape(uniform_preds, (1, -1)), Y=np.reshape(llm_std_pred, (1, -1)))}\\n\")\n",
    "\n",
    "print(f\"The R^2 correlation between uniform_simu and llm in UNIFORM mode is: {r_square_correlation(simu_res=np.reshape(uniform_preds, (1, -1)), llm_res=np.reshape(llm_std_pred, (1, -1)))}\\n\")\n",
    "\n",
    "print(f\"The HSIC correlation between uniform_simu and llm in UNIFORM mode is: {hsic(X=np.reshape(uniform_preds, (1, -1)), Y=np.reshape(llm_std_pred, (1, -1)))}\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "agent_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
