{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8cabc7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from trainer import *\n",
    "from dataloader import *\n",
    "from model import StrategicClassifierForWarmup, StrategicClassifierFiniteSet\n",
    "import numpy as np\n",
    "from model_utils import HingeLoss, BasicStrategicHingeLoss, AmbiguousStrategicHingeLoss\n",
    "import torch\n",
    "from datetime import datetime\n",
    "import random\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69a82da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_label_points(\n",
    "    n_points=1000,\n",
    "    x_low=-3.0,\n",
    "    x_high=3.0,\n",
    "    y_low=-3.0,\n",
    "    y_high=3.0,\n",
    "    shift=0.1,\n",
    "    seed=101,\n",
    "    cost_scaling = 1.0,\n",
    "    device=\"cpu\",\n",
    "    dtype=torch.float32\n",
    "):\n",
    "    if seed is not None:\n",
    "        np.random.seed(seed)\n",
    "\n",
    "    W = [[1,0], [1,1], [1, -1]]\n",
    "    b = [-1, 2, 2]\n",
    "    W = np.asarray(W)\n",
    "    b = np.asarray(b)\n",
    "\n",
    "    w_chosen = W[0]\n",
    "    b_chosen = b[0]\n",
    "\n",
    "    X = np.column_stack([\n",
    "        np.random.uniform(x_low, x_high, size=n_points),        # x ∈ [x_low, x_high]\n",
    "        np.random.uniform(y_low, y_high, size=n_points) # y ∈ [y_low, y_high]\n",
    "    ])\n",
    "\n",
    "    margins = X @ W.T + b[None, :]\n",
    "\n",
    "    cond_chosen = margins[:, 0] >= 0\n",
    "\n",
    "    two_norm = (2.0 / cost_scaling) * np.linalg.norm(w_chosen)\n",
    "    cond_intersection = np.all(margins >= -two_norm, axis=1)\n",
    "\n",
    "    positive = cond_chosen | cond_intersection\n",
    "    y = np.where(positive, 1, -1)\n",
    "\n",
    "    X_moved = X.copy()\n",
    "    X_moved[positive, 0] += shift / 2\n",
    "    X_moved[~positive, 0] -= shift / 2\n",
    "\n",
    "    X_final = X_moved\n",
    "    y_final = y\n",
    "\n",
    "    return X_final, y_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62c4cd2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_data(X, y, title=\"Training Data\"):\n",
    "    # X = X.detach().numpy()\n",
    "    # y = y.detach().numpy()\n",
    "\n",
    "    plt.figure(figsize=(6, 6))\n",
    "    plt.scatter(X[y == -1][:, 0], X[y == -1][:, 1], c='red', label='Class -1', alpha=0.6)\n",
    "    plt.scatter(X[y == 1][:, 0], X[y == 1][:, 1], c='blue', label='Class 1', alpha=0.6)\n",
    "    plt.xlabel(\"x₁\", fontsize=12)\n",
    "    plt.ylabel(\"x₂\", fontsize=12)\n",
    "    plt.title(title, fontsize=14)\n",
    "    plt.grid(True)\n",
    "    plt.legend()\n",
    "    plt.axis(\"equal\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82df9a64",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "set_seed(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "988c996b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cost_scaling = 0.5\n",
    "X,y = generate_and_label_points(n_points=1000, seed=101, shift=1.0, cost_scaling=cost_scaling, x_high=4.0, x_low=-6.0, y_high=10.0, y_low=-10.0)\n",
    "plot_data(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "275ea48e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.tensor(X, dtype=torch.float32)\n",
    "y = torch.tensor(y, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5fa5b94",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = SCPIDataset(X, y)\n",
    "dl_train, dl_val, dl_test = create_dataloaders(dataset, batch_size=1000, test_ratio=0.4, val_ratio=0.1, data_seed=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "179074eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_3_classifiers = StrategicClassifierFiniteSet(d=2, num_classifiers=3, dev=0.4, cost_scaling=cost_scaling)\n",
    "loss = AmbiguousStrategicHingeLoss()\n",
    "\n",
    "opt = torch.optim.Adam(model_3_classifiers.parameters(), lr=0.006)\n",
    "\n",
    "for name, p in model_3_classifiers.named_parameters():\n",
    "    print(name, p)\n",
    "\n",
    "trainer = StrategicTrainer(\n",
    "    model=model_3_classifiers,\n",
    "    loss_fn=loss,\n",
    "    optimizer=opt,\n",
    "    reg_classifier=0.001,\n",
    "    reg_auxiliary=0.001,\n",
    "    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
    ")\n",
    "metrics = trainer.fit(\n",
    "    dl_train=dl_train,\n",
    "    dl_val=dl_val,\n",
    "    num_epochs=500,\n",
    "    early_stopping=None\n",
    ")\n",
    "trainer.predict(dl_test)\n",
    "\n",
    "print(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c7b1e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, p in model_3_classifiers.named_parameters():\n",
    "    print(name, p)\n",
    "\n",
    "print(metrics)\n",
    "print(metrics[\"train_loss\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbee3997",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(1)\n",
    "str_classification_model = StrategicClassifierForWarmup(d=2, cost_scaling=cost_scaling)\n",
    "loss = BasicStrategicHingeLoss(scale_loss=cost_scaling)\n",
    "opt = torch.optim.Adam(str_classification_model.parameters(), lr=0.006)\n",
    "trainer = StrategicTrainer(\n",
    "    model=str_classification_model,\n",
    "    loss_fn=loss,\n",
    "    optimizer=opt,\n",
    "    reg_classifier=0.001,\n",
    "    write_metrics=False,\n",
    "    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    ")\n",
    "trainer.fit(\n",
    "    dl_train=dl_train,\n",
    "    dl_val=dl_val,\n",
    "    num_epochs=500,\n",
    "    early_stopping=None,\n",
    "    no_val = True\n",
    ")\n",
    "results = trainer.predict(dl_test)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1abc2ee3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results)\n",
    "for name, param in str_classification_model.named_parameters():\n",
    "    print(name, param.data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7254573",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(1)\n",
    "model_2_classifiers = StrategicClassifierFiniteSet(d=2, num_classifiers=2, dev=0.4, cost_scaling=cost_scaling)\n",
    "loss = AmbiguousStrategicHingeLoss()\n",
    "\n",
    "opt = torch.optim.Adam(model_2_classifiers.parameters(), lr=0.006)\n",
    "\n",
    "for name, p in model_2_classifiers.named_parameters():\n",
    "    print(name, p)\n",
    "\n",
    "trainer = StrategicTrainer(\n",
    "    model=model_2_classifiers,\n",
    "    loss_fn=loss,\n",
    "    optimizer=opt,\n",
    "    reg_classifier=0.001,\n",
    "    reg_auxiliary=0.001,\n",
    "    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n",
    ")\n",
    "metrics = trainer.fit(\n",
    "    dl_train=dl_train,\n",
    "    dl_val=dl_val,\n",
    "    num_epochs=500,\n",
    "    early_stopping=None\n",
    ")\n",
    "trainer.predict(dl_test)\n",
    "\n",
    "print(metrics)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a68b451",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(metrics)\n",
    "for name, param in model_2_classifiers.named_parameters():\n",
    "    print(name, param.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a16867a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(1)\n",
    "naive_model = StrategicClassifierForWarmup(d=2, cost_scaling=cost_scaling)\n",
    "loss = HingeLoss()\n",
    "opt = torch.optim.Adam(naive_model.parameters(), lr=0.006)\n",
    "trainer = StrategicTrainer(\n",
    "    model=naive_model,\n",
    "    loss_fn=loss,\n",
    "    optimizer=opt,\n",
    "    reg_classifier=0.001,\n",
    "    write_metrics=False,\n",
    "    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    ")\n",
    "trainer.fit(\n",
    "    dl_train=dl_train,\n",
    "    dl_val=dl_val,\n",
    "    num_epochs=500,\n",
    "    early_stopping=None,\n",
    "    no_val = True\n",
    ")\n",
    "results = trainer.predict(dl_test)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a6c021e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results\n",
    "for name, param in naive_model.named_parameters():\n",
    "    print(name, param.data)\n",
    "print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ec7fbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib.legend_handler import HandlerTuple # Import this for the combined legend\n",
    "\n",
    "def plot_model_classifiers(model, X, y, title=\"Model Decision Boundaries\", type=\"complex\"):\n",
    "    # 1. Prepare Data\n",
    "    X_np = X.detach().cpu().numpy()\n",
    "    y_np = y.detach().cpu().numpy()\n",
    "\n",
    "    plt.figure(figsize=(6, 4))\n",
    "\n",
    "    # 2. Plot Data Points\n",
    "    plt.scatter(X_np[y_np == -1][:, 0], X_np[y_np == -1][:, 1], \n",
    "                c='red', label='Class -1', alpha=0.5, s=25, edgecolors='none')\n",
    "    plt.scatter(X_np[y_np == 1][:, 0], X_np[y_np == 1][:, 1], \n",
    "                c='blue', label='Class 1', alpha=0.5, s=25, edgecolors='none')\n",
    "\n",
    "    # 3. Determine View Limits with Padding\n",
    "    x_min, x_max = X_np[:, 0].min(), X_np[:, 0].max()\n",
    "    y_min, y_max = X_np[:, 1].min(), X_np[:, 1].max()\n",
    "    \n",
    "    pad_x = (x_max - x_min) * 0.1\n",
    "    pad_y = (y_max - y_min) * 0.1\n",
    "    \n",
    "    view_x_min, view_x_max = x_min - pad_x, x_max + pad_x\n",
    "    view_y_min, view_y_max = y_min - pad_y, y_max + pad_y\n",
    "\n",
    "    # 4. Helper to plot using Bounding Box Intersections\n",
    "    def get_line_coords(w, b):\n",
    "        \"\"\"Calculates line coordinates within the view box\"\"\"\n",
    "        if abs(w[0]) < 1e-5 and abs(w[1]) < 1e-5: return None, None\n",
    "        \n",
    "        borders = [\n",
    "            ('bottom', view_y_min, 'horizontal'),\n",
    "            ('top',    view_y_max, 'horizontal'),\n",
    "            ('left',   view_x_min, 'vertical'),\n",
    "            ('right',  view_x_max, 'vertical')\n",
    "        ]\n",
    "        \n",
    "        intersections = []\n",
    "        for name, val, kind in borders:\n",
    "            if kind == 'horizontal': # w0*x + w1*val + b = 0\n",
    "                if abs(w[0]) > 1e-5:\n",
    "                    x_int = -(w[1] * val + b) / w[0]\n",
    "                    if view_x_min <= x_int <= view_x_max:\n",
    "                        intersections.append((x_int, val))\n",
    "            else: # w0*val + w1*y + b = 0\n",
    "                if abs(w[1]) > 1e-5:\n",
    "                    y_int = -(w[0] * val + b) / w[1]\n",
    "                    if view_y_min <= y_int <= view_y_max:\n",
    "                        intersections.append((val, y_int))\n",
    "        \n",
    "        intersections = list(set(intersections))\n",
    "        if len(intersections) >= 2:\n",
    "            intersections.sort(key=lambda p: p[0]) \n",
    "            return intersections[0], intersections[-1]\n",
    "        return None, None\n",
    "\n",
    "    def plot_line_and_arrow(w_param, b_param, color, label, linewidth, is_main=False, linestyle='-'):\n",
    "        w = w_param.detach().cpu().numpy()\n",
    "        b = b_param.detach().cpu().numpy()\n",
    "        \n",
    "        p1, p2 = get_line_coords(w, b)\n",
    "        \n",
    "        if p1 and p2:\n",
    "            # Plot Line\n",
    "            plt.plot([p1[0], p2[0]], [p1[1], p2[1]], color=color, linestyle=linestyle, \n",
    "                     label=label, alpha=1.0, linewidth=linewidth)\n",
    "            \n",
    "            # Only draw arrows for solid lines (main classifiers)\n",
    "            if linestyle == '-':\n",
    "                start_x = (p1[0] + p2[0]) / 2\n",
    "                start_y = (p1[1] + p2[1]) / 2\n",
    "                \n",
    "                w_norm = w / np.linalg.norm(w)\n",
    "                scale = 1.0 \n",
    "                \n",
    "                plt.arrow(start_x, start_y, \n",
    "                          w_norm[0] * scale, w_norm[1] * scale,\n",
    "                          head_width=0.3 if is_main else 0.25, \n",
    "                          head_length=0.3 if is_main else 0.25, \n",
    "                          fc=color, ec=color, zorder=10)\n",
    "\n",
    "    # 5. Draw Classifiers\n",
    "    \n",
    "    # Check if we are in the complex scenario (k=2/3 case)\n",
    "    if hasattr(model, 'w_chosen') and hasattr(model, 'classifiers_disguise'):\n",
    "        \n",
    "        title = f\"(k={model.num_classifiers}) Model Decision Boundary\"\n",
    "        \n",
    "        # Calculate shift based on w_chosen norm\n",
    "        w_ref = model.w_chosen.detach().cpu().numpy()\n",
    "        norm_val = np.linalg.norm(w_ref)\n",
    "        shift_amount = 2 * norm_val \n",
    "\n",
    "        # A. Plot Possible (Disguise) Classifiers\n",
    "        for i, (w_d, b_d) in enumerate(zip(model.classifiers_disguise, model.b_disguise)):\n",
    "            # Solid Line (Possible)\n",
    "            lbl = 'Possible' if i == 0 else None\n",
    "            plot_line_and_arrow(w_d, b_d, color='green', label=lbl, linewidth=2.5, is_main=False)\n",
    "            \n",
    "            # Dashed Line (Decision Boundary)\n",
    "            # We label the FIRST one 'Decision Boundary'\n",
    "            dash_lbl = 'Decision Boundary' if i == 0 else None\n",
    "            plot_line_and_arrow(w_d, b_d + shift_amount, color='green', label=dash_lbl, \n",
    "                                linewidth=2.0, linestyle='--', is_main=False)\n",
    "\n",
    "        # B. Plot Realized (Chosen) Classifier\n",
    "        # Solid Line (Realized)\n",
    "        plot_line_and_arrow(model.w_chosen, model.b_chosen, \n",
    "                            color='black', label='Realized', linewidth=2.5, is_main=True)\n",
    "        \n",
    "        # Dashed Line (Decision Boundary)\n",
    "        # We ALWAYS label this 'Decision Boundary' too (it will be combined in legend)\n",
    "        plot_line_and_arrow(model.w_chosen, model.b_chosen + shift_amount, \n",
    "                            color='black', label='Decision Boundary', linewidth=2.0, linestyle='--', is_main=True)\n",
    "\n",
    "    # Simple Model Case (Single w)\n",
    "    elif hasattr(model, 'w'):\n",
    "        w_ref = model.w.detach().cpu().numpy()\n",
    "        norm_val = np.linalg.norm(w_ref)\n",
    "        shift_amount = 2 * norm_val \n",
    "\n",
    "        plot_line_and_arrow(model.w, model.b, \n",
    "                            color='black', label='Classifier', linewidth=2.5, is_main=True)\n",
    "        \n",
    "        plot_line_and_arrow(model.w, model.b + shift_amount, \n",
    "                            color='black', label='Decision Boundary', linewidth=2.0, linestyle='--', is_main=True)\n",
    "\n",
    "    # 6. Final Formatting\n",
    "    plt.xlabel(\"x₁\")\n",
    "    plt.ylabel(\"x₂\")\n",
    "    plt.title(title)\n",
    "    \n",
    "    plt.xlim(view_x_min, view_x_max)\n",
    "    plt.ylim(view_y_min, view_y_max)\n",
    "    \n",
    "    # --- Custom Legend Logic ---\n",
    "    # Retrieve all handles and labels\n",
    "    handles, labels = plt.gca().get_legend_handles_labels()\n",
    "    \n",
    "    # Dictionary to group handles by label\n",
    "    label_dict = {}\n",
    "    for h, l in zip(handles, labels):\n",
    "        if l not in label_dict:\n",
    "            label_dict[l] = []\n",
    "        label_dict[l].append(h)\n",
    "    \n",
    "    # Rebuild final lists\n",
    "    final_handles = []\n",
    "    final_labels = []\n",
    "    \n",
    "    for l, h_list in label_dict.items():\n",
    "        if l == \"Decision Boundary\":\n",
    "            # For \"Decision Boundary\", we combine ALL handles (Green + Black) into a tuple\n",
    "            final_handles.append(tuple(h_list))\n",
    "            final_labels.append(l)\n",
    "        else:\n",
    "            # For others (Possible, Realized, Classes), just take the first handle\n",
    "            final_handles.append(h_list[0])\n",
    "            final_labels.append(l)\n",
    "\n",
    "    # Create legend with HandlerTuple to display both lines next to \"Decision Boundary\"\n",
    "    plt.legend(final_handles, final_labels, loc='lower left', \n",
    "               handler_map={tuple: HandlerTuple(ndivide=None)})\n",
    "    \n",
    "    plt.grid(True, alpha=0.3)\n",
    "    plt.gca().set_aspect('equal', adjustable='datalim')\n",
    "\n",
    "    plt.savefig(f'decision_boundary_{type}.eps', format='eps', bbox_inches='tight')\n",
    "    plt.savefig(f'decision_boundary_{type}.pdf', format='pdf', bbox_inches='tight')\n",
    "    plt.savefig(f'decision_boundary_{type}.png', format='png', dpi=300, bbox_inches='tight')\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6df7a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_model_classifiers(model_3_classifiers, X, y, title=\"3-Classifiers Model Decision Boundaries\", type=\"3_classifiers\")\n",
    "plot_model_classifiers(model_2_classifiers, X, y, title=\"2-Classifiers Model Decision Boundaries\", type=\"2_classifiers\")\n",
    "plot_model_classifiers(naive_model, X, y, title=\"Naive Model Decision Boundary\", type=\"naive\")\n",
    "plot_model_classifiers(str_classification_model, X, y, title=\"Strategic Classifier Decision Boundary\", type=\"strategic\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
