{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from utils import plot_images\n",
    "\n",
    "def find_clusters(image):\n",
    "    clusters = []\n",
    "    visited = set()\n",
    "\n",
    "    def dfs(row, col, cluster):\n",
    "        if (row, col) in visited or image[0, row, col] != 0:\n",
    "            return\n",
    "        visited.add((row, col))\n",
    "        cluster.append((row, col))\n",
    "        if row > 0:\n",
    "            dfs(row - 1, col, cluster)\n",
    "        if row < image.shape[1] - 1:\n",
    "            dfs(row + 1, col, cluster)\n",
    "        if col > 0:\n",
    "            dfs(row, col - 1, cluster)\n",
    "        if col < image.shape[2] - 1:\n",
    "            dfs(row, col + 1, cluster)\n",
    "\n",
    "    for row in range(image.shape[1]):\n",
    "        for col in range(image.shape[2]):\n",
    "            if (row, col) not in visited and image[0, row, col] == 0:\n",
    "                cluster = []\n",
    "                dfs(row, col, cluster)\n",
    "                clusters.append(cluster)\n",
    "\n",
    "    return clusters\n",
    "\n",
    "def adapt_cluster(cluster, x, y):\n",
    "    if len(cluster) == 0:\n",
    "        return []\n",
    "\n",
    "    # Function to adapt the cluster for a given orientation\n",
    "    def adapt_for_orientation(sorted_cluster, primary, secondary, limit):\n",
    "        adapted = sorted_cluster[:x]  # Initial adaptation based on size\n",
    "        primary_pos = adapted[-1][primary]\n",
    "\n",
    "        # Expand if necessary, ensuring the secondary dimension doesn't exceed 'limit'\n",
    "        while len(adapted) < x:\n",
    "            next_pos = (adapted[-1][0], adapted[-1][1])\n",
    "            next_pos = list(next_pos)\n",
    "            next_pos[primary] += 1\n",
    "\n",
    "            if next_pos[secondary] - adapted[0][secondary] < limit:\n",
    "                adapted.append(tuple(next_pos))\n",
    "            else:\n",
    "                break  # Stop if we exceed the limit\n",
    "\n",
    "        return adapted[:x]  # Ensure the length is exactly x\n",
    "\n",
    "    # Sort for horizontal and vertical orientations\n",
    "    cluster.sort(key=lambda pos: (pos[0], pos[1]))  # Vertical\n",
    "    vertical_adapted = adapt_for_orientation(cluster, 0, 1, y)\n",
    "\n",
    "    cluster.sort(key=lambda pos: (pos[1], pos[0]))  # Horizontal\n",
    "    horizontal_adapted = adapt_for_orientation(cluster, 1, 0, y)\n",
    "\n",
    "    # Choose the adaptation that requires fewer additional pixels\n",
    "    if len(vertical_adapted) > len(horizontal_adapted):\n",
    "        return horizontal_adapted\n",
    "    else:\n",
    "        return vertical_adapted\n",
    "\n",
    "# Modify find_optimal_structure accordingly\n",
    "def find_optimal_structure(image, x, y):\n",
    "    clusters = find_clusters(image)\n",
    "    adapted_clusters = [adapt_cluster(cluster, x, y) for cluster in clusters]\n",
    "    optimal_structure = min(adapted_clusters, key=lambda c: len(c), default=[])\n",
    "    return [(row, col) for row, col in optimal_structure]\n",
    "\n",
    "\n",
    "\n",
    "# Example usage\n",
    "print(111)\n",
    "image = torch.ones([1, 64, 64])  # Example binary image\n",
    "image[:, 4:20, 4:16] = 0\n",
    "\n",
    "image[:, 20:25, 20:38] = 0\n",
    "x = 40  # Length of contiguous zero pixels\n",
    "y = 2   # Minor axis length\n",
    "\n",
    "structure_positions = find_optimal_structure(image, x, y)\n",
    "print(\"Positions of the structure:\", structure_positions)\n",
    "\n",
    "new_image = image.clone()\n",
    "new_image[:, structure_positions] = 0\n",
    "\n",
    "plot_images(torch.stack([image, new_image]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"| system |: Starting imports...\")\n",
    "\n",
    "import cvxpy as cp\n",
    "import torch\n",
    "import numpy as np\n",
    "from utils import plot_images\n",
    "\n",
    "\n",
    "def create_optimized_structure(image, x, y):\n",
    "    \n",
    "    print(\"| cvxpy |: Setting up projection...\")\n",
    "    \n",
    "    shape = image.shape[1:]\n",
    "    variables = cp.Variable(shape, boolean=True)\n",
    "    image_reshaped = image[0].numpy()\n",
    "\n",
    "    # Objective: Minimize the number of bit flips\n",
    "    bit_flip_term = cp.sum(cp.abs(variables - image_reshaped))\n",
    "\n",
    "    # Add penalty for discontinuities (vectorized)\n",
    "    penalty_weight = 0.1\n",
    "    vertical_diff = variables[:-1, :] - variables[1:, :]\n",
    "    horizontal_diff = variables[:, :-1] - variables[:, 1:]\n",
    "    discontinuity_penalty = penalty_weight * (cp.sum(cp.abs(vertical_diff)) + cp.sum(cp.abs(horizontal_diff)))\n",
    "\n",
    "    # Define the objective\n",
    "    objective = cp.Minimize(bit_flip_term + discontinuity_penalty)\n",
    "\n",
    "    # Constraints\n",
    "    print(\"| cvxpy |: Compiling constraints...\")\n",
    "    constraints = []\n",
    "    \n",
    "#     constraints.append(cp.sum(variables) == (shape[0]*shape[1] - x))\n",
    "\n",
    "    for i in range(shape[0]):\n",
    "        row_sum = cp.sum(variables[i, :])\n",
    "        col_sum = cp.sum(variables[:, i])\n",
    "\n",
    "        row_constraint = cp.Variable(boolean=True)\n",
    "        col_constraint = cp.Variable(boolean=True)\n",
    "\n",
    "        # Constraint for row\n",
    "        constraints += [\n",
    "            row_sum >= (shape[0] - y) - (1 - row_constraint) * shape[0],\n",
    "            row_constraint >= 0\n",
    "        ]\n",
    "\n",
    "        # Constraint for column\n",
    "        constraints += [\n",
    "            col_sum >= (shape[1] - y) - (1 - col_constraint) * shape[1],\n",
    "            col_constraint >= 0\n",
    "        ]\n",
    "\n",
    "        # At least one constraint (row or column) must be satisfied\n",
    "        constraints.append(row_constraint + col_constraint >= 1)\n",
    "\n",
    "    # Define and solve the problem\n",
    "    problem = cp.Problem(objective, constraints)\n",
    "    print(\"| cvxpy |: Solving projection...\")\n",
    "    problem.solve(verbose=True) #, scipy_options={'method':'highs-ds', 'maxiter':10000})\n",
    "\n",
    "    # Extract the result\n",
    "    result = np.round(variables.value)  # Round to nearest integer (0 or 1)\n",
    "    print(\"| cvxpy |: Projection complete.\")\n",
    "    \n",
    "    return result\n",
    "\n",
    "# Example usage\n",
    "print(\"| system |: Starting problem...\")\n",
    "image = torch.ones([1, 64, 64])  # Example binary image\n",
    "image[:, 4:30, 4:20] = 0\n",
    "x = 105  # Length of contiguous zero pixels\n",
    "y = 5   # Maximum width/height of the structure\n",
    "\n",
    "optimized_structure = create_optimized_structure(image.clone(), x, y)\n",
    "print(\"Optimized Structure:\\n\", optimized_structure)\n",
    "\n",
    "print(\"| system |: Plotting images...\")\n",
    "\n",
    "plot_images(torch.stack([image, torch.from_numpy(optimized_structure).reshape((1, 64, 64))]))\n",
    "\n",
    "print(\"| system |: Complete. Exiting.\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch-cvxpy",
   "language": "python",
   "name": "torch-cvxpy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
