{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# challenge \"mean\" baseline\n",
    "\n",
    "This notebook creates baseline embeddings by bilinear interpolation and averaging of the modalities.\n",
    "\n",
    "We use the ChallengeDataset to load the data. The datacubes of the challenge data are of shapes (1, 4, 27, 264, 264), (number of samples, number of timesteps, number of channels, height, width).\n",
    "\n",
    "The embedding works as follow:\n",
    "1. Subsample each channel to 8x8 pixels using bilinear interpolation -> shape (1, 4, 27, 8, 8)\n",
    "2. Average B01 through B09 for both S2L1C and L2 L2A along the channel dimension. Average B11 and B12 along the channel dimension. Average S1 channels along the channel dimension. Concatenate the three averages and B10 along channel dimension -> shape (1, 4, 4, 8, 8)\n",
    "3. Flatten into 1024 element vector -> shape (1024,)\n",
    "\n",
    "After embedding, a submission file is created in the expected format for the challenge. If you use this code, verify that it produces the right number of decimals for your output.\n",
    "\n",
    "At the end, a function to test that a submission file is readable for evaluation is provided.\n",
    "\n",
    "Note that parts of this notebook is simplified for demonstration purposes. However, the datasets and dataloaders, as well as the verification of the submission file are intended to be directly usable and true to the data and the expected submission file formats."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from scipy.ndimage import zoom\n",
    "from torchvision import transforms\n",
    "\n",
    "from data.dataset import ChallengeDataset, S2L1C_MEAN, S2L1C_STD, S2L2A_MEAN, S2L2A_STD, S1GRD_MEAN, S1GRD_STD"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Order of modalities.\n",
    "# In this demo, modalities are ordered the same as the default order in the SSL4EOS12 dataset class.\n",
    "# Modalities are loaded in the order provided here.\n",
    "# Change the order based on your needs.\n",
    "modalities = ['s2l1c', 's2l2a', 's1']\n",
    "\n",
    "# Path to challenge data folder, i.e. the folder containing the s1, s2l1c and s2l2a subfolders.\n",
    "path_to_data = '/path/to/challenge/data/'\n",
    "\n",
    "# Path to where the submission file should be written.\n",
    "path_to_output_file = 'path/to/output/file.csv'\n",
    "\n",
    "write_result_to_file = True  # Set to True to trigger saving of the csv at the end.\n",
    "\n",
    "# Create data transformation\n",
    "# Get mean and standard deviations for the modalities in the same order as the modalities\n",
    "# Note that we will use the `shift_s2_channels` flag in the challenge dataset, and we should \n",
    "# therefore use the mean and standard deviation of the SSL4EO-S12 dataset.\n",
    "mean_data = S2L1C_MEAN + S2L2A_MEAN + S1GRD_MEAN\n",
    "std_data = S2L1C_STD + S2L2A_STD + S1GRD_STD\n",
    "\n",
    "data_transform = transforms.Compose([\n",
    "    # Add additional transformation here\n",
    "    transforms.Normalize(mean=mean_data, std=std_data)\n",
    "])\n",
    "\n",
    "# Note that both ChallengeDataset and SSL4EOS12Dataset outputs torch tensors, so there is no need to a ToTensor transform."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Length of train dataset: 5149\n",
      "Modality s2l1c shape: torch.Size([1, 4, 13, 264, 264])\n",
      "Modality s2l2a shape: torch.Size([1, 4, 12, 264, 264])\n",
      "Modality s1 shape: torch.Size([1, 4, 2, 264, 264])\n"
     ]
    }
   ],
   "source": [
    "# Concatenate modalities\n",
    "# dataloader output is {'data': concatenated_data, 'file_name': file_name}\n",
    "# The data has shapes [n_samples, n_seasons, n_channels, height, width] (for concatenated_data [1, 4, 27, 264, 264])\n",
    "\n",
    "dataset_challenge = ChallengeDataset(path_to_data, \n",
    "                                  modalities = modalities, \n",
    "                                  dataset_name='bands', \n",
    "                                  transform=data_transform, \n",
    "                                  concat=False,\n",
    "                                  output_file_name=True,\n",
    "                                  shift_s2_channels=True\n",
    "                                 )\n",
    "\n",
    "# Print dataset length\n",
    "print(f\"Length of train dataset: {len(dataset_challenge)}\")\n",
    "\n",
    "# Print shape of first sample\n",
    "for m, d in dataset_challenge[0]['data'].items():\n",
    "    print(f'Modality {m} shape:', d.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create submission file\n",
    "\n",
    "In this section, we create a submission by randomly generating embeddings of the correct size.\n",
    "Finally, we create a submission file.\n",
    "\n",
    "We use the ChallengeDataset since we can easily get the sample ID (file name) from the this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_submission_from_dict(emb_dict):\n",
    "    \"\"\"Assume dictionary has format {hash-id0: embedding0, hash-id1: embedding1, ...}\n",
    "    \"\"\"\n",
    "    df_submission = pd.DataFrame.from_dict(emb_dict, orient='index')\n",
    "    \n",
    "    # Reset index with name 'id'\n",
    "    df_submission.index.name = 'id'\n",
    "    df_submission.reset_index(drop=False, inplace=True)\n",
    "        \n",
    "    return df_submission\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compress by bilinear transform and channel averaging\n",
    "\n",
    "In this section, we create a submission file by processing each sample accordingly:\n",
    "1. Subsampling each channel to 8x8 pixels using bilinear interpolation\n",
    "2. Average channels B01 to B09 for both L1C and L2A, average B11 and B12, and average S1 channels. Together with B10, this turns into 4 new channels.\n",
    "3. Flatten into 1024 element vector.\n",
    "\n",
    "We use the dataloader based on the ChallengeDataset since we can easily get the sample ID (file name) from the dataloader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correlation analysis show that L1C and L2A channels B01 to B09 are correlated, B11 and B12 are correlated, \n",
    "# and S1 VV and VH are correlated, so we average these, leaving (together with B10) 4 averaged channels.\n",
    "\n",
    "def embed(data, file_name, emb_len=1024):\n",
    "    # Bilinear interpolation of each channel separately.\n",
    "    rescaled_mod = {m: zoom(d, (1, 1, 1, 8/d.shape[3], 8/d.shape[4]), order=1) for m, d in data.items()}\n",
    "\n",
    "    # Calculate mean of correlated channels.\n",
    "    b1_b9 = np.mean(np.concatenate((rescaled_mod['s2l1c'][:, :, 0:9, :, :], \n",
    "                                   rescaled_mod['s2l2a'][:, :, 0:9, :, :]), axis=2), \n",
    "                    axis=2, keepdims=True)\n",
    "    b10 = rescaled_mod['s2l1c'][:, :, 9:10, :, :]\n",
    "    b11_b12 = np.mean(np.concatenate((rescaled_mod['s2l1c'][:, :, 10:, :, :], \n",
    "                                     rescaled_mod['s2l2a'][:, :, 10:, :, :]), axis=2), \n",
    "                      axis=2, keepdims=True)\n",
    "    s1 = np.mean(rescaled_mod['s1'], axis=2, keepdims=True)\n",
    "\n",
    "    # Concatenate aggregated channels\n",
    "    emb = np.concatenate((b1_b9, b10, b11_b12, s1), axis=2)\n",
    "\n",
    "    # Flatten\n",
    "    emb = emb.flatten()\n",
    "\n",
    "    return {'file_name': file_name, 'embedding': emb}\n",
    "\n",
    "\n",
    "def mean_embedding_parallel(dataset, n_workers=4, n_samples=None):\n",
    "    \n",
    "    # Initialize result embeddings\n",
    "    embeddings = {}\n",
    "\n",
    "    # Run embedding in parallel\n",
    "    with ThreadPoolExecutor(max_workers=n_workers) as executor:\n",
    "        futures = []\n",
    "        \n",
    "        for ind, data_file_name in enumerate(dataset):\n",
    "            data = data_file_name['data']\n",
    "            # print(data)\n",
    "            file_name = data_file_name['file_name']\n",
    "            # Submit the batch for processing\n",
    "            future = executor.submit(embed, data, file_name)\n",
    "            futures.append(future)\n",
    "\n",
    "            if (n_samples is not None) and (ind-1 > n_samples):\n",
    "                break\n",
    "        \n",
    "        # Extract results\n",
    "        for future in futures:\n",
    "            res = future.result()\n",
    "            # Compile embeddings\n",
    "            embeddings[res['file_name']] = res['embedding']\n",
    "    return embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_workers = 1\n",
    "if n_workers != 1:\n",
    "    # Embed data\n",
    "    embeddings = mean_embedding_parallel(dataset_challenge, n_workers=n_workers, n_samples=10)\n",
    "else:\n",
    "    embeddings = {}\n",
    "    for ind, data_file_name in enumerate(dataset_challenge):\n",
    "        data = data_file_name['data']\n",
    "        file_name = data_file_name['file_name']\n",
    "        emb = embed(data, file_name, 1024)\n",
    "        embeddings[file_name] = emb['embedding']\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create submission file\n",
    "submission_file = create_submission_from_dict(embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of embeddings: 5149\n"
     ]
    }
   ],
   "source": [
    "print('Number of embeddings:', len(submission_file))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>...</th>\n",
       "      <th>1014</th>\n",
       "      <th>1015</th>\n",
       "      <th>1016</th>\n",
       "      <th>1017</th>\n",
       "      <th>1018</th>\n",
       "      <th>1019</th>\n",
       "      <th>1020</th>\n",
       "      <th>1021</th>\n",
       "      <th>1022</th>\n",
       "      <th>1023</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>fec24d0cda8793ff55e1059c7b88763fee8d58d3decf78...</td>\n",
       "      <td>-0.136362</td>\n",
       "      <td>-0.371835</td>\n",
       "      <td>-0.404560</td>\n",
       "      <td>-0.508725</td>\n",
       "      <td>-0.460369</td>\n",
       "      <td>-0.422099</td>\n",
       "      <td>-0.455290</td>\n",
       "      <td>-0.303729</td>\n",
       "      <td>-0.100668</td>\n",
       "      <td>...</td>\n",
       "      <td>0.395265</td>\n",
       "      <td>0.711939</td>\n",
       "      <td>0.601892</td>\n",
       "      <td>0.383944</td>\n",
       "      <td>0.874982</td>\n",
       "      <td>0.449806</td>\n",
       "      <td>0.952038</td>\n",
       "      <td>-0.268883</td>\n",
       "      <td>0.884533</td>\n",
       "      <td>0.243575</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>67960f4c8870a8aa52f295da0f0fea6d708c3cee2555a4...</td>\n",
       "      <td>0.259305</td>\n",
       "      <td>0.083660</td>\n",
       "      <td>0.043898</td>\n",
       "      <td>0.332839</td>\n",
       "      <td>0.073862</td>\n",
       "      <td>-0.279074</td>\n",
       "      <td>-0.163957</td>\n",
       "      <td>0.097987</td>\n",
       "      <td>0.054454</td>\n",
       "      <td>...</td>\n",
       "      <td>0.113527</td>\n",
       "      <td>-0.575751</td>\n",
       "      <td>-0.560006</td>\n",
       "      <td>-0.238343</td>\n",
       "      <td>-0.913553</td>\n",
       "      <td>-0.952944</td>\n",
       "      <td>-0.011693</td>\n",
       "      <td>-0.664440</td>\n",
       "      <td>0.862798</td>\n",
       "      <td>-0.504407</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>9688abfaebaea5dca2ec8bde771a7bf1e2bba8e661b777...</td>\n",
       "      <td>0.068881</td>\n",
       "      <td>0.112759</td>\n",
       "      <td>0.085232</td>\n",
       "      <td>0.119378</td>\n",
       "      <td>0.012920</td>\n",
       "      <td>-0.103890</td>\n",
       "      <td>0.019192</td>\n",
       "      <td>0.134928</td>\n",
       "      <td>-0.128380</td>\n",
       "      <td>...</td>\n",
       "      <td>0.388513</td>\n",
       "      <td>-0.447894</td>\n",
       "      <td>-1.262257</td>\n",
       "      <td>-1.520254</td>\n",
       "      <td>-0.984263</td>\n",
       "      <td>-1.121416</td>\n",
       "      <td>-0.635569</td>\n",
       "      <td>-1.050879</td>\n",
       "      <td>-1.350882</td>\n",
       "      <td>-0.926634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>fa3ae237ee6e2ee569c20a1e088112cf2105300d9272cc...</td>\n",
       "      <td>-1.164994</td>\n",
       "      <td>-1.179528</td>\n",
       "      <td>-1.185304</td>\n",
       "      <td>-1.183173</td>\n",
       "      <td>-1.179835</td>\n",
       "      <td>-1.183128</td>\n",
       "      <td>-1.183431</td>\n",
       "      <td>-1.183904</td>\n",
       "      <td>-1.148454</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.174436</td>\n",
       "      <td>-1.286493</td>\n",
       "      <td>-1.486834</td>\n",
       "      <td>-0.839548</td>\n",
       "      <td>0.361805</td>\n",
       "      <td>0.279468</td>\n",
       "      <td>-0.059674</td>\n",
       "      <td>-0.799558</td>\n",
       "      <td>-0.876158</td>\n",
       "      <td>-1.462009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>430590d31e38c5b345a92dc7d9eb8d126c01abced0cf1a...</td>\n",
       "      <td>-0.166036</td>\n",
       "      <td>-0.311182</td>\n",
       "      <td>-0.300327</td>\n",
       "      <td>-0.343975</td>\n",
       "      <td>-0.384960</td>\n",
       "      <td>-0.244595</td>\n",
       "      <td>-0.299571</td>\n",
       "      <td>-0.286590</td>\n",
       "      <td>-0.221417</td>\n",
       "      <td>...</td>\n",
       "      <td>1.354896</td>\n",
       "      <td>0.118833</td>\n",
       "      <td>0.745980</td>\n",
       "      <td>1.308391</td>\n",
       "      <td>0.539959</td>\n",
       "      <td>0.529650</td>\n",
       "      <td>0.233003</td>\n",
       "      <td>0.646347</td>\n",
       "      <td>0.746715</td>\n",
       "      <td>0.449681</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 1025 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                  id         0         1  \\\n",
       "0  fec24d0cda8793ff55e1059c7b88763fee8d58d3decf78... -0.136362 -0.371835   \n",
       "1  67960f4c8870a8aa52f295da0f0fea6d708c3cee2555a4...  0.259305  0.083660   \n",
       "2  9688abfaebaea5dca2ec8bde771a7bf1e2bba8e661b777...  0.068881  0.112759   \n",
       "3  fa3ae237ee6e2ee569c20a1e088112cf2105300d9272cc... -1.164994 -1.179528   \n",
       "4  430590d31e38c5b345a92dc7d9eb8d126c01abced0cf1a... -0.166036 -0.311182   \n",
       "\n",
       "          2         3         4         5         6         7         8  ...  \\\n",
       "0 -0.404560 -0.508725 -0.460369 -0.422099 -0.455290 -0.303729 -0.100668  ...   \n",
       "1  0.043898  0.332839  0.073862 -0.279074 -0.163957  0.097987  0.054454  ...   \n",
       "2  0.085232  0.119378  0.012920 -0.103890  0.019192  0.134928 -0.128380  ...   \n",
       "3 -1.185304 -1.183173 -1.179835 -1.183128 -1.183431 -1.183904 -1.148454  ...   \n",
       "4 -0.300327 -0.343975 -0.384960 -0.244595 -0.299571 -0.286590 -0.221417  ...   \n",
       "\n",
       "       1014      1015      1016      1017      1018      1019      1020  \\\n",
       "0  0.395265  0.711939  0.601892  0.383944  0.874982  0.449806  0.952038   \n",
       "1  0.113527 -0.575751 -0.560006 -0.238343 -0.913553 -0.952944 -0.011693   \n",
       "2  0.388513 -0.447894 -1.262257 -1.520254 -0.984263 -1.121416 -0.635569   \n",
       "3 -1.174436 -1.286493 -1.486834 -0.839548  0.361805  0.279468 -0.059674   \n",
       "4  1.354896  0.118833  0.745980  1.308391  0.539959  0.529650  0.233003   \n",
       "\n",
       "       1021      1022      1023  \n",
       "0 -0.268883  0.884533  0.243575  \n",
       "1 -0.664440  0.862798 -0.504407  \n",
       "2 -1.050879 -1.350882 -0.926634  \n",
       "3 -0.799558 -0.876158 -1.462009  \n",
       "4  0.646347  0.746715  0.449681  \n",
       "\n",
       "[5 rows x 1025 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "submission_file.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write submission\n",
    "if write_result_to_file:\n",
    "    submission_file.to_csv(path_to_output_file, index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Verify submission file integrity\n",
    "\n",
    "Below we provide a snippet from a function which will read your embeddings and test for the same errors that the evaluation will check for. The function is similar to how the submission files are loaded.\n",
    "\n",
    "The intention of this function is to help to verify that a submission has the right structure and contents, check for missing embeddings or NaN values, prior to submission.\n",
    "\n",
    "The function is intended to be a support. Successfully completing this function does not guarantee fault-free submission file, but is an indication that the most common errors are not present."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_submission(path_to_submission: str, \n",
    "                    expected_embedding_ids: set, \n",
    "                    embedding_dim: int = 1024):\n",
    "    # Load data\n",
    "    df = pd.read_csv(path_to_submission, header=0)\n",
    "\n",
    "    # Verify that id is in columns\n",
    "    if 'id' not in df.columns:\n",
    "        raise ValueError(f\"\"\"Submission file must contain column 'id'.\"\"\")\n",
    "\n",
    "    # Temporarily set index to 'id'\n",
    "    df.set_index('id', inplace=True)\n",
    "\n",
    "    # Check that all samples are included\n",
    "    submitted_embeddings = set(df.index.to_list())\n",
    "    n_missing_embeddings = len(expected_embedding_ids.difference(submitted_embeddings))\n",
    "    if n_missing_embeddings > 0:\n",
    "        raise ValueError(f\"\"\"Submission is missing {n_missing_embeddings} embeddings.\"\"\")\n",
    "    \n",
    "    # Check that embeddings have the correct length\n",
    "    if len(df.columns) != embedding_dim:\n",
    "        raise ValueError(f\"\"\"{embedding_dim} embedding dimensions, but provided embeddings have {len(df.columns)} dimensions.\"\"\")\n",
    "\n",
    "    # Convert columns to float\n",
    "    try:\n",
    "        for col in df.columns:\n",
    "            df[col] = df[col].astype(float)\n",
    "    except Exception as e:\n",
    "        raise ValueError(f\"\"\"Failed to convert embedding values to float.\n",
    "    Check embeddings for any not-allowed character, for example empty strings, letters, etc.\n",
    "    Original error message: {e}\"\"\")\n",
    "\n",
    "    # Check if any NaNs \n",
    "    if df.isna().any().any():\n",
    "        raise ValueError(f\"\"\"Embeddings contain NaN values.\"\"\")\n",
    "\n",
    "    # Successful completion of the function\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use the created embeddings as the list of all samples.\n",
    "# This can be done since we are sure to have fully looped through the dataset.\n",
    "# A better way would be to find all the IDs in the challenge data separately, e.g. from the dataloader.\n",
    "embedding_ids = set(embeddings.keys())\n",
    "embedding_dim = 1024\n",
    "\n",
    "# Test submission\n",
    "assert test_submission(path_to_output_file, embedding_ids, embedding_dim)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
