{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspyred\n",
    "from inspyred import ec, benchmarks\n",
    "\n",
    "from catboost import CatBoostClassifier, Pool\n",
    "from sklearn.metrics import accuracy_score, classification_report, confusion_matrix\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "import pandas as pd\n",
    "import re\n",
    "import itertools\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import pickle\n",
    "import time\n",
    "import psutil\n",
    "import gc\n",
    "import random\n",
    "import pynvml\n",
    "\n",
    "from scipy.stats import chisquare, kstest\n",
    "import math\n",
    "\n",
    "random.seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to clean folder names\n",
    "def clean_folder_name(folder_name):\n",
    "    # Remove invalid characters\n",
    "    cleaned_name = re.sub(r'[<>:\"/\\\\|?*]', '', folder_name)\n",
    "    # Remove trailing dots and spaces\n",
    "    cleaned_name = cleaned_name.rstrip('. ')\n",
    "    return cleaned_name\n",
    "\n",
    "def CPU_monitor_memory_usage():\n",
    "    memory_info = psutil.virtual_memory()\n",
    "    memory_usage = memory_info.percent\n",
    "        \n",
    "    print(f\"CPU Current memory usage: {memory_usage}%\")\n",
    "\n",
    "    if memory_usage >= 100:\n",
    "        print(\"CPU Memory usage is too high. Pausing execution...\")\n",
    "        gc.collect()  # Trigger garbage collection manually\n",
    "        while memory_usage > 30:\n",
    "            time.sleep(10)\n",
    "            memory_info = psutil.virtual_memory()\n",
    "            memory_usage = memory_info.percent\n",
    "        print(\"CPU Memory usage is low enough. Resuming execution...\")\n",
    "\n",
    "    # time.sleep(5)\n",
    "\n",
    "def monitor_gpu_memory():\n",
    "    # Initialize NVML\n",
    "    pynvml.nvmlInit()\n",
    "    \n",
    "    try:\n",
    "        # Get handle for the first GPU\n",
    "        handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
    "\n",
    "        # Get memory info\n",
    "        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n",
    "        total_memory = mem_info.total\n",
    "        used_memory = mem_info.used\n",
    "\n",
    "        # Calculate the percentage of GPU memory used\n",
    "        memory_usage = (used_memory / total_memory) * 100\n",
    "        print(f\"Current GPU memory usage: {memory_usage:.2f}%\")\n",
    "\n",
    "        # Check if memory usage is too high\n",
    "        if memory_usage >= 95:\n",
    "            print(\"GPU memory usage is too high. Pausing execution...\")\n",
    "            while memory_usage > 30:\n",
    "                time.sleep(10)\n",
    "                mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)\n",
    "                used_memory = mem_info.used\n",
    "                memory_usage = (used_memory / total_memory) * 100\n",
    "            print(\"GPU memory usage is low enough. Resuming execution...\")\n",
    "\n",
    "    finally:\n",
    "        # Clean up\n",
    "        pynvml.nvmlShutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def CatList(cat_, list_, type_ = \"and\"):\n",
    "        if type_ == \"and\":\n",
    "            return cat_ in list_\n",
    "        elif type_ == \"or\":\n",
    "            return cat_ not in list_\n",
    "        else:\n",
    "            return True\n",
    "def cats_used(max_cat_others, data, output_cat, min_cats = 2):\n",
    "    if max_cat_others > 0:\n",
    "        cats_values = (data[output_cat].value_counts()/data.shape[0])*100\n",
    "        sum_perc = 0\n",
    "        cats_ = []\n",
    "        for i in range(len(cats_values)):\n",
    "            if len(cats_) < min_cats or sum_perc < (100-max_cat_others):\n",
    "                cats_.append(str(cats_values.index[i]))\n",
    "                sum_perc = sum_perc + cats_values[i]\n",
    "            else:\n",
    "                break\n",
    "\n",
    "    else:\n",
    "        cats_ = list(pd.unique(data[output_cat].tolist()))\n",
    "\n",
    "    return cats_\n",
    "\n",
    "def cats_levels(data, max_cat_others, output_cat, min_cats,prev_cats= {}):\n",
    "    cats_to_use = []\n",
    "    if len(list(prev_cats.keys())) != 0:\n",
    "        for cat_name in prev_cats.keys():\n",
    "            print(cat_name)\n",
    "            for case in prev_cats[cat_name]:\n",
    "                print(case)\n",
    "                data_cat = data[data[cat_name] == case]\n",
    "                cats_to_use_i = cats_used(max_cat_others, data_cat, output_cat, min_cats)\n",
    "                cats_to_use = cats_to_use + cats_to_use_i\n",
    "    else:\n",
    "        cats_to_use = cats_used(max_cat_others, data, output_cat, min_cats)\n",
    "\n",
    "    return cats_to_use\n",
    "\n",
    "\n",
    "def Cats_Filter(instance, list_categories, name_prev_cat = None):\n",
    "    if name_prev_cat != None:\n",
    "        if instance in list_categories:\n",
    "            return instance\n",
    "        else:\n",
    "            return \"Other_TEIS_\"+name_prev_cat\n",
    "    else:\n",
    "        if instance in list_categories:\n",
    "            return instance\n",
    "        else:\n",
    "            return \"Other_TEIS\"\n",
    "\n",
    "def ManageTextFeature(text_):\n",
    "    if type(text_) != str:\n",
    "        return \"No valid text\"\n",
    "    else:\n",
    "        return text_\n",
    "    \n",
    "\n",
    "def classification_report_to_df(report, y_true, y_pred):\n",
    "    global bch_class_df\n",
    "    global topic_dict\n",
    "    global iteration\n",
    "    df = pd.DataFrame(report).transpose()\n",
    "\n",
    "    order_labels = list(topic_dict.values())\n",
    "\n",
    "    # Calculate the confusion matrix\n",
    "    labels = df.index[:-3]  # Exclude 'accuracy', 'macro avg', 'weighted avg'\n",
    "    # Calculate the confusion matrix\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=labels)\n",
    "\n",
    "    # Extracting TP, FP, TN, FN for each class\n",
    "    TP = cm.diagonal()\n",
    "    FP = cm.sum(axis=0) - TP\n",
    "    FN = cm.sum(axis=1) - TP\n",
    "    TN = cm.sum() - (FP + FN + TP)\n",
    "\n",
    "    sens = sum(TP) / (sum(TP)+sum(FN))\n",
    "    spec = sum(TN) / (sum(TN)+sum(FP))\n",
    "    \n",
    "    # Calculate Sensitivity (same as recall)\n",
    "    df['Sensitivity'] = df['recall']\n",
    "    \n",
    "    # Calculate Specificity\n",
    "    tn = cm.sum() - (cm.sum(axis=0) + cm.sum(axis=1) - np.diag(cm))\n",
    "    fp = cm.sum(axis=0) - np.diag(cm)\n",
    "    specificity = tn / (tn + fp)\n",
    "    \n",
    "    # Assign computed specificity to dataframe except for the last three rows\n",
    "    df.loc[df.index[:-3], 'Specificity'] = specificity\n",
    "    \n",
    "    # Handling special cases\n",
    "    # Set 'accuracy' row sensitivity and specificity to the accuracy value\n",
    "    accuracy = df.loc['accuracy', 'precision']  # assuming 'precision' contains the accuracy\n",
    "    df.loc['accuracy', ['Sensitivity', 'Specificity']] = sens, spec\n",
    "    \n",
    "    # Calculate 'macro avg' and 'weighted avg' for sensitivity and specificity\n",
    "    df.loc['macro avg', 'Sensitivity'] = df.iloc[:-3]['Sensitivity'].mean()\n",
    "    df.loc['weighted avg', 'Sensitivity'] = np.average(df.iloc[:-3]['Sensitivity'], weights=df.iloc[:-3]['support'])\n",
    "    \n",
    "    df.loc['macro avg', 'Specificity'] = df.iloc[:-3]['Specificity'].mean()\n",
    "    df.loc['weighted avg', 'Specificity'] = np.average(df.iloc[:-3]['Specificity'], weights=df.iloc[:-3]['support'])\n",
    "\n",
    "    # Calculate Balanced Accuracy for each row, including special averages\n",
    "    df['Balanced Accuracy'] = (df['Sensitivity'] + df['Specificity']) / 2\n",
    "\n",
    "    df.loc['accuracy', 'precision'] = sum(TP) / (sum(TP) + sum(FP))\n",
    "    df.loc['accuracy', 'recall'] = sum(TP) / (sum(TP) + sum(FN))\n",
    "    df.loc['accuracy', 'f1-score'] = 2* sum(TP) / (2 * sum(TP) + sum(FP) + sum(FN))\n",
    "\n",
    "    if iteration > 1:\n",
    "        bch_class_df_noFr = bch_class_df.drop(columns=['TP', 'FP', 'TN', 'FN'])\n",
    "    else: \n",
    "        bch_class_df_noFr = bch_class_df\n",
    "\n",
    "    diff_df = df - bch_class_df_noFr\n",
    "    # Renaming columns for clarity\n",
    "    diff_df.columns = ['Diff ' + col for col in diff_df.columns]\n",
    "\n",
    "    # Concatenating the original dataframe with the differences\n",
    "    combined_df = pd.concat([df, diff_df], axis=1)\n",
    "\n",
    "    class_accuracy = cm.diagonal() / cm.sum(axis=1)\n",
    "    combined_df.loc[labels, 'Accuracy'] = class_accuracy\n",
    "    # Copying f1-score to 'Accuracy' for the last three rows\n",
    "    combined_df.loc[['accuracy', 'macro avg', 'weighted avg'], 'Accuracy'] = combined_df.loc[['accuracy', 'macro avg', 'weighted avg'], 'f1-score']\n",
    "\n",
    "    # Calculate and append TP, FP, TN, FN metrics\n",
    "    metrics_df = pd.DataFrame({\n",
    "        \"TP\": TP,\n",
    "        \"FP\": FP,\n",
    "        \"TN\": TN,\n",
    "        \"FN\": FN\n",
    "    }, index=labels)\n",
    "\n",
    "    # Merge the new metrics into the existing DataFrame\n",
    "    combined_df = combined_df.merge(metrics_df, left_index=True, right_index=True, how='left')\n",
    "\n",
    "    # Reorder DataFrame based on specified order labels\n",
    "    combined_df = combined_df.reindex(order_labels + ['macro avg', 'weighted avg'])\n",
    "\n",
    "    return combined_df\n",
    "\n",
    "\n",
    "def fitness_evaluation(data_syn, chromosome, selected_indexes, indices_frozenset, generation, process):\n",
    "    global history_IndexesList_dict\n",
    "    global topic_name\n",
    "    global X_train_r\n",
    "    global Y_train_r\n",
    "    global X_test_re\n",
    "    global Y_test_re\n",
    "    global catboost_params\n",
    "    global nsgaii_results_path\n",
    "    global history_dict_name\n",
    "    global gen_eval_df\n",
    "    global topic_name\n",
    "    global topic_number\n",
    "    global sum_GPU_seconds\n",
    "    global total_gpu_seconds\n",
    "    global GPU_limit\n",
    "    global metric_name\n",
    "    global bch_m0\n",
    "    global X_test_re_Test\n",
    "    global Y_test_re_Test\n",
    "    global test_gen_eval_df\n",
    "\n",
    "    CPU_monitor_memory_usage()\n",
    "    monitor_gpu_memory()\n",
    "\n",
    "    if indices_frozenset in history_IndexesList_dict:\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "        classification_df = history_IndexesList_dict[indices_frozenset]['classification_df']\n",
    "        history_IndexesList_dict[indices_frozenset]['generation_process_chromosome'].append((generation, process, chromosome))\n",
    "        if metric_name == \"overall_balanced_accuracy\":\n",
    "            fitness_objective = (classification_df.loc['accuracy', 'Balanced Accuracy'], 0)\n",
    "        elif metric_name == \"overall_f1-score\":\n",
    "            fitness_objective = (classification_df.loc['accuracy', 'f1-score'], 0)\n",
    "        else:\n",
    "            fitness_objective = (classification_df.loc[topic_name, metric_name], 0)\n",
    "    else:\n",
    "        if GPU_limit == True:\n",
    "            return ()\n",
    "        individual_IndexesList_dict = {}\n",
    "        individual_IndexesList_dict['generation_process_chromosome'] = []\n",
    "        filtered_syn_df = data_syn[data_syn['index_meta'].isin(selected_indexes)]\n",
    "\n",
    "        X_train_re = pd.concat([X_train_r, filtered_syn_df.drop(columns=['topic_name'])])\n",
    "        Y_train_re = pd.concat([Y_train_r, filtered_syn_df['topic_name']])\n",
    "\n",
    "        train_pool_re = Pool(\n",
    "            X_train_re[[\"text\", \"area_TEIS\"]],\n",
    "            Y_train_re,\n",
    "            text_features=[\"text\"],\n",
    "            cat_features=[\"area_TEIS\"]\n",
    "        )\n",
    "        valid_pool_re = Pool(\n",
    "            X_test_re[[\"text\", \"area_TEIS\"]],\n",
    "            Y_test_re,\n",
    "            text_features=[\"text\"],\n",
    "            cat_features=[\"area_TEIS\"]\n",
    "        )\n",
    "\n",
    "        catboost_params = catboost_params\n",
    "            \n",
    "        # Model Training\n",
    "        model_re = CatBoostClassifier(**catboost_params)\n",
    "        start_time = time.time()  # Start timing\n",
    "        model_re.fit(train_pool_re, eval_set=valid_pool_re)\n",
    "        training_time = time.time() - start_time  # End timing\n",
    "\n",
    "        sum_GPU_seconds += training_time\n",
    "        if sum_GPU_seconds >= total_gpu_seconds:\n",
    "            GPU_limit = True\n",
    "\n",
    "        # Save the retrain performances\n",
    "        predictions = model_re.predict(X_test_re[[\"text\", \"area_TEIS\"]])\n",
    "        accuracy = accuracy_score(Y_test_re, predictions)\n",
    "        report = classification_report(Y_test_re, predictions, digits=3, output_dict=True)\n",
    "        classification_df = classification_report_to_df(report, Y_test_re, predictions)\n",
    "\n",
    "        fitness_score = (accuracy, classification_df.loc[topic_name, 'recall'])\n",
    "            \n",
    "        # Save the trained model, classification_df, and fitness_score\n",
    "        individual_IndexesList_dict['model'] = model_re\n",
    "        individual_IndexesList_dict['true_labels'] = []  # Convert to list if Y_test_re is a pandas Series or numpy array\n",
    "        individual_IndexesList_dict['predicted_labels'] = []  # Convert to list for consistency\n",
    "        individual_IndexesList_dict['classification_df'] = classification_df\n",
    "        individual_IndexesList_dict['fitness_score'] = fitness_score\n",
    "        individual_IndexesList_dict['number_of_syn_sample'] = len(filtered_syn_df)\n",
    "        individual_IndexesList_dict['retraining_time'] = training_time  # Save the training time\n",
    "        individual_IndexesList_dict['generation_process_chromosome'].append((generation, process, chromosome))\n",
    "\n",
    "        # Save the individual dictionary\n",
    "        history_IndexesList_dict[indices_frozenset] = individual_IndexesList_dict\n",
    "        # with open(f'{nsgaii_results_path}/{history_dict_name}.pkl', 'wb') as file:\n",
    "        #     pickle.dump(history_IndexesList_dict, file)\n",
    "        print(fitness_score)\n",
    "\n",
    "        new_row_index = len(gen_eval_df)\n",
    "        class_DF_path = f'{nsgaii_results_path}/Class_DF'\n",
    "        os.makedirs(class_DF_path, exist_ok=True)\n",
    "        classification_df.to_csv(f'{class_DF_path}/{topic_number}_NSGA-II_{new_row_index}_AllEval_ClassDF.csv', index=True)\n",
    "        classification_df.to_pickle(f'{class_DF_path}/{topic_number}_NSGA-II_{new_row_index}_AllEval_ClassDF.pkl')\n",
    "\n",
    "        # Collect all generation data into a new DataFrame row\n",
    "        gen_eval_row = {\n",
    "            \"topic_name\": topic_name,\n",
    "            \"topic_number\": topic_number,\n",
    "            \"generation\": generation,\n",
    "            'fitness_score': fitness_score,\n",
    "            \"accuracy\": fitness_score[0],\n",
    "            \"topic_recall\": fitness_score[1],\n",
    "            'balanced_fitness_score': (classification_df.loc['accuracy', 'Balanced Accuracy'], classification_df.loc[topic_name, 'Balanced Accuracy']),\n",
    "            'overall_balanced_accuracy': classification_df.loc['accuracy', 'Balanced Accuracy'],\n",
    "            'topic_balanced_accuracy': classification_df.loc[topic_name, 'Balanced Accuracy'],\n",
    "            'balanced_acc_rec_score': (classification_df.loc[topic_name, 'Balanced Accuracy'], classification_df.loc[topic_name, 'recall']),\n",
    "            'topic_F1': classification_df.loc[topic_name, 'f1-score'],\n",
    "            'overall_F1': classification_df.loc['accuracy', 'f1-score'],\n",
    "            'overall_recall': classification_df.loc['accuracy', 'recall'],\n",
    "            \"retraining_time\": training_time,\n",
    "            \"number_of_syn_sample\": len(filtered_syn_df),\n",
    "            \"retrained_dots_list\": filtered_syn_df['index_meta'].tolist(),\n",
    "            'true_labels': [],\n",
    "            'predicted_labels': [],\n",
    "            'chromosome': chromosome,\n",
    "            'classDF_path': f'{class_DF_path}/{topic_number}_NSGA-II_{new_row_index}_AllEval_ClassDF.csv',\n",
    "            'T13_TBA_Imp': classification_df.loc['Humanitarian aid for Ukraine.', 'Balanced Accuracy'] - bch_m0.loc['Humanitarian aid for Ukraine.', 'Balanced Accuracy']\n",
    "            # 'T12_TBA_Imp': classification_df.loc['Email security and attachments.', 'Balanced Accuracy'] - bch_m0.loc['Email security and attachments.', 'Balanced Accuracy']\n",
    "        }\n",
    "        # # Transform DataFrame to dict format\n",
    "        # for idx in classification_df.index:\n",
    "        #     if idx != 'accuracy' and idx != 'macro avg' and idx != 'weighted avg':\n",
    "        #         gen_eval_row[idx] = classification_df.loc[idx].dropna().to_dict()\n",
    "        # Convert the dictionary to a DataFrame for a single row\n",
    "        gen_eval_row_df = pd.DataFrame([gen_eval_row])\n",
    "        # Concatenate this new row DataFrame to the existing DataFrame\n",
    "        gen_eval_df = pd.concat([gen_eval_df, gen_eval_row_df], ignore_index=True)\n",
    "        gen_eval_df.to_csv(f'{nsgaii_results_path}/GenAllEvals_{NSGA_II_results_name}.csv', index=True)\n",
    "        gen_eval_df.to_pickle(f'{nsgaii_results_path}/GenAllEvals_{NSGA_II_results_name}.pkl')\n",
    "\n",
    "        \"\"\"Testing results below\"\"\"\n",
    "\n",
    "        test_predictions = model_re.predict(X_test_re_Test[[\"text\", \"area_TEIS\"]])\n",
    "        test_accuracy = accuracy_score(Y_test_re_Test, test_predictions)\n",
    "        test_report = classification_report(Y_test_re_Test, test_predictions, digits=3, output_dict=True)\n",
    "        test_classification_df = classification_report_to_df(test_report, Y_test_re_Test, test_predictions)\n",
    "\n",
    "        test_new_row_index = len(test_gen_eval_df)\n",
    "        test_class_DF_path = f'{nsgaii_results_path}/test_Class_DF'\n",
    "        os.makedirs(test_class_DF_path, exist_ok=True)\n",
    "        test_classification_df.to_csv(f'{test_class_DF_path}/test_{topic_number}_NSGA-II_{test_new_row_index}_AllEval_ClassDF.csv', index=True)\n",
    "        test_classification_df.to_pickle(f'{test_class_DF_path}/test_{topic_number}_NSGA-II_{test_new_row_index}_AllEval_ClassDF.pkl')\n",
    "\n",
    "        # Collect all generation data into a new DataFrame row\n",
    "        test_gen_eval_row = {\n",
    "            \"topic_name\": topic_name,\n",
    "            \"topic_number\": topic_number,\n",
    "            \"generation\": generation,\n",
    "            'fitness_score': (test_classification_df.loc['accuracy', 'recall'], test_classification_df.loc[topic_name, 'recall']),\n",
    "            \"accuracy\": test_classification_df.loc['accuracy', 'recall'],\n",
    "            \"topic_recall\": test_classification_df.loc[topic_name, 'recall'],\n",
    "            'balanced_fitness_score': (test_classification_df.loc['accuracy', 'Balanced Accuracy'], test_classification_df.loc[topic_name, 'Balanced Accuracy']),\n",
    "            'overall_balanced_accuracy': test_classification_df.loc['accuracy', 'Balanced Accuracy'],\n",
    "            'topic_balanced_accuracy': test_classification_df.loc[topic_name, 'Balanced Accuracy'],\n",
    "            'balanced_acc_rec_score': (test_classification_df.loc[topic_name, 'Balanced Accuracy'], test_classification_df.loc[topic_name, 'recall']),\n",
    "            'topic_F1': test_classification_df.loc[topic_name, 'f1-score'],\n",
    "            'overall_F1': test_classification_df.loc['accuracy', 'f1-score'],\n",
    "            'overall_recall': test_classification_df.loc['accuracy', 'recall'],\n",
    "            \"retraining_time\": training_time,\n",
    "            \"number_of_syn_sample\": len(filtered_syn_df),\n",
    "            \"retrained_dots_list\": filtered_syn_df['index_meta'].tolist(),\n",
    "            'true_labels': [],\n",
    "            'predicted_labels': [],\n",
    "            'chromosome': chromosome,\n",
    "            'classDF_path': f'{test_class_DF_path}/test_{topic_number}_NSGA-II_{test_new_row_index}_AllEval_ClassDF.csv',\n",
    "            'T13_TBA_Imp': test_classification_df.loc['Humanitarian aid for Ukraine.', 'Balanced Accuracy'] - bch_m0.loc['Humanitarian aid for Ukraine.', 'Balanced Accuracy'],\n",
    "            'T13_TR_Imp': classification_df.loc['Humanitarian aid for Ukraine.', 'recall'] - bch_m0.loc['Humanitarian aid for Ukraine.', 'recall']\n",
    "            # 'T12_TBA_Imp': classification_df.loc['Email security and attachments.', 'Balanced Accuracy'] - bch_m0.loc['Email security and attachments.', 'Balanced Accuracy']\n",
    "        }\n",
    "        # # Transform DataFrame to dict format\n",
    "        # for idx in classification_df.index:\n",
    "        #     if idx != 'accuracy' and idx != 'macro avg' and idx != 'weighted avg':\n",
    "        #         gen_eval_row[idx] = classification_df.loc[idx].dropna().to_dict()\n",
    "        # Convert the dictionary to a DataFrame for a single row\n",
    "        test_gen_eval_row_df = pd.DataFrame([test_gen_eval_row])\n",
    "        # Concatenate this new row DataFrame to the existing DataFrame\n",
    "        test_gen_eval_df = pd.concat([test_gen_eval_df, test_gen_eval_row_df], ignore_index=True)\n",
    "        test_gen_eval_df.to_csv(f'{nsgaii_results_path}/test_GenAllEvals_{NSGA_II_results_name}.csv', index=True)\n",
    "        test_gen_eval_df.to_pickle(f'{nsgaii_results_path}/test_GenAllEvals_{NSGA_II_results_name}.pkl')\n",
    "\n",
    "        if metric_name == \"overall_balanced_accuracy\":\n",
    "            fitness_objective = (classification_df.loc['accuracy', 'Balanced Accuracy'], 0)\n",
    "        elif metric_name == \"overall_f1-score\":\n",
    "            fitness_objective = (classification_df.loc['accuracy', 'f1-score'], 0)\n",
    "        else:\n",
    "            fitness_objective = (classification_df.loc[topic_name, metric_name], 0)\n",
    "        \n",
    "    return fitness_objective\n",
    "\n",
    "def polynomial_mutation(individual, mutation_rate, eta=20):\n",
    "    \"\"\"\n",
    "    Perform polynomial mutation on an individual.\n",
    "    Args:\n",
    "    - individual: A list of tuples (index, priority)\n",
    "    - mutation_rate: Probability of mutation per gene.\n",
    "    - eta: Distribution index for mutation (controls the spread).\n",
    "    \"\"\"\n",
    "    mutated_individual = []\n",
    "    for gene in individual:\n",
    "        if random.random() < mutation_rate:\n",
    "            u = random.random()\n",
    "            delta = 0\n",
    "            if u < 0.5:\n",
    "                delta = (2*u)**(1/(eta+1)) - 1\n",
    "            else:\n",
    "                delta = 1 - (2*(1 - u))**(1/(eta+1))\n",
    "\n",
    "            # Mutate the priority value, ensuring it remains an integer within bounds\n",
    "            min_priority, max_priority = 1, len(individual)  # Assuming priority bounds\n",
    "            new_priority = int(min(max_priority, max(min_priority, gene[1] + delta * (max_priority - min_priority))))\n",
    "            mutated_individual.append((gene[0], new_priority))\n",
    "        else:\n",
    "            mutated_individual.append(gene)\n",
    "    return mutated_individual\n",
    "\n",
    "def adaptive_polynomial_mutation(individual, generation, max_generations, initial_rate=0.1):\n",
    "    \"\"\"\n",
    "    Adaptive mutation adjusts the mutation rate based on the generation number.\n",
    "    Args:\n",
    "    - individual: The individual to mutate.\n",
    "    - generation: Current generation number.\n",
    "    - max_generations: Total number of generations planned.\n",
    "    - initial_rate: Initial mutation rate.\n",
    "    \"\"\"\n",
    "    # Adjust mutation rate based on the progress\n",
    "    mutation_rate = initial_rate * (1 - generation / max_generations)\n",
    "    return polynomial_mutation(individual, mutation_rate)\n",
    "\n",
    "def dominates(score1, score2):\n",
    "    return (score1[0] > score2[0] and score1[1] >= score2[1]) or (score1[0] >= score2[0] and score1[1] > score2[1])\n",
    "\n",
    "def all_pareto_observer(history_pareto_selections_list):\n",
    "    global topic_name\n",
    "    global topic_number\n",
    "    global nsgaii_results_path\n",
    "    global gen_stats_df_name\n",
    "    global gen_stats_df\n",
    "    global NSGA_II_results_name\n",
    "    global fold_pfs_df\n",
    "    global history_IndexesList_dict\n",
    "\n",
    "    objective_1_values = [fitness[0] for _, _, fitness in history_pareto_selections_list]\n",
    "    objective_2_values = [fitness[1] for _, _, fitness in history_pareto_selections_list]\n",
    "\n",
    "    # Check if both lists are empty or filled with zeros\n",
    "    if not objective_1_values or not objective_2_values or all(value == 0 for value in objective_1_values + objective_2_values):\n",
    "        print(\"All objectives are None or 0, skipping further processes.\")\n",
    "        return  # Exit the function\n",
    "\n",
    "    non_dominated_segments = {}\n",
    "    for i1, tuple1 in enumerate(history_pareto_selections_list):\n",
    "        if tuple1[2] == (None, None):\n",
    "            continue  # Skip non-evaluable segments\n",
    "\n",
    "        dominated = False\n",
    "        for i2, tuple2 in enumerate(history_pareto_selections_list):\n",
    "            if i1 != i2 and dominates(tuple2[2], tuple1[2]):\n",
    "                dominated = True\n",
    "                break\n",
    "        if not dominated:\n",
    "            non_dominated_segments[i1] = tuple1\n",
    "    \n",
    "    pareto_fitness_tuples = [(ft[0], ft[1]) for ft in [sel[2] for sel in non_dominated_segments.values()] if ft[0] is not None and ft[1] is not None]\n",
    "    pareto_selections_tuples = [sel for sel in non_dominated_segments.values()]\n",
    "\n",
    "    # Calculate statistics\n",
    "    worst_fitness = (min(objective_1_values), min(objective_2_values))\n",
    "    best_fitness = (max(objective_1_values), max(objective_2_values))\n",
    "    mean_OverallAcc = np.mean(objective_1_values)\n",
    "    mean_ClassRecall = np.mean(objective_2_values)\n",
    "\n",
    "    print(f\"All Pareto Selections, Evaluations: {len(history_pareto_selections_list)}\")\n",
    "    print(f\"Best Fitness: {best_fitness}\")\n",
    "    print(f\"Worst Fitness: {worst_fitness}\")\n",
    "    print(f\"Mean Overall Accuracy: {mean_OverallAcc}\")\n",
    "    print(f\"Mean {topic_name} Recall: {mean_ClassRecall}\")\n",
    "\n",
    "    print(\"Pareto Front Selections:---------------------\")\n",
    "    for i, sel in enumerate(non_dominated_segments.values()):\n",
    "        print(f\"Selection {i+1}: {sel[0]} \\nFitness: {sel[2]}\")\n",
    "        indices_frozenset = frozenset(sel[0])\n",
    "        classification_df = history_IndexesList_dict[indices_frozenset]['classification_df']\n",
    "        # Collect all generation data into a new DataFrame row\n",
    "        new_PFs_row = {\n",
    "            \"topic_name\": topic_name,\n",
    "            \"topic_number\": topic_number,\n",
    "            \"generation\": \"Final_Evaluate\",\n",
    "            'fitness_score': sel[2],\n",
    "            \"accuracy\": sel[2][0],\n",
    "            \"topic_recall\": sel[2][1],\n",
    "            \"retraining_time\": history_IndexesList_dict[indices_frozenset]['retraining_time'],\n",
    "            \"number_of_syn_sample\": len(sel[0]),\n",
    "            \"retrained_dots_list\": sel[0],\n",
    "            'true_labels': history_IndexesList_dict[indices_frozenset]['true_labels'],\n",
    "            'predicted_labels': history_IndexesList_dict[indices_frozenset]['predicted_labels'],\n",
    "            'chromosome': sel[1]\n",
    "        }\n",
    "        # # Transform DataFrame to dict format\n",
    "        # for idx in classification_df.index:\n",
    "        #     if idx != 'accuracy' and idx != 'macro avg' and idx != 'weighted avg':\n",
    "        #         new_PFs_row[idx] = classification_df.loc[idx].dropna().to_dict()\n",
    "        # Convert the dictionary to a DataFrame for a single row\n",
    "        new_PFs_row_df = pd.DataFrame([new_PFs_row])\n",
    "        # Concatenate this new row DataFrame to the existing DataFrame\n",
    "        fold_pfs_df = pd.concat([fold_pfs_df, new_PFs_row_df], ignore_index=True)\n",
    "        fold_pfs_df.to_csv(f'{nsgaii_results_path}/TopicPFs_{NSGA_II_results_name}.csv', index=False)\n",
    "        fold_pfs_df.to_pickle(f'{nsgaii_results_path}/TopicPFs_{NSGA_II_results_name}.pkl')\n",
    "        print('---')\n",
    "    print('------------------------')\n",
    "\n",
    "    recall_key = f\"Mean {topic_name} Recall\"\n",
    "    # Collect all generation data into a new DataFrame row\n",
    "    new_row = {\n",
    "        \"Generation\": \"All Pareto Selections in History\",\n",
    "        \"Number of Evaluations\": len(history_pareto_selections_list),\n",
    "        \"Best Fitness\": best_fitness,\n",
    "        \"Worst Fitness\": worst_fitness,\n",
    "        \"Mean Overall Accuracy\": mean_OverallAcc,\n",
    "        recall_key: mean_ClassRecall,\n",
    "        \"Pareto Front Selections\": [pareto_fitness_tuples, pareto_selections_tuples]\n",
    "    }\n",
    "    # Convert the dictionary to a DataFrame for a single row\n",
    "    new_row_df = pd.DataFrame([new_row])\n",
    "    # Concatenate this new row DataFrame to the existing DataFrame\n",
    "    gen_stats_df = pd.concat([gen_stats_df, new_row_df], ignore_index=True)\n",
    "    gen_stats_df.to_csv(f'{nsgaii_results_path}/{gen_stats_df_name}.csv', index=False)\n",
    "    gen_stats_df.to_pickle(f'{nsgaii_results_path}/{gen_stats_df_name}.pkl')\n",
    "\n",
    "    # Plotting\n",
    "    plt.figure(figsize=(12, 7))\n",
    "    \n",
    "    plt.scatter(objective_1_values, objective_2_values, c='blue', alpha=0.5, label='Population')\n",
    "    plt.scatter([ft[0] for ft in pareto_fitness_tuples], [ft[1] for ft in pareto_fitness_tuples], c='red', alpha=0.9, label='Pareto Front')\n",
    "    \n",
    "    plt.xlabel('Objective 1: Overall Accuracy')\n",
    "    plt.ylabel(f'Objective 2: {topic_name} Recall')\n",
    "    plt.title(f'{topic_number}, Population and Pareto Front for All Pareto Selections in History')\n",
    "    plt.legend()\n",
    "    pareto_plots_dir = f\"{nsgaii_results_path}/{gen_stats_df_name}\"\n",
    "    os.makedirs(pareto_plots_dir, exist_ok=True)\n",
    "    filename = f\"All Pareto Selections in History.png\"\n",
    "    plt.savefig(os.path.join(pareto_plots_dir, filename), dpi=200, bbox_inches='tight', pad_inches=0)\n",
    "    plt.show()\n",
    "\n",
    "    return gen_stats_df\n",
    "\n",
    "\n",
    "def select_string_based_on_probabilities_no_duplicates(nested_list, probabilities, rnd, selected_set):\n",
    "    \"\"\"\n",
    "    Select a string from a list of lists based on given probabilities for each inner list, ensuring no duplicates.\n",
    "\n",
    "    :param nested_list: List of lists containing strings\n",
    "    :param probabilities: List of probabilities for selecting from each inner list\n",
    "    :param rnd: Random instance to use for selection\n",
    "    :param selected_set: Set of already selected strings to avoid duplicates\n",
    "    :return: Selected string\n",
    "    \"\"\"\n",
    "    # Normalize probabilities to sum to 1\n",
    "    total = sum(probabilities)\n",
    "    normalized_probabilities = [p / total for p in probabilities]\n",
    "\n",
    "    while True:\n",
    "        # Select which inner list to choose from\n",
    "        chosen_list_index = rnd.choices(range(len(nested_list)), weights=normalized_probabilities, k=1)[0]\n",
    "        chosen_list = nested_list[chosen_list_index]\n",
    "\n",
    "        # Select a string from the chosen inner list uniformly, ensuring no duplicates\n",
    "        available_choices = [s for s in chosen_list if s not in selected_set]\n",
    "        if available_choices:\n",
    "            selected_string = rnd.choice(available_choices)\n",
    "            selected_set.add(selected_string)\n",
    "            return selected_string\n",
    "\n",
    "def move_to_first_list(nested_list, element):\n",
    "    # Remove the element from its current list\n",
    "    for inner_list in nested_list:\n",
    "        if element in inner_list:\n",
    "            inner_list.remove(element)\n",
    "            break\n",
    "    # Add the element to the first inner list\n",
    "    nested_list[0].append(element)\n",
    "    return nested_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def is_dominated(score1, score2):\n",
    "    return (score1[0] > score2[0] and score1[1] >= score2[1]) or (score1[0] >= score2[0] and score1[1] > score2[1])\n",
    "\n",
    "def identify_pareto_fronts(population):\n",
    "    population = np.array(population)\n",
    "    fronts = []\n",
    "    remaining_population = population.copy()\n",
    "    \n",
    "    while len(remaining_population) > 0:\n",
    "        front = []\n",
    "        non_dominated_indices = []\n",
    "        \n",
    "        for i, p1 in enumerate(remaining_population):\n",
    "            dominated = False\n",
    "            for j, p2 in enumerate(remaining_population):\n",
    "                if is_dominated(p2, p1):\n",
    "                    dominated = True\n",
    "                    break\n",
    "            if not dominated:\n",
    "                non_dominated_indices.append(i)\n",
    "                front.append(p1)\n",
    "        \n",
    "        fronts.append(front)\n",
    "        remaining_population = np.delete(remaining_population, non_dominated_indices, axis=0)\n",
    "    \n",
    "    return fronts\n",
    "\n",
    "def plot_pareto_fronts(fronts):\n",
    "    colors = ['r', 'g', 'b', 'y', 'c', 'm']\n",
    "    \n",
    "    for i, front in enumerate(fronts):\n",
    "        front = np.array(front)\n",
    "        plt.scatter(front[:, 0], front[:, 1], color=colors[i % len(colors)], alpha=0.5, label=f'Front {i+1}')\n",
    "    \n",
    "    plt.xlabel('Objective 1')\n",
    "    plt.ylabel('Objective 2')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "def relaxed_chi_square_test(data, bins=10, tolerance=0):\n",
    "    observed_freq, _ = np.histogram(data, bins=bins)\n",
    "    expected_freq = [len(data) / bins] * bins\n",
    "    \n",
    "    chi_square_stat, p_value = chisquare(observed_freq, expected_freq)\n",
    "    \n",
    "    relaxed_threshold = 0.025 + tolerance\n",
    "    \n",
    "    return chi_square_stat, p_value, p_value > relaxed_threshold\n",
    "\n",
    "def relaxed_ks_test(data, tolerance=0):\n",
    "    d_stat, p_value = kstest(data, 'uniform', args=(np.min(data), np.ptp(data)))\n",
    "    \n",
    "    relaxed_threshold = 0.025 + tolerance\n",
    "    \n",
    "    return d_stat, p_value, p_value > relaxed_threshold\n",
    "\n",
    "def check_uniform_distribution_with_tolerance(data, chi_tolerance=0, ks_tolerance=0):\n",
    "    chi_square_stat, chi_p_value, chi_uniform = relaxed_chi_square_test(data, tolerance=chi_tolerance)\n",
    "    d_stat, ks_p_value, ks_uniform = relaxed_ks_test(data, tolerance=ks_tolerance)\n",
    "    \n",
    "    is_uniform = chi_uniform and ks_uniform\n",
    "    \n",
    "    return {\n",
    "        'chi_square_stat': chi_square_stat,\n",
    "        'chi_p_value': chi_p_value,\n",
    "        'ks_stat': d_stat,\n",
    "        'ks_p_value': ks_p_value,\n",
    "        'is_uniform': is_uniform\n",
    "    }\n",
    "\n",
    "def determine_distribution_type(front_sizes):\n",
    "    flat_front_sizes = np.array(front_sizes).flatten()\n",
    "    uniformity_check = check_uniform_distribution_with_tolerance(flat_front_sizes)\n",
    "    \n",
    "    if uniformity_check['is_uniform']:\n",
    "        return 'Uniform'\n",
    "    elif (max(front_sizes) == front_sizes[-1] or max(front_sizes) == front_sizes[-2] or min(front_sizes) == front_sizes[0]) and front_sizes[-1] != front_sizes[0]:\n",
    "        return 'Reversed Quadratic'\n",
    "    elif max(front_sizes) == front_sizes[0] or (len(front_sizes) <= 3 and all(size > len(front_sizes) for size in front_sizes)):\n",
    "        return 'Quadratic'\n",
    "    else:\n",
    "        return 'Random'\n",
    "\n",
    "def generate_tournament_size(S, distribution_type, rng):\n",
    "    if distribution_type == 'Uniform' or distribution_type == 'Quadratic':\n",
    "        probabilities = [0.6 / min(4, (S-1)) if i < 4 else 0.4 / max(1, (S-5)) for i in range(2, S+1)]\n",
    "    elif distribution_type == 'Reversed Quadratic':\n",
    "        probabilities = [0.4 / max(1, (max(S-1, 7)//2-1)) if i >= max(S-1, 7)//2 else 0.6 / max(1, (max(S-1, 7)//2-1)) for i in range(2, max(S, 7))]\n",
    "    elif distribution_type == 'Random':\n",
    "        probabilities = [0.6 / max(1, (max(S//2, 4)-1)) if i < max(S//2, 4) else 0.4 / max(1, (max(S//2, 4)-1)) for i in range(2, max(S, 7))]\n",
    "    \n",
    "    # Normalize probabilities to sum to 1\n",
    "    total = sum(probabilities)\n",
    "    probabilities = [p / total for p in probabilities]\n",
    "    \n",
    "    return rng.choice(range(2, S+1 if distribution_type in ['Uniform', 'Quadratic'] else max(S, 7)), p=probabilities)\n",
    "\n",
    "def cluster_population(population, fronts, distribution_type, S):\n",
    "    clusters = {}\n",
    "    \n",
    "    if distribution_type in ['Uniform', 'Quadratic']:\n",
    "        for i, ind in enumerate(population):\n",
    "            fitness_value = ind.fitness\n",
    "            if fitness_value not in clusters:\n",
    "                clusters[fitness_value] = []\n",
    "            clusters[fitness_value].append(i)\n",
    "    \n",
    "    elif distribution_type == 'Reversed Quadratic':\n",
    "        for i in range(S-1):\n",
    "            clusters[i+1] = [idx for idx, _ in enumerate(fronts[i])]\n",
    "        clusters[S] = [idx for i in range(S-1, len(fronts)) for idx in range(len(fronts[i]))]\n",
    "    \n",
    "    elif distribution_type == 'Random':\n",
    "        for i in range(len(fronts)):\n",
    "            clusters[i+1] = [idx for idx, _ in enumerate(fronts[i])]\n",
    "    \n",
    "    return clusters\n",
    "\n",
    "def population_observer(population, prng):\n",
    "    # Convert population to appropriate structure\n",
    "    fitness_values = [ind.fitness for ind in population]\n",
    "    \n",
    "    fronts = identify_pareto_fronts(fitness_values)\n",
    "    front_sizes = [len(front) for front in fronts]\n",
    "    \n",
    "    distribution_type = determine_distribution_type(front_sizes)\n",
    "    \n",
    "    if distribution_type == 'Uniform':\n",
    "        S = len(set(individual.fitness for individual in population))\n",
    "    elif distribution_type == 'Reversed Quadratic':\n",
    "        S = math.ceil(len(fronts) / 2)\n",
    "    elif distribution_type == 'Quadratic':\n",
    "        S = len(set(individual.fitness for individual in population))\n",
    "    elif distribution_type == 'Random':\n",
    "        S = len(fronts)\n",
    "    \n",
    "    rng = np.random.default_rng(seed=42)  # Use the same seed for reproducibility\n",
    "    tournament_size = generate_tournament_size(S, distribution_type, rng)\n",
    "\n",
    "    clusters = cluster_population(population, fronts, distribution_type, S)\n",
    "    \n",
    "    plot_pareto_fronts(fronts)\n",
    "    print(distribution_type)\n",
    "    \n",
    "    return {\n",
    "        'fronts': fronts,\n",
    "        'front_sizes': front_sizes,\n",
    "        'distribution_type': distribution_type,\n",
    "        'S': S,\n",
    "        'tournament_size': tournament_size,\n",
    "        'clusters': clusters\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_index_meta_pool(indices, seed=None):\n",
    "    if seed is not None:\n",
    "        random.seed(seed)  # Set the random seed if provided\n",
    "    # Encoding: Assign random priority\n",
    "    chromosome = [(index, random.randint(1, len(indices))) for index in indices]\n",
    "    return chromosome\n",
    "\n",
    "def decode_chromosome(chromosome):\n",
    "    # Decoding: Calculate mean priority using round() for 四舍五入\n",
    "    mean_priority = round(sum(x[1] for x in chromosome) / len(chromosome))\n",
    "\n",
    "    # Sort by priority and select indices\n",
    "    sorted_chromosome = sorted(chromosome, key=lambda x: x[1])\n",
    "    selected_indices = [x[0] for x in sorted_chromosome[:mean_priority]]\n",
    "    return selected_indices\n",
    "\n",
    "\n",
    "def custom_generator(random, args):\n",
    "    global init_counter\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    class_index_meta_pool = inner_args.get(\"class_index_meta_pool\")\n",
    "\n",
    "    individual_chromosome = encode_index_meta_pool(class_index_meta_pool, seed=init_counter)\n",
    "    individual = inspyred.ec.Individual(individual_chromosome)\n",
    "    init_counter += 1\n",
    "\n",
    "    return individual  # Return the initial individual of chromosome\n",
    "\n",
    "\n",
    "def evaluate(candidates, args):\n",
    "    global generation_counter\n",
    "    generation_counter += 1\n",
    "    global history_IndexesList_dict\n",
    "    global topic_name\n",
    "    global X_train_r\n",
    "    global Y_train_r\n",
    "    global X_test_re\n",
    "    global Y_test_re\n",
    "    global catboost_params\n",
    "    global GPU_limit\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "\n",
    "    results = []\n",
    "    for candidate in candidates:\n",
    "        if GPU_limit == True:\n",
    "            return results\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        \n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "\n",
    "        fitness_score = fitness_evaluation(data_syn, candidate, selected_indexes, indices_frozenset, generation_counter, 'Evaluation')\n",
    "\n",
    "        results.append(fitness_score)\n",
    "\n",
    "    return results\n",
    "\n",
    "\n",
    "def lexicase_elitist_with_cluster_tournament_selection(random, population, args):\n",
    "    global history_IndexesList_dict\n",
    "    global generation_counter\n",
    "    global Smode\n",
    "    global topic_number\n",
    "    global topic_name\n",
    "    global topic_name_cases_order\n",
    "    global topic_group_probabilities\n",
    "    global topic_name_sizes_order\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "\n",
    "    topic_name_cases_order = move_to_first_list(topic_name_cases_order, topic_name)\n",
    "    topic_name_sizes_order = move_to_first_list(topic_name_sizes_order, topic_name)\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    num_selected = inner_args.get(\"num_selected\")\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "    flattened_length = sum(len(sublist) for sublist in topic_name_sizes_order)\n",
    "\n",
    "    # Call population_observer to analyze the population\n",
    "    result = population_observer(population, random)\n",
    "    \n",
    "    # Retrieve the clusters from the result\n",
    "    clusters = result['clusters']\n",
    "\n",
    "    fitness_scores = []\n",
    "    chromosomes_indexes_pair = {}\n",
    "    for pop_i, individual in enumerate(population):\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(individual, inspyred.ec.Individual) and depth < max_depth:\n",
    "            individual = individual.candidate\n",
    "            depth += 1\n",
    "\n",
    "        selected_indexes = decode_chromosome(individual)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "        fitness_scores.append(fitness_score)\n",
    "        chromosomes_indexes_pair[pop_i] = (individual, indices_frozenset)\n",
    "    \n",
    "    # Reserve the best individuals\n",
    "    best_OverallAcc_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][0])\n",
    "    best_ClassRecall_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][1])\n",
    "    elitism_indices = {best_OverallAcc_index, best_ClassRecall_index}\n",
    "    mating_pool = [population[i] for i in elitism_indices]\n",
    "    already_selected = set(elitism_indices)  # Keep track of already selected candidates\n",
    "\n",
    "    while len(mating_pool) < (num_selected + len(elitism_indices))/2: # (num_selected + len(elitism_indices))/2:\n",
    "        selected_set = set()\n",
    "        lexicase = [select_string_based_on_probabilities_no_duplicates(topic_name_sizes_order, topic_group_probabilities, random, selected_set) for _ in range(flattened_length)]\n",
    "\n",
    "        # Initialize candidates as the entire population\n",
    "        candidates = list(range(len(population)))\n",
    "        cases = lexicase[:]\n",
    "\n",
    "        while cases and len(candidates) > 1:\n",
    "            case = cases.pop(0)\n",
    "            # Get fitness for each candidate\n",
    "            case_fitness = {i: history_IndexesList_dict[chromosomes_indexes_pair[i][1]]['classification_df'].loc[case, 'recall'] for i in candidates}\n",
    "            # Find the maximum fitness for the current case\n",
    "            max_fitness = max(case_fitness.values())\n",
    "            # Filter candidates with the maximum fitness\n",
    "            candidates = [i for i in candidates if case_fitness[i] == max_fitness]\n",
    "\n",
    "            # Ensure uniqueness based on indices_frozenset\n",
    "            unique_candidates = {}\n",
    "            for i in candidates:\n",
    "                indices_frozenset = chromosomes_indexes_pair[i][1]\n",
    "                if indices_frozenset not in unique_candidates:\n",
    "                    unique_candidates[indices_frozenset] = i\n",
    "\n",
    "            # Update candidates to unique ones\n",
    "            candidates = list(unique_candidates.values())\n",
    "\n",
    "        # If we have exactly one candidate, select it\n",
    "        if len(candidates) == 1:\n",
    "            if Smode != \"Sovl\":\n",
    "                if population[candidates[0]] not in already_selected:\n",
    "                    mating_pool.append(population[candidates[0]])\n",
    "                    already_selected.add(candidates[0])\n",
    "                else:\n",
    "                    # Find the second best from the previous case fitness\n",
    "                    if case_fitness:\n",
    "                        sorted_fitness = sorted(case_fitness.items(), key=lambda x: x[1], reverse=True)\n",
    "                        for idx, fitness in sorted_fitness:\n",
    "                            if population[idx] not in already_selected:\n",
    "                                mating_pool.append(population[idx])\n",
    "                                already_selected.add(idx)\n",
    "                                break\n",
    "            else:\n",
    "                mating_pool.append(population[candidates[0]])\n",
    "        elif len(candidates) > 1:\n",
    "            for index in candidates:\n",
    "                mating_pool.append(population[index])\n",
    "                already_selected.add(index)\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "    \n",
    "    rng = np.random.default_rng(seed=42)\n",
    "    while len(mating_pool) < num_selected:\n",
    "        if Smode != \"Sovl\":\n",
    "            # Create a set of indexes whose corresponding elements in `population` are in `mating_pool`\n",
    "            excluded_indexes = {i for i, individual in enumerate(population) if individual in mating_pool}\n",
    "\n",
    "            # Update dictionary by removing excluded indexes\n",
    "            for coords in list(clusters.keys()):\n",
    "                clusters[coords] = [index for index in clusters[coords] if index not in excluded_indexes]\n",
    "                # Remove the key if the list becomes empty\n",
    "                if not clusters[coords]:\n",
    "                    del clusters[coords]\n",
    "\n",
    "        current_tournament_size = generate_tournament_size(min(result['S'], len(clusters)), result['distribution_type'], rng)\n",
    "        if current_tournament_size < 2:  # Ensuring there's at least a minimal pool for a tournament\n",
    "            print(\"Warning: Not enough individuals to continue meaningful selection.\")\n",
    "            break\n",
    "        selected_clusters = random.sample(list(clusters.keys()), min(len(list(clusters.keys())), current_tournament_size))\n",
    "        if result['distribution_type'] in ['Uniform', 'Quadratic']:\n",
    "            non_dominated_indices = []\n",
    "            for i, f1 in enumerate(selected_clusters):\n",
    "                dominated = False\n",
    "                for j, f2 in enumerate(selected_clusters):\n",
    "                    if is_dominated(f2, f1):\n",
    "                        dominated = True\n",
    "                        break\n",
    "                if not dominated:\n",
    "                    non_dominated_indices.append(f1)\n",
    "            for cluster in non_dominated_indices:\n",
    "                # Select a random individual from the best cluster\n",
    "                selected_individual = random.choice(clusters[cluster])\n",
    "                mating_pool.append(population[selected_individual])\n",
    "        elif result['distribution_type'] in ['Reversed Quadratic', 'Random']:\n",
    "            # Select the best cluster based on some criteria (e.g., smallest index for simplicity)\n",
    "            best_cluster = min(selected_clusters)\n",
    "            # Select a random individual from the best cluster\n",
    "            selected_individual = random.choice(clusters[best_cluster])\n",
    "            # Append the selected individual to the selected_parents list\n",
    "            mating_pool.append(population[selected_individual])\n",
    "\n",
    "        if len(clusters) < current_tournament_size:\n",
    "            break  # If we can't fill the tournament, break the loop\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "        \n",
    "    for candidate in mating_pool:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = fitness_evaluation(data_syn, candidate, selected_indexes, indices_frozenset, generation_counter, 'Selected')\n",
    "        \n",
    "    return mating_pool\n",
    "\n",
    "\n",
    "def lexicase_elitist_with_stochastic_tournament_selection(random, population, args):\n",
    "    global history_IndexesList_dict\n",
    "    global generation_counter\n",
    "    global Smode\n",
    "    global topic_number\n",
    "    global topic_name\n",
    "    global topic_name_cases_order\n",
    "    global topic_group_probabilities\n",
    "    global topic_name_sizes_order\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "\n",
    "    topic_name_cases_order = move_to_first_list(topic_name_cases_order, topic_name)\n",
    "    topic_name_sizes_order = move_to_first_list(topic_name_sizes_order, topic_name)\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    num_selected = inner_args.get(\"num_selected\")\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "    flattened_length = sum(len(sublist) for sublist in topic_name_sizes_order)\n",
    "\n",
    "    fitness_scores = []\n",
    "    chromosomes_indexes_pair = {}\n",
    "    for pop_i, individual in enumerate(population):\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(individual, inspyred.ec.Individual) and depth < max_depth:\n",
    "            individual = individual.candidate\n",
    "            depth += 1\n",
    "\n",
    "        selected_indexes = decode_chromosome(individual)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "        fitness_scores.append(fitness_score)\n",
    "        chromosomes_indexes_pair[pop_i] = (individual, indices_frozenset)\n",
    "    \n",
    "    # Reserve the best individuals\n",
    "    best_OverallAcc_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][0])\n",
    "    best_ClassRecall_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][1])\n",
    "    elitism_indices = {best_OverallAcc_index, best_ClassRecall_index}\n",
    "    mating_pool = [population[i] for i in elitism_indices]\n",
    "    already_selected = set(elitism_indices)  # Keep track of already selected candidates\n",
    "\n",
    "    while len(mating_pool) < (num_selected + len(elitism_indices))/2: # (num_selected + len(elitism_indices))/2:\n",
    "        selected_set = set()\n",
    "        lexicase = [select_string_based_on_probabilities_no_duplicates(topic_name_sizes_order, topic_group_probabilities, random, selected_set) for _ in range(flattened_length)]\n",
    "\n",
    "        # Initialize candidates as the entire population\n",
    "        candidates = list(range(len(population)))\n",
    "        cases = lexicase[:]\n",
    "\n",
    "        while cases and len(candidates) > 1:\n",
    "            case = cases.pop(0)\n",
    "            # Get fitness for each candidate\n",
    "            case_fitness = {i: history_IndexesList_dict[chromosomes_indexes_pair[i][1]]['classification_df'].loc[case, 'recall'] for i in candidates}\n",
    "            # Find the maximum fitness for the current case\n",
    "            max_fitness = max(case_fitness.values())\n",
    "            # Filter candidates with the maximum fitness\n",
    "            candidates = [i for i in candidates if case_fitness[i] == max_fitness]\n",
    "\n",
    "            # Ensure uniqueness based on indices_frozenset\n",
    "            unique_candidates = {}\n",
    "            for i in candidates:\n",
    "                indices_frozenset = chromosomes_indexes_pair[i][1]\n",
    "                if indices_frozenset not in unique_candidates:\n",
    "                    unique_candidates[indices_frozenset] = i\n",
    "\n",
    "            # Update candidates to unique ones\n",
    "            candidates = list(unique_candidates.values())\n",
    "\n",
    "        # If we have exactly one candidate, select it\n",
    "        if len(candidates) == 1:\n",
    "            if Smode != \"Sovl\":\n",
    "                if population[candidates[0]] not in already_selected:\n",
    "                    mating_pool.append(population[candidates[0]])\n",
    "                    already_selected.add(candidates[0])\n",
    "                else:\n",
    "                    # Find the second best from the previous case fitness\n",
    "                    if case_fitness:\n",
    "                        sorted_fitness = sorted(case_fitness.items(), key=lambda x: x[1], reverse=True)\n",
    "                        for idx, fitness in sorted_fitness:\n",
    "                            if population[idx] not in already_selected:\n",
    "                                mating_pool.append(population[idx])\n",
    "                                already_selected.add(idx)\n",
    "                                break\n",
    "            else:\n",
    "                mating_pool.append(population[candidates[0]])\n",
    "        elif len(candidates) > 1:\n",
    "            for index in candidates:\n",
    "                mating_pool.append(population[index])\n",
    "                already_selected.add(index)\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "    \n",
    "    while len(mating_pool) < num_selected:\n",
    "        current_tournament_size = random.randrange(2, len(population)+1)\n",
    "        if current_tournament_size < 2:  # Ensuring there's at least a minimal pool for a tournament\n",
    "            print(\"Warning: Not enough individuals to continue meaningful selection.\")\n",
    "            break\n",
    "        tournament_indices = random.sample(range(len(population)), current_tournament_size)\n",
    "        if Smode != \"Sovl\":\n",
    "            # Filter out indices already selected for the mating pool to avoid duplicates\n",
    "            tournament_indices = [i for i in tournament_indices if population[i] not in mating_pool]\n",
    "\n",
    "        if len(tournament_indices) < current_tournament_size:\n",
    "            break  # If we can't fill the tournament, break the loop\n",
    "\n",
    "        # Create the tournament from the selected indices\n",
    "        tournament = [population[i] for i in tournament_indices]\n",
    "        tournament_fitness = [fitness_scores[i] for i in tournament_indices]\n",
    "\n",
    "        # Find the best individuals based on different criteria\n",
    "        best_OverallAcc_tournament_index = tournament_indices[max(range(len(tournament)), key=lambda i: tournament_fitness[i][0])]\n",
    "        best_ClassRecall_tournament_index = tournament_indices[max(range(len(tournament)), key=lambda i: tournament_fitness[i][1])]\n",
    "\n",
    "        # Append the best individual by time to the mating pool\n",
    "        mating_pool.append(population[best_OverallAcc_tournament_index])\n",
    "\n",
    "        # # Append the best individual by crowding if it's different from the best by time\n",
    "        # if best_OverallAcc_tournament_index != best_ClassRecall_tournament_index:\n",
    "        # if len(mating_pool) < num_selected:  # Check capacity before adding\n",
    "        mating_pool.append(population[best_ClassRecall_tournament_index])\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "        \n",
    "    for candidate in mating_pool:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = fitness_evaluation(data_syn, candidate, selected_indexes, indices_frozenset, generation_counter, 'Selected')\n",
    "        \n",
    "    return mating_pool\n",
    "\n",
    "\n",
    "\n",
    "def cluster_tournament_selection_with_elitism(random, population, args):\n",
    "    global history_IndexesList_dict\n",
    "    global tournament_size\n",
    "    global generation_counter\n",
    "    global Smode\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    num_selected = inner_args.get(\"num_selected\")\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "\n",
    "    # Call population_observer to analyze the population\n",
    "    result = population_observer(population, random)\n",
    "    \n",
    "    # Retrieve the clusters from the result\n",
    "    clusters = result['clusters']\n",
    "\n",
    "    fitness_scores = []\n",
    "    for individual in population:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(individual, inspyred.ec.Individual) and depth < max_depth:\n",
    "            individual = individual.candidate\n",
    "            depth += 1\n",
    "\n",
    "        selected_indexes = decode_chromosome(individual)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "        fitness_scores.append(fitness_score)\n",
    "    \n",
    "    # Reserve the best individuals\n",
    "    best_OverallAcc_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][0])\n",
    "    best_ClassRecall_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][1])\n",
    "    elitism_indices = {best_OverallAcc_index, best_ClassRecall_index}\n",
    "    mating_pool = [population[i] for i in elitism_indices]\n",
    "\n",
    "    rng = np.random.default_rng(seed=42)\n",
    "    while len(mating_pool) < num_selected:\n",
    "        if Smode != \"Sovl\":\n",
    "            # Create a set of indexes whose corresponding elements in `population` are in `mating_pool`\n",
    "            excluded_indexes = {i for i, individual in enumerate(population) if individual in mating_pool}\n",
    "\n",
    "            # Update dictionary by removing excluded indexes\n",
    "            for coords in list(clusters.keys()):\n",
    "                clusters[coords] = [index for index in clusters[coords] if index not in excluded_indexes]\n",
    "                # Remove the key if the list becomes empty\n",
    "                if not clusters[coords]:\n",
    "                    del clusters[coords]\n",
    "\n",
    "        current_tournament_size = generate_tournament_size(min(result['S'], len(clusters)), result['distribution_type'], rng)\n",
    "        if current_tournament_size < 2:  # Ensuring there's at least a minimal pool for a tournament\n",
    "            print(\"Warning: Not enough individuals to continue meaningful selection.\")\n",
    "            break\n",
    "        selected_clusters = random.sample(list(clusters.keys()), min(len(list(clusters.keys())), current_tournament_size))\n",
    "        if result['distribution_type'] in ['Uniform', 'Quadratic']:\n",
    "            non_dominated_indices = []\n",
    "            for i, f1 in enumerate(selected_clusters):\n",
    "                dominated = False\n",
    "                for j, f2 in enumerate(selected_clusters):\n",
    "                    if is_dominated(f2, f1):\n",
    "                        dominated = True\n",
    "                        break\n",
    "                if not dominated:\n",
    "                    non_dominated_indices.append(f1)\n",
    "            for cluster in non_dominated_indices:\n",
    "                # Select a random individual from the best cluster\n",
    "                selected_individual = random.choice(clusters[cluster])\n",
    "                mating_pool.append(population[selected_individual])\n",
    "        elif result['distribution_type'] in ['Reversed Quadratic', 'Random']:\n",
    "            # Select the best cluster based on some criteria (e.g., smallest index for simplicity)\n",
    "            best_cluster = min(selected_clusters)\n",
    "            # Select a random individual from the best cluster\n",
    "            selected_individual = random.choice(clusters[best_cluster])\n",
    "            # Append the selected individual to the selected_parents list\n",
    "            mating_pool.append(population[selected_individual])\n",
    "\n",
    "        if len(clusters) < current_tournament_size:\n",
    "            break  # If we can't fill the tournament, break the loop\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "        \n",
    "    for candidate in mating_pool:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = fitness_evaluation(data_syn, candidate, selected_indexes, indices_frozenset, generation_counter, 'Selected')\n",
    "        \n",
    "    return mating_pool\n",
    "\n",
    "\n",
    "\n",
    "def nsgaii_tournament_selection_with_priority_and_elitism(random, population, args):\n",
    "    global history_IndexesList_dict\n",
    "    global tournament_size\n",
    "    global generation_counter\n",
    "    global Smode\n",
    "    global GPU_limit\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    num_selected = inner_args.get(\"num_selected\")\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "\n",
    "    fitness_scores = []\n",
    "    for individual in population:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(individual, inspyred.ec.Individual) and depth < max_depth:\n",
    "            individual = individual.candidate\n",
    "            depth += 1\n",
    "\n",
    "        selected_indexes = decode_chromosome(individual)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "        fitness_scores.append(fitness_score)\n",
    "    \n",
    "    # Reserve the best individuals\n",
    "    best_OverallAcc_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][0])\n",
    "    best_ClassRecall_index = max(range(len(fitness_scores)), key=lambda i: fitness_scores[i][1])\n",
    "    elitism_indices = {best_OverallAcc_index, best_ClassRecall_index}\n",
    "    mating_pool = [population[i] for i in elitism_indices]\n",
    "\n",
    "    while len(mating_pool) < num_selected:\n",
    "        current_tournament_size = random.randrange(2, len(population)+1) # random.randrange(2, len(population)+1) ; min(tournament_size, len(population))\n",
    "        if current_tournament_size < 2:  # Ensuring there's at least a minimal pool for a tournament\n",
    "            print(\"Warning: Not enough individuals to continue meaningful selection.\")\n",
    "            break\n",
    "        tournament_indices = random.sample(range(len(population)), current_tournament_size)\n",
    "        if Smode != \"Sovl\":\n",
    "            # Filter out indices already selected for the mating pool to avoid duplicates\n",
    "            tournament_indices = [i for i in tournament_indices if population[i] not in mating_pool]\n",
    "\n",
    "        if len(tournament_indices) < current_tournament_size:\n",
    "            break  # If we can't fill the tournament, break the loop\n",
    "\n",
    "        # Create the tournament from the selected indices\n",
    "        tournament = [population[i] for i in tournament_indices]\n",
    "        tournament_fitness = [fitness_scores[i] for i in tournament_indices]\n",
    "\n",
    "        # Find the best individuals based on different criteria\n",
    "        best_OverallAcc_tournament_index = tournament_indices[max(range(len(tournament)), key=lambda i: tournament_fitness[i][0])]\n",
    "        best_ClassRecall_tournament_index = tournament_indices[max(range(len(tournament)), key=lambda i: tournament_fitness[i][1])]\n",
    "\n",
    "        # Append the best individual by time to the mating pool\n",
    "        mating_pool.append(population[best_OverallAcc_tournament_index])\n",
    "\n",
    "        # # Append the best individual by crowding if it's different from the best by time\n",
    "        # if best_OverallAcc_tournament_index != best_ClassRecall_tournament_index:\n",
    "        # if len(mating_pool) < num_selected:  # Check capacity before adding\n",
    "        mating_pool.append(population[best_ClassRecall_tournament_index])\n",
    "\n",
    "        # If the mating pool is already filled, break the loop\n",
    "        if len(mating_pool) >= num_selected:\n",
    "            break\n",
    "        \n",
    "    for candidate in mating_pool:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = fitness_evaluation(data_syn, candidate, selected_indexes, indices_frozenset, generation_counter, 'Selected')\n",
    "        \n",
    "    return mating_pool\n",
    "\n",
    "\n",
    "def nsgaii_weight_mapping_crossover(random, candidates, args):\n",
    "    global generation_counter\n",
    "    global gen_stats_df\n",
    "    global Xmode\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "    \n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "    crossover_rate = inner_args.get(\"crossover_rate\")\n",
    "\n",
    "    if Xmode == \"Xnp\":\n",
    "        offsprings = []\n",
    "    else:\n",
    "        offsprings = candidates\n",
    "    \n",
    "    candidate_pairs = list(itertools.combinations(candidates, 2))\n",
    "    \n",
    "    for parent1, parent2 in candidate_pairs:\n",
    "        # Check if crossover should occur\n",
    "        if random.random() > crossover_rate:\n",
    "            # If not, append the original parents to the offspring list\n",
    "            if Xmode == \"Xnp\":\n",
    "                offsprings.extend([parent1, parent2])\n",
    "            continue  # Skip to the next pair\n",
    "\n",
    "        max_depth_1 = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth_1 = 0\n",
    "        while isinstance(parent1, inspyred.ec.Individual) and depth_1 < max_depth_1:\n",
    "            parent1 = parent1.candidate\n",
    "            depth_1 += 1\n",
    "        max_depth_2 = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth_2 = 0\n",
    "        while isinstance(parent2, inspyred.ec.Individual) and depth_2 < max_depth_2:\n",
    "            parent2 = parent2.candidate\n",
    "            depth_2 += 1\n",
    "\n",
    "        # Select a crossover point that is not at the ends of the lists\n",
    "        cut_point = random.randint(1, len(parent1) - 1)\n",
    "\n",
    "        # Split each parent into two parts\n",
    "        parent1_A, parent1_B = parent1[:cut_point], parent1[cut_point:]\n",
    "        parent2_A, parent2_B = parent2[:cut_point], parent2[cut_point:]\n",
    "        \n",
    "        # Extract priority values and sort by these values\n",
    "        parent1_B_sorted = sorted(parent1_B, key=lambda x: x[1])\n",
    "        parent2_B_sorted = sorted(parent2_B, key=lambda x: x[1])\n",
    "\n",
    "        # Extract priority values, sort by these values, and remember original indices\n",
    "        indices1 = list(range(len(parent1_B)))\n",
    "        indices2 = list(range(len(parent2_B)))\n",
    "\n",
    "        # Create mapping from parent1_B to parent2_B based on their sorted priorities\n",
    "        for i in range(len(parent1_B_sorted)):\n",
    "            parent1_B_sorted[i] = (parent1_B_sorted[i][0], parent2_B_sorted[i][1])\n",
    "            parent2_B_sorted[i] = (parent2_B_sorted[i][0], parent1_B_sorted[i][1])\n",
    "                \n",
    "        # Restore the original order of elements based on remembered indices\n",
    "        restored_parent1_B = [None] * len(parent1_B_sorted)\n",
    "        restored_parent2_B = [None] * len(parent2_B_sorted)\n",
    "        for item, index in zip(parent1_B_sorted, indices1):\n",
    "            restored_parent1_B[index] = item\n",
    "        for item, index in zip(parent2_B_sorted, indices2):\n",
    "            restored_parent2_B[index] = item\n",
    "        \n",
    "        # Reconstruct the modified parent lists\n",
    "        new_parent1 = parent1_A + restored_parent1_B\n",
    "        new_parent2 = parent2_A + restored_parent2_B\n",
    "\n",
    "        # Attempt to evaluate the new individuals\n",
    "        try:\n",
    "            np1_selected_indexes = decode_chromosome(new_parent1)\n",
    "            # Convert list to frozenset\n",
    "            np1_indices_frozenset = frozenset(np1_selected_indexes)\n",
    "            fitness1 = fitness_evaluation(data_syn, new_parent1, np1_selected_indexes, np1_indices_frozenset, generation_counter, 'Crossover')\n",
    "            # Convert new parent1 back to an inspyred individual\n",
    "            offsprings.append(inspyred.ec.Individual(new_parent1))\n",
    "        except Exception as e:\n",
    "            print(f\"Failed to evaluate new_parent1 '{new_parent1}' due to error: {e}\")\n",
    "\n",
    "        try:\n",
    "            np2_selected_indexes = decode_chromosome(new_parent2)\n",
    "            # Convert list to frozenset\n",
    "            np2_indices_frozenset = frozenset(np2_selected_indexes)\n",
    "            fitness2 = fitness_evaluation(data_syn, new_parent2, np2_selected_indexes, np2_indices_frozenset, generation_counter, 'Crossover')\n",
    "            # Convert new parent2 back to an inspyred individual\n",
    "            offsprings.append(inspyred.ec.Individual(new_parent2))\n",
    "        except Exception as e:\n",
    "            print(f\"Failed to evaluate new_parent2 '{new_parent2}' due to error: {e}\")\n",
    "\n",
    "    print(len(offsprings))\n",
    "    return offsprings\n",
    "\n",
    "\n",
    "def nsgaii_mutate_individual(random, candidates, args):\n",
    "    global generation_counter\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return []\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "    max_generations = inner_args.get(\"max_generations\")\n",
    "    initial_mutation_rate = inner_args.get(\"initial_mutation_rate\")\n",
    "\n",
    "    mutated_candidates = []\n",
    "\n",
    "    for individual in candidates:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(individual, inspyred.ec.Individual) and depth < max_depth:\n",
    "            individual = individual.candidate\n",
    "            depth += 1\n",
    "        \n",
    "        selected_indexes = decode_chromosome(individual)\n",
    "        # Convert list to frozenset\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = history_IndexesList_dict[indices_frozenset]['fitness_score']\n",
    "\n",
    "        individual = adaptive_polynomial_mutation(individual, generation_counter, max_generations, initial_rate=initial_mutation_rate)\n",
    "\n",
    "        # Attempt to evaluate the new individuals\n",
    "        try:\n",
    "            selected_indexes = decode_chromosome(individual)\n",
    "            # Convert list to frozenset\n",
    "            indices_frozenset = frozenset(selected_indexes)\n",
    "            fitness = fitness_evaluation(data_syn, individual, selected_indexes, indices_frozenset, generation_counter, 'Mutation')\n",
    "            # Convert new parent1 back to an inspyred individual\n",
    "            mutated_candidates.append(inspyred.ec.Individual(individual))\n",
    "        except Exception as e:\n",
    "            print(f\"Failed to evaluate individual '{individual}' due to error: {e}\")\n",
    "    \n",
    "    return mutated_candidates\n",
    "\n",
    "\n",
    "def custom_observer(population, num_generations, num_evaluations, args):\n",
    "    global topic_name\n",
    "    global topic_number\n",
    "    global generation_counter\n",
    "    global nsgaii_results_path\n",
    "    global gen_stats_df_name\n",
    "    global gen_stats_df\n",
    "    global history_pareto_selections_list\n",
    "    global NSGA_II_results_name\n",
    "    global gen_pfs_df\n",
    "    global history_IndexesList_dict\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return\n",
    "\n",
    "    inner_args = args.get(\"args\")  # Retrieve the nested dictionary\n",
    "    data_syn = inner_args.get(\"data_syn\")\n",
    "    \"\"\"\n",
    "    Custom observer to handle, display, and plot both objectives.\n",
    "    \"\"\"\n",
    "    # Extract fitness values\n",
    "    fitness_tuples = [ind.fitness for ind in population]\n",
    "\n",
    "    # Convert to separate lists for plotting\n",
    "    objective_1_values = [ft[0] for ft in fitness_tuples]\n",
    "    objective_2_values = [ft[1] for ft in fitness_tuples]\n",
    "\n",
    "    # Identify Pareto front\n",
    "    def is_dominated(ind, other):\n",
    "        \"\"\" Check if ind is dominated by another individual \"\"\"\n",
    "        return all(o >= i for i, o in zip(ind.fitness, other.fitness)) and any(o > i for i, o in zip(ind.fitness, other.fitness))\n",
    "\n",
    "    pareto_front = [ind for ind in population if not any(is_dominated(ind, other) for other in population if ind != other)]\n",
    "\n",
    "    pareto_fitness_tuples = [ind.fitness for ind in pareto_front]\n",
    "\n",
    "    # Statistics\n",
    "    best_individual = max(population, key=lambda ind: (ind.fitness[0], ind.fitness[1]))\n",
    "    best_fitness = best_individual.fitness\n",
    "\n",
    "    worst_individual = min(population, key=lambda ind: (ind.fitness[0], ind.fitness[1]))\n",
    "    worst_fitness = worst_individual.fitness\n",
    "\n",
    "    mean_OverallAcc = sum(ind.fitness[0] for ind in population) / len(population)\n",
    "    mean_ClassRecall = sum(ind.fitness[1] for ind in population) / len(population)\n",
    "\n",
    "    print(f\"Generation: {num_generations}, Evaluations: {num_evaluations}\")\n",
    "    print(f\"Best Fitness: {best_fitness}\")\n",
    "    print(f\"Worst Fitness: {worst_fitness}\")\n",
    "    print(f\"Mean Overall Accuracy: {mean_OverallAcc}\")\n",
    "    print(f\"Mean {topic_name} Recall: {mean_ClassRecall}\")\n",
    "\n",
    "    # Print out the unique paths on the Pareto front\n",
    "    pareto_paths = []\n",
    "    pareto_chromosomes = []\n",
    "    pareto_fitnesses = []\n",
    "    for ind in pareto_front:\n",
    "        max_depth = 10  # Prevent infinite loops by setting a maximum depth\n",
    "        depth = 0\n",
    "        while isinstance(ind, inspyred.ec.Individual) and depth < max_depth:\n",
    "            ind = ind.candidate\n",
    "            depth += 1\n",
    "        chromosome = ind\n",
    "        decoded_path = decode_chromosome(chromosome)\n",
    "        pareto_paths.append(decoded_path)\n",
    "        pareto_chromosomes.append(chromosome)\n",
    "\n",
    "        selected_indexes = decoded_path\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        fitness_score = fitness_evaluation(data_syn, chromosome, selected_indexes, indices_frozenset, generation_counter, 'On Pareto Front')\n",
    "        pareto_fitnesses.append(fitness_score)\n",
    "\n",
    "    # Dictionary to store unique paths\n",
    "    unique_paths_dict = {}\n",
    "    for path in pareto_paths:\n",
    "        # Sort the path to create a key (as a tuple for immutability)\n",
    "        key = tuple(sorted(path))\n",
    "        # Only add the original path if the key hasn't been seen before\n",
    "        if key not in unique_paths_dict:\n",
    "            unique_paths_dict[key] = path\n",
    "    # Extract the paths; they maintain their original order from the first appearance\n",
    "    unique_pareto_paths = list(unique_paths_dict.values())\n",
    "    unique_pareto_chromosomes = [pareto_chromosomes[pareto_paths.index(list(path))] for path in unique_pareto_paths]\n",
    "    unique_pareto_fitnesses = [pareto_fitnesses[pareto_paths.index(list(path))] for path in unique_pareto_paths]\n",
    "\n",
    "    print(\"Pareto Front Selections:---------------------\")\n",
    "    for i, (path, chromosome, fitness) in enumerate(zip(unique_pareto_paths, unique_pareto_chromosomes, unique_pareto_fitnesses)):\n",
    "        print(f\"Selection {i+1}: {path} \\nFitness: {fitness}\")\n",
    "        history_pareto_selections_list.append((path, chromosome, fitness))\n",
    "        indices_frozenset = frozenset(path)\n",
    "        classification_df = history_IndexesList_dict[indices_frozenset]['classification_df']\n",
    "        # Collect all generation data into a new DataFrame row\n",
    "        gen_PFs_row = {\n",
    "            \"topic_name\": topic_name,\n",
    "            \"topic_number\": topic_number,\n",
    "            \"generation\": num_generations,\n",
    "            'fitness_score': fitness,\n",
    "            \"accuracy\": fitness[0],\n",
    "            \"topic_recall\": fitness[1],\n",
    "            'balanced_fitness_score': (classification_df.loc['accuracy', 'Balanced Accuracy'], classification_df.loc[topic_name, 'Balanced Accuracy']),\n",
    "            'overall_balanced_accuracy': classification_df.loc['accuracy', 'Balanced Accuracy'],\n",
    "            'topic_balanced_accuracy': classification_df.loc[topic_name, 'Balanced Accuracy'],\n",
    "            'balanced_acc_rec_score': (classification_df.loc[topic_name, 'Balanced Accuracy'], classification_df.loc[topic_name, 'recall']),\n",
    "            'topic_F1': classification_df.loc[topic_name, 'f1-score'],\n",
    "            'overall_F1': classification_df.loc['accuracy', 'f1-score'],\n",
    "            'overall_recall': classification_df.loc['accuracy', 'recall'],\n",
    "            \"retraining_time\": history_IndexesList_dict[indices_frozenset]['retraining_time'],\n",
    "            \"number_of_syn_sample\": len(path),\n",
    "            \"retrained_dots_list\": path,\n",
    "            'true_labels': history_IndexesList_dict[indices_frozenset]['true_labels'],\n",
    "            'predicted_labels': history_IndexesList_dict[indices_frozenset]['predicted_labels'],\n",
    "            'chromosome': chromosome\n",
    "        }\n",
    "        # # Transform DataFrame to dict format\n",
    "        # for idx in classification_df.index:\n",
    "        #     if idx != 'accuracy' and idx != 'macro avg' and idx != 'weighted avg':\n",
    "        #         gen_PFs_row[idx] = classification_df.loc[idx].dropna().to_dict()\n",
    "        # Convert the dictionary to a DataFrame for a single row\n",
    "        gen_PFs_row_df = pd.DataFrame([gen_PFs_row])\n",
    "        # Concatenate this new row DataFrame to the existing DataFrame\n",
    "        gen_pfs_df = pd.concat([gen_pfs_df, gen_PFs_row_df], ignore_index=True)\n",
    "        gen_pfs_df.to_csv(f'{nsgaii_results_path}/GenPFs_{NSGA_II_results_name}.csv', index=False)\n",
    "        gen_pfs_df.to_pickle(f'{nsgaii_results_path}/GenPFs_{NSGA_II_results_name}.pkl')\n",
    "        print('---')\n",
    "    print('------------------------')\n",
    "    \n",
    "    recall_key = f\"Mean {topic_name} Recall\"\n",
    "    # Collect all generation data into a new DataFrame row\n",
    "    new_row = {\n",
    "        \"Generation\": num_generations,\n",
    "        \"Number of Evaluations\": num_evaluations,\n",
    "        \"Best Fitness\": best_fitness,\n",
    "        \"Worst Fitness\": worst_fitness,\n",
    "        \"Mean Overall Accuracy\": mean_OverallAcc,\n",
    "        recall_key: mean_ClassRecall,\n",
    "        \"Pareto Front Selections\": [unique_pareto_paths, unique_pareto_fitnesses]\n",
    "    }\n",
    "    # Convert the dictionary to a DataFrame for a single row\n",
    "    new_row_df = pd.DataFrame([new_row])\n",
    "    # Concatenate this new row DataFrame to the existing DataFrame\n",
    "    gen_stats_df = pd.concat([gen_stats_df, new_row_df], ignore_index=True)\n",
    "    gen_stats_df.to_csv(f'{nsgaii_results_path}/{gen_stats_df_name}.csv', index=False)\n",
    "    gen_stats_df.to_pickle(f'{nsgaii_results_path}/{gen_stats_df_name}.pkl')\n",
    "\n",
    "    # Plotting\n",
    "    plt.figure(figsize=(12, 7))\n",
    "    \n",
    "    plt.scatter(objective_1_values, objective_2_values, c='blue', alpha=0.5, label='Population')\n",
    "    plt.scatter([ft[0] for ft in pareto_fitness_tuples], [ft[1] for ft in pareto_fitness_tuples], c='red', alpha=0.9, label='Pareto Front')\n",
    "    \n",
    "    plt.xlabel('Objective 1: Overall Accuracy')\n",
    "    plt.ylabel(f'Objective 2: {topic_name} Recall')\n",
    "    plt.title(f'{topic_number}, Population and Pareto Front at Generation {num_generations}')\n",
    "    plt.legend()\n",
    "    pareto_plots_dir = f\"{nsgaii_results_path}/{gen_stats_df_name}\"\n",
    "    os.makedirs(pareto_plots_dir, exist_ok=True)\n",
    "    filename = f\"Gen_{num_generations}.png\"\n",
    "    plt.savefig(os.path.join(pareto_plots_dir, filename), dpi=200, bbox_inches='tight', pad_inches=0.1)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "def custom_terminator(population, num_generations, num_evaluations, args):\n",
    "    global GPU_limit\n",
    "\n",
    "    if GPU_limit == True:\n",
    "        return True\n",
    "    # Define constants and initialization\n",
    "    max_depth = 10\n",
    "    unique_frozensets = set()\n",
    "    \n",
    "    # Loop through each individual to access the genetic encoding\n",
    "    for ind in population:\n",
    "        candidate = ind.candidate\n",
    "        depth = 0\n",
    "        \n",
    "        # Traverse to the actual candidate if wrapped in Individual class\n",
    "        while isinstance(candidate, inspyred.ec.Individual) and depth < max_depth:\n",
    "            candidate = candidate.candidate\n",
    "            depth += 1\n",
    "        \n",
    "        # Decode the chromosome to get selected indices\n",
    "        selected_indexes = decode_chromosome(candidate)\n",
    "        \n",
    "        # Convert the list of selected indexes to frozenset and add to the set\n",
    "        indices_frozenset = frozenset(selected_indexes)\n",
    "        unique_frozensets.add(indices_frozenset)\n",
    "    \n",
    "    # Check if all individuals have identical indices_frozenset\n",
    "    if len(unique_frozensets) == 1:\n",
    "        return True\n",
    "    \n",
    "    # Call the built-in generation termination condition\n",
    "    return ec.terminators.generation_termination(population, num_generations, num_evaluations, args)\n",
    "\n",
    "\n",
    "\n",
    "def run_nsga2(args, population_size=10, maximize=True, max_generations=20, num_selected=5, seed=42):\n",
    "    prng = random.Random(seed)\n",
    "\n",
    "    # Create an NSGA-II instance\n",
    "    ea = ec.emo.NSGA2(prng)\n",
    "\n",
    "    # Configure the algorithm\n",
    "    ea.observer = [custom_observer]\n",
    "    ea.terminator = custom_terminator # ec.terminators.generation_termination  # or use another termination condition\n",
    "\n",
    "    # Set custom functions\n",
    "    ea.selector = lexicase_elitist_with_cluster_tournament_selection\n",
    "    ea.variator = [nsgaii_weight_mapping_crossover, nsgaii_mutate_individual]  # Or define a simple no-op mutation function\n",
    "\n",
    "    # Run the algorithm\n",
    "    final_pop = ea.evolve(\n",
    "        generator=custom_generator,\n",
    "        evaluator=evaluate,  # Or use your own evaluate function\n",
    "        pop_size=population_size,\n",
    "        maximize=maximize,\n",
    "        max_generations=max_generations,\n",
    "        num_selected=num_selected,\n",
    "        args=args,\n",
    "    )\n",
    "\n",
    "\n",
    "    return final_pop\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dominates(score1, score2):\n",
    "    \"\"\"\n",
    "    Determines if one score dominates another.\n",
    "    A score1 dominates score2 if it is better in all the objectives or equal in some and better in at least one.\n",
    "    \"\"\"\n",
    "    return (score1[0] > score2[0] and score1[1] >= score2[1]) or (score1[0] >= score2[0] and score1[1] > score2[1])\n",
    "\n",
    "def find_pareto_front(df):\n",
    "    \"\"\"\n",
    "    Marks rows as 'Yes' if they are on the Pareto front, 'No' otherwise.\n",
    "    \"\"\"\n",
    "    df = df.copy()  # Copy DataFrame to avoid modifying the original\n",
    "    df['Pareto'] = 'No'  # Initialize the Pareto column with 'No'\n",
    "    \n",
    "    scores = df['balanced_acc_rec_score'].tolist()\n",
    "    is_pareto = np.ones(len(scores), dtype=bool)  # Initialize all as True\n",
    "    \n",
    "    for i1 in range(len(scores)):\n",
    "        for i2 in range(len(scores)):\n",
    "            if i1 != i2 and dominates(scores[i2], scores[i1]):\n",
    "                is_pareto[i1] = False\n",
    "                break\n",
    "\n",
    "    # Update the 'Pareto' column based on the Pareto front\n",
    "    df.loc[is_pareto, 'Pareto'] = 'Yes'\n",
    "    \n",
    "    return df\n",
    "\n",
    "def find_best_values(df):\n",
    "    # Identify the maximum values for each specified column\n",
    "    max_values = {\n",
    "        'accuracy': df['accuracy'].max(),\n",
    "        'topic_recall': df['topic_recall'].max(),\n",
    "        'overall_balanced_accuracy': df['overall_balanced_accuracy'].max(),\n",
    "        'topic_balanced_accuracy': df['topic_balanced_accuracy'].max(),\n",
    "        'topic_F1': df['topic_F1'].max(),\n",
    "        'overall_F1': df['overall_F1'].max(),\n",
    "        'overall_recall': df['overall_recall'].max()\n",
    "    }\n",
    "    \n",
    "    # Function to apply to each row to determine the best columns\n",
    "    def check_best(row):\n",
    "        return [col for col, max_val in max_values.items() if row[col] == max_val]\n",
    "\n",
    "    # Apply the function to each row\n",
    "    df['best'] = df.apply(check_best, axis=1)\n",
    "    \n",
    "    return df\n",
    "\n",
    "def post_process(df, bch_class_df):\n",
    "    global topic_number\n",
    "    topic_name = topic_dict[topic_number]\n",
    "\n",
    "    bch_topic_recall = bch_class_df.loc[topic_name, 'recall']\n",
    "    bch_topic_balanced_accuracy = bch_class_df.loc[topic_name, 'Balanced Accuracy']\n",
    "    bch_overall_balanced_accuracy = bch_class_df.loc['accuracy', 'Balanced Accuracy']\n",
    "    bch_overall_F1_score = bch_class_df.loc['accuracy', 'f1-score']\n",
    "\n",
    "    # Calculate improvements\n",
    "    df['imp_topic_recall'] = df['topic_recall'] - bch_topic_recall\n",
    "    df['imp_topic_balanced_accuracy'] = df['topic_balanced_accuracy'] - bch_topic_balanced_accuracy\n",
    "    df['imp_overall_balanced_accuracy'] = df['overall_balanced_accuracy'] - bch_overall_balanced_accuracy\n",
    "    df['imp_overall_F1'] = df['overall_F1'] - bch_overall_F1_score\n",
    "\n",
    "    # Calculate cumulative retraining_time\n",
    "    df['cumulative_time'] = df['retraining_time'].cumsum()\n",
    "\n",
    "    # Calculate max and average improvements\n",
    "    df['max_topic_recall_imp'] = df[['imp_topic_recall']].max(axis=1).cummax()\n",
    "    df['average_topic_recall_imp'] = df[['imp_topic_recall']].mean(axis=1).expanding().mean()\n",
    "\n",
    "    df['max_topic_balanced_acc_imp'] = df[['imp_topic_balanced_accuracy']].max(axis=1).cummax()\n",
    "    df['average_topic_balanced_acc_imp'] = df[['imp_topic_balanced_accuracy']].mean(axis=1).expanding().mean()\n",
    "\n",
    "    df['max_overall_balanced_acc_imp'] = df[['imp_overall_balanced_accuracy']].max(axis=1).cummax()\n",
    "    df['average_overall_balanced_acc_imp'] = df[['imp_overall_balanced_accuracy']].mean(axis=1).expanding().mean()\n",
    "\n",
    "    df['max_overall_F1_improvement'] = df[['imp_overall_F1']].max(axis=1).cummax()\n",
    "    df['average_overall_F1_improvement'] = df[['imp_overall_F1']].mean(axis=1).expanding().mean()\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bch_classification_report_to_df(report, y_true, y_pred):\n",
    "    global bch_class_df\n",
    "    global topic_dict\n",
    "    df = pd.DataFrame(report).transpose()\n",
    "\n",
    "    # Calculate the confusion matrix\n",
    "    labels = df.index[:-3]  # Exclude 'accuracy', 'macro avg', 'weighted avg'\n",
    "    # Calculate the confusion matrix\n",
    "    cm = confusion_matrix(y_true, y_pred, labels=labels)\n",
    "\n",
    "    # Extracting TP, FP, TN, FN for each class\n",
    "    TP = cm.diagonal()\n",
    "    FP = cm.sum(axis=0) - TP\n",
    "    FN = cm.sum(axis=1) - TP\n",
    "    TN = cm.sum() - (FP + FN + TP)\n",
    "\n",
    "    sens = sum(TP) / (sum(TP)+sum(FN))\n",
    "    spec = sum(TN) / (sum(TN)+sum(FP))\n",
    "    \n",
    "    # Calculate Sensitivity (same as recall)\n",
    "    df['Sensitivity'] = df['recall']\n",
    "    \n",
    "    # Calculate Specificity\n",
    "    tn = cm.sum() - (cm.sum(axis=0) + cm.sum(axis=1) - np.diag(cm))\n",
    "    fp = cm.sum(axis=0) - np.diag(cm)\n",
    "    specificity = tn / (tn + fp)\n",
    "    \n",
    "    # Assign computed specificity to dataframe except for the last three rows\n",
    "    df.loc[df.index[:-3], 'Specificity'] = specificity\n",
    "    \n",
    "    # Handling special cases\n",
    "    # Set 'accuracy' row sensitivity and specificity to the accuracy value\n",
    "    accuracy = df.loc['accuracy', 'precision']  # assuming 'precision' contains the accuracy\n",
    "    df.loc['accuracy', ['Sensitivity', 'Specificity']] = sens, spec\n",
    "    \n",
    "    # Calculate 'macro avg' and 'weighted avg' for sensitivity and specificity\n",
    "    df.loc['macro avg', 'Sensitivity'] = df.iloc[:-3]['Sensitivity'].mean()\n",
    "    df.loc['weighted avg', 'Sensitivity'] = np.average(df.iloc[:-3]['Sensitivity'], weights=df.iloc[:-3]['support'])\n",
    "    \n",
    "    df.loc['macro avg', 'Specificity'] = df.iloc[:-3]['Specificity'].mean()\n",
    "    df.loc['weighted avg', 'Specificity'] = np.average(df.iloc[:-3]['Specificity'], weights=df.iloc[:-3]['support'])\n",
    "\n",
    "    # Calculate Balanced Accuracy for each row, including special averages\n",
    "    df['Balanced Accuracy'] = (df['Sensitivity'] + df['Specificity']) / 2\n",
    "    \n",
    "    return df\n",
    "\n",
    "def train_bch(X_train_re, X_test_re, Y_train_re, Y_test_re, catboost_params, itr0_path):\n",
    "    global X_test_re_Test\n",
    "    global Y_test_re_Test\n",
    "    CPU_monitor_memory_usage()\n",
    "    monitor_gpu_memory()\n",
    "    bch_dict = {}\n",
    "\n",
    "    train_pool_re = Pool(\n",
    "        X_train_re[[\"text\", \"area_TEIS\"]],\n",
    "        Y_train_re,\n",
    "        text_features=[\"text\"],\n",
    "        cat_features=[\"area_TEIS\"]\n",
    "    )\n",
    "    valid_pool_re = Pool(\n",
    "        X_test_re[[\"text\", \"area_TEIS\"]],\n",
    "        Y_test_re,\n",
    "        text_features=[\"text\"],\n",
    "        cat_features=[\"area_TEIS\"]\n",
    "    )\n",
    "\n",
    "    # Model Training\n",
    "    model_re = CatBoostClassifier(**catboost_params)\n",
    "    start_time = time.time()  # Start timing\n",
    "    model_re.fit(train_pool_re, eval_set=valid_pool_re)\n",
    "    training_time = time.time() - start_time  # End timing\n",
    "\n",
    "    # Save the retrain performances\n",
    "    val_predictions = model_re.predict(X_test_re[[\"text\", \"area_TEIS\"]])\n",
    "    val_accuracy = accuracy_score(Y_test_re, val_predictions)\n",
    "    val_report = classification_report(Y_test_re, val_predictions, digits=3, output_dict=True)\n",
    "    print(val_accuracy)\n",
    "    # print(report)\n",
    "    val_classification_df = bch_classification_report_to_df(val_report, Y_test_re, val_predictions)\n",
    "    # print(classification_df)\n",
    "    val_classification_df.to_pickle(f\"{itr0_path}/Validation_Benchmark_M0_Classdf_0.pkl\")\n",
    "    val_classification_df.to_csv(f\"{itr0_path}/Validation_Benchmark_M0_Classdf_0.csv\", index=True)\n",
    "\n",
    "    # Save the retrain performances\n",
    "    predictions = model_re.predict(X_test_re_Test[[\"text\", \"area_TEIS\"]])\n",
    "    accuracy = accuracy_score(Y_test_re_Test, predictions)\n",
    "    report = classification_report(Y_test_re_Test, predictions, digits=3, output_dict=True)\n",
    "    print(accuracy)\n",
    "    # print(report)\n",
    "    classification_df = bch_classification_report_to_df(report, Y_test_re_Test, predictions)\n",
    "    # print(classification_df)\n",
    "\n",
    "    classification_df.to_pickle(f\"{itr0_path}/Benchmark_M0_Classdf_0.pkl\")\n",
    "    classification_df.to_csv(f\"{itr0_path}/Benchmark_M0_Classdf_0.csv\", index=True)\n",
    "\n",
    "    bch_dict['model'] = model_re\n",
    "    bch_dict['classification_df'] = classification_df\n",
    "    bch_dict['accuracy'] = accuracy\n",
    "    bch_dict['retraining_time'] = training_time\n",
    "\n",
    "    return bch_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    run = 1\n",
    "    rand = 10\n",
    "    \n",
    "    run_path = f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}\"\n",
    "    itr0_path = f\"{run_path}/Iteration_0\"\n",
    "    os.makedirs(itr0_path, exist_ok=True)\n",
    "\n",
    "    # Load Data\n",
    "    data = pd.read_csv(f'D:/AutoGeTS/Data/tickets_topics.csv',lineterminator='\\n')\n",
    "    data_topic = data.dropna().reset_index()\n",
    "    data_topic = data_topic.rename(columns={'index': 'index_meta'})\n",
    "\n",
    "    X_train_r_both, X_test_re_Test, Y_train_r_both, Y_test_re_Test = train_test_split(data_topic, data_topic.topic_name, test_size = 0.2,random_state = 42)\n",
    "        \n",
    "    # Further split the training set to create a validation set\n",
    "    X_train_r, X_test_re, Y_train_r, Y_test_re = train_test_split(\n",
    "        X_train_r_both, \n",
    "        Y_train_r_both, \n",
    "        test_size=0.2,  # 20% of the initial training set, which is 16% of the original data\n",
    "        random_state=rand\n",
    "    )\n",
    "\n",
    "    catboost_params = {'iterations': 300, 'learning_rate': 0.2, 'depth': 8, 'l2_leaf_reg': 1, \n",
    "                        'bagging_temperature': 1, 'random_strength': 1, 'border_count': 254, \n",
    "                        'eval_metric': 'TotalF1', 'task_type': 'GPU', 'early_stopping_rounds': 20, 'use_best_model': True, 'verbose': 0, 'random_seed': rand}\n",
    "\n",
    "    bch_dict = train_bch(X_train_r, X_test_re, Y_train_r, Y_test_re, catboost_params, itr0_path)\n",
    "    for iteration in [1, 2, 3, 4, 5]:\n",
    "        # if iteration <= 1:\n",
    "        #     continue\n",
    "        for topic_number in [\"T13\"]: # \"T11\", \"T15\"\n",
    "            prev_itr = iteration - 1\n",
    "\n",
    "            gpu_hours = 1\n",
    "\n",
    "            if topic_number in []:\n",
    "                metric_name = \"overall_balanced_accuracy\" # \"recall\", \"Balanced Accuracy\", \"overall_balanced_accuracy\", \"overall_f1-score\"\n",
    "            elif topic_number in [\"T1\"]:\n",
    "                metric_name = \"overall_f1-score\"\n",
    "            elif topic_number in [\"T11\", \"T12\", \"T13\", \"T14\", \"T15\"]:\n",
    "                metric_name = \"Balanced Accuracy\"\n",
    "\n",
    "            nsgaii_results_path = f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{iteration}/{topic_number}_{metric_name}_GPU{gpu_hours}h_GA\"\n",
    "            os.makedirs(nsgaii_results_path, exist_ok=True)\n",
    "\n",
    "            fold_pfs_df = pd.DataFrame() \n",
    "            gen_eval_df = pd.DataFrame() \n",
    "            gen_pfs_df = pd.DataFrame() \n",
    "            test_gen_eval_df = pd.DataFrame() \n",
    "\n",
    "            init_counter = 0\n",
    "            generation_counter = -1\n",
    "            history_IndexesList_dict = {}\n",
    "            history_pareto_selections_list = []\n",
    "            gen_stats_df = pd.DataFrame(columns=[\"Generation\", \"Number of Evaluations\", \"Best Fitness\", \"Worst Fitness\", \"Mean Overall Accuracy\", \"Mean Topic Recall\", \"Pareto Front Selections\"])\n",
    "\n",
    "            if iteration > 1:\n",
    "                bch_class_df = pd.read_pickle(f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{prev_itr}/Bch_Itr_{prev_itr}.pkl\")\n",
    "                bch_filtered_columns = [col for col in bch_class_df.columns if not col.startswith(\"Diff\") and col != \"Accuracy\"]\n",
    "                bch_class_df = bch_class_df[bch_filtered_columns]\n",
    "            else:\n",
    "                bch_class_df = pd.read_pickle(f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_0/Benchmark_M0_Classdf_0.pkl\")\n",
    "\n",
    "            bch_m0 = pd.read_pickle(f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_0/Benchmark_M0_Classdf_0.pkl\")\n",
    "            \n",
    "            if iteration > 1:\n",
    "                prev_itr_X_train_re = pd.read_pickle(f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{prev_itr}/X_train_re_itr_{prev_itr}.pkl\")\n",
    "                X_train_r = prev_itr_X_train_re\n",
    "                prev_itr_Y_train_re = pd.read_pickle(f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{prev_itr}/Y_train_re_itr_{prev_itr}.pkl\")\n",
    "                Y_train_r = prev_itr_Y_train_re\n",
    "\n",
    "            topic_dict = {\"T1\": \"IT support and assistance.\",\"T2\": \"Account activation and access issues.\",\"T3\": \"Password and device security.\",\n",
    "                    \"T4\": \"Printer issues and troubleshooting.\",\"T5\": \"HP Dock connectivity issues.\",\"T6\": \"Employee documentation and errors.\",\n",
    "                    \"T7\": \"\\\"Access and login issues\\\"\",\"T8\": \"Opening and managing files/devices.\",\"T9\": \"Mobile email and VPN setup.\",\n",
    "                    \"T10\": \"IT support and communication.\",\"T11\": \"Error handling in RPG programming.\",\"T12\": \"Email security and attachments.\",\n",
    "                    \"T13\": \"Humanitarian aid for Ukraine.\",\"T14\": \"Internet connectivity issues in offices.\",\"T15\": \"Improving integration with Infojobs.\", \"Acc\": \"accuracy\"}\n",
    "            \n",
    "            low_recall_topics_order = [[\"Acc\", \"T13\"], [\"T15\", \"T12\", \"T10\", \"T14\"], [\"T6\", \"T8\", \"T9\", \"T7\", \"T5\", \"T4\"], [\"T3\", \"T2\", \"T11\", \"T1\"]]\n",
    "            large_size_topics_order = [[\"Acc\"], [\"T2\", \"T1\", \"T3\"], [\"T5\", \"T7\", \"T6\", \"T10\", \"T4\", \"T9\", \"T8\"], [\"T14\", \"T15\", \"T11\", \"T12\", \"T13\"]]\n",
    "            topic_group_probabilities = [0.4, 0.3, 0.2, 0.1]\n",
    "            topic_name_cases_order = [[topic_dict[topic_number] for topic_number in sublist] for sublist in low_recall_topics_order]\n",
    "            topic_name_sizes_order = [[topic_dict[topic_number] for topic_number in sublist] for sublist in large_size_topics_order]\n",
    "\n",
    "\n",
    "            \"\"\"Section below are changable parameters/inputs----------------------------\"\"\"\n",
    "            # topic_number = \"T13\"\n",
    "            syn_number = 1\n",
    "\n",
    "            total_gpu_seconds = gpu_hours * 60 * 60\n",
    "            \n",
    "            # Added synthetic data path\n",
    "            if topic_number in [\"T1\", \"T2\"]:\n",
    "                data_syn_raw = pd.read_pickle(f'D:/AutoGeTS/Synthetic_Data/{topic_number}-synthesis-{syn_number}.pkl')\n",
    "            else:\n",
    "                data_syn_raw = pd.read_csv(f'D:/AutoGeTS/Synthetic_Data/{topic_number}-synthesis-{syn_number}.csv',lineterminator='\\n')\n",
    "            data_syn = data_syn_raw[[\"index_meta\", \"text\", \"area_TEIS\", 'topic_name', \"sample\"]].dropna()\n",
    "\n",
    "            Xmode = \"\" # \"\", \"Xnp\"\n",
    "            Smode = \"Sovl\"\n",
    "            population_size = 20\n",
    "            num_selected = 20\n",
    "            max_generations = 15\n",
    "            tournament_size = 3\n",
    "            crossover_rate = 0.7\n",
    "            initial_mutation_rate = 0.3\n",
    "            maximize=True\n",
    "\n",
    "            NSGA_II_results_name = f\"{topic_number}_{Xmode}_LexClS-{Smode}_PopSize{population_size}_NumSel{num_selected}_MaxGen{max_generations}_CR{crossover_rate}_MR{initial_mutation_rate}\"\n",
    "            gen_stats_df_name = f\"NSGAII-Gen-Stats_{topic_number}_{Xmode}_LexClS-{Smode}_PopSize{population_size}_NumSel{num_selected}_MaxGen{max_generations}_CR{crossover_rate}_MR{initial_mutation_rate}\"\n",
    "            history_dict_name = f\"NSGAII-Retrain-Dict_{topic_number}_{Xmode}_LexClS-{Smode}_PopSize{population_size}_NumSel{num_selected}_MaxGen{max_generations}_CR{crossover_rate}_MR{initial_mutation_rate}\"\n",
    "            final_pop_name = f\"NSGAII-FinalPop_{topic_number}_{Xmode}_LexClS-{Smode}_PopSize{population_size}_NumSel{num_selected}_MaxGen{max_generations}_CR{crossover_rate}_MR{initial_mutation_rate}\"\n",
    "            history_pareto_lists_name = f\"NSGAII-HistoryPareto-List_{topic_number}_{Xmode}_LexClS-{Smode}_PopSize{population_size}_NumSel{num_selected}_MaxGen{max_generations}_CR{crossover_rate}_MR{initial_mutation_rate}\"\n",
    "            \n",
    "            # catboost_params = {'iterations': 300, 'learning_rate': 0.2, 'depth': 6, 'l2_leaf_reg': 1, \n",
    "            #                    'bagging_temperature': 1, 'random_strength': 1, 'border_count': 254, \n",
    "            #                    'eval_metric': 'TotalF1', 'task_type': 'GPU', 'early_stopping_rounds': 20, 'use_best_model': True, 'verbose': 1, 'random_seed': 0}\n",
    "\n",
    "            # catboost_params = {'iterations': 300, 'learning_rate': 0.2, 'depth': 8, 'l2_leaf_reg': 3, \n",
    "            #                    'bagging_temperature': 1, 'random_strength': 1, 'border_count': 254, \n",
    "            #                    'eval_metric': 'TotalF1', 'task_type': 'GPU', 'early_stopping_rounds': 20, 'use_best_model': True, 'verbose': 1, 'random_seed': 0}\n",
    "\n",
    "            # catboost_params = {'iterations': 300, 'learning_rate': 0.5, 'depth': 6, 'l2_leaf_reg': 10, \n",
    "            #                    'bagging_temperature': 1, 'random_strength': 1, 'border_count': 254, \n",
    "            #                    'eval_metric': 'TotalF1', 'task_type': 'GPU', 'early_stopping_rounds': 20, 'use_best_model': True, 'verbose': 1, 'random_seed': 0}\n",
    "\n",
    "            \"\"\"-----------------------------------\"\"\"\n",
    "\n",
    "            topic_name = topic_dict[topic_number]\n",
    "            clean_topic_name = clean_folder_name(topic_name)\n",
    "\n",
    "            sum_GPU_seconds = 0\n",
    "            GPU_limit = False\n",
    "\n",
    "            class_data_pool = X_train_r[X_train_r['topic_name'] == topic_name]\n",
    "            class_index_meta_pool = class_data_pool['index_meta'].tolist()\n",
    "\n",
    "            args={\"data_syn\": data_syn,\n",
    "                \"max_generations\": max_generations,\n",
    "                \"num_selected\": num_selected,  # Or another suitable size\n",
    "                \"crossover_rate\": crossover_rate,\n",
    "                \"initial_mutation_rate\": initial_mutation_rate,\n",
    "                \"class_index_meta_pool\": class_index_meta_pool,\n",
    "                    }\n",
    "\n",
    "\n",
    "            final_pop = run_nsga2(args, population_size=population_size, maximize=maximize, max_generations=max_generations, num_selected=num_selected, seed=42)\n",
    "            if GPU_limit == True:\n",
    "                gen_eval_df  = find_pareto_front(gen_eval_df)\n",
    "                gen_eval_df = find_best_values(gen_eval_df)\n",
    "                gen_eval_df = post_process(gen_eval_df, bch_class_df)\n",
    "                gen_eval_df.to_csv(f'{nsgaii_results_path}/GenAllEvals_{NSGA_II_results_name}.csv', index=True)\n",
    "                gen_eval_df.to_pickle(f'{nsgaii_results_path}/GenAllEvals_{NSGA_II_results_name}.pkl')\n",
    "                # break\n",
    "                test_gen_eval_df  = find_pareto_front(test_gen_eval_df)\n",
    "                test_gen_eval_df = find_best_values(test_gen_eval_df)\n",
    "                test_gen_eval_df = post_process(test_gen_eval_df, bch_class_df)\n",
    "                test_gen_eval_df.to_csv(f'{nsgaii_results_path}/test_GenAllEvals_{NSGA_II_results_name}.csv', index=True)\n",
    "                test_gen_eval_df.to_pickle(f'{nsgaii_results_path}/test_GenAllEvals_{NSGA_II_results_name}.pkl')\n",
    "                # break\n",
    "            \n",
    "            \"\"\"Extract best model and append synthetics\"\"\"\n",
    "            # Find the index of the row with the largest value in the 'max_overall_balanced_acc_imp' column\n",
    "            index_of_max_imp = test_gen_eval_df['imp_topic_balanced_accuracy'].idxmax()\n",
    "            print(index_of_max_imp)\n",
    "\n",
    "            # Retrieve the row corresponding to this index\n",
    "            row_with_largest_value = test_gen_eval_df.loc[index_of_max_imp]\n",
    "\n",
    "            filtered_syn_df = data_syn[data_syn['index_meta'].isin(row_with_largest_value['retrained_dots_list'])]\n",
    "\n",
    "            X_train_re = pd.concat([X_train_r, filtered_syn_df.drop(columns=['topic_name'])])\n",
    "            Y_train_re = pd.concat([Y_train_r, filtered_syn_df['topic_name']])\n",
    "\n",
    "            train_pool_re = Pool(\n",
    "                X_train_re[[\"text\", \"area_TEIS\"]],\n",
    "                Y_train_re,\n",
    "                text_features=[\"text\"],\n",
    "                cat_features=[\"area_TEIS\"]\n",
    "            )\n",
    "            valid_pool_re = Pool(\n",
    "                X_test_re[[\"text\", \"area_TEIS\"]],\n",
    "                Y_test_re,\n",
    "                text_features=[\"text\"],\n",
    "                cat_features=[\"area_TEIS\"]\n",
    "            )\n",
    "\n",
    "            catboost_params = catboost_params\n",
    "                        \n",
    "            # Model Training\n",
    "            model_re = CatBoostClassifier(**catboost_params)\n",
    "            # start_time = time.time()  # Start timing\n",
    "            model_re.fit(train_pool_re, eval_set=valid_pool_re)\n",
    "            # training_time = time.time() - start_time  # End timing\n",
    "\n",
    "            # Save the retrain performances\n",
    "            predictions = model_re.predict(X_test_re_Test[[\"text\", \"area_TEIS\"]])\n",
    "            accuracy = accuracy_score(Y_test_re_Test, predictions)\n",
    "            report = classification_report(Y_test_re_Test, predictions, digits=6, output_dict=True)\n",
    "            classification_df = classification_report_to_df(report, Y_test_re_Test, predictions)\n",
    "\n",
    "            print(classification_df)\n",
    "\n",
    "            iteration_repo = f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{iteration}/\"\n",
    "            os.makedirs(iteration_repo, exist_ok=True)\n",
    "            if classification_df.loc[topic_name, 'Diff Balanced Accuracy'] >= 0:\n",
    "                classification_df.to_csv(f\"{iteration_repo}/Bch_Itr_{iteration}.csv\", index=True)\n",
    "                classification_df.to_pickle(f\"{iteration_repo}/Bch_Itr_{iteration}.pkl\")\n",
    "\n",
    "                X_train_re.to_pickle(f\"{iteration_repo}/X_train_re_itr_{iteration}.pkl\")\n",
    "                Y_train_re.to_pickle(f\"{iteration_repo}/Y_train_re_itr_{iteration}.pkl\")\n",
    "            else:\n",
    "                iteration_noimprove_repo = f\"D:/Step_2_Pathway/Paper_GPU1h_Improve_T13_TBA/B3_Run_{run}/Iteration_{iteration}/Itr_No_Improve\"\n",
    "                os.makedirs(iteration_noimprove_repo, exist_ok=True)\n",
    "                classification_df.to_csv(f\"{iteration_noimprove_repo}/Bch_Itr_{iteration}.csv\", index=True)\n",
    "                classification_df.to_pickle(f\"{iteration_noimprove_repo}/Bch_Itr_{iteration}.pkl\")\n",
    "\n",
    "                X_train_re.to_pickle(f\"{iteration_noimprove_repo}/X_train_re_itr_{iteration}.pkl\")\n",
    "                Y_train_re.to_pickle(f\"{iteration_noimprove_repo}/Y_train_re_itr_{iteration}.pkl\")\n",
    "\n",
    "                bch_class_df.to_csv(f\"{iteration_repo}/Bch_Itr_{iteration}.csv\", index=True)\n",
    "                bch_class_df.to_pickle(f\"{iteration_repo}/Bch_Itr_{iteration}.pkl\")\n",
    "\n",
    "                prev_itr_X_train_re.to_pickle(f\"{iteration_repo}/X_train_re_itr_{iteration}.pkl\")\n",
    "                prev_itr_Y_train_re.to_pickle(f\"{iteration_repo}/Y_train_re_itr_{iteration}.pkl\")\n",
    "            # with open(f'{nsgaii_results_path}/{history_dict_name}.pkl', 'wb') as file:\n",
    "            #     pickle.dump(history_IndexesList_dict, file)\n",
    "\n",
    "            # with open(f'{nsgaii_results_path}/{final_pop_name}.pkl', 'wb') as file:\n",
    "            #     pickle.dump(final_pop, file)\n",
    "\n",
    "            # with open(f'{nsgaii_results_path}/{history_pareto_lists_name}.pkl', 'wb') as file:\n",
    "            #     pickle.dump(history_pareto_selections_list, file)\n",
    "            \n",
    "            # gen_stats_df = all_pareto_observer(history_pareto_selections_list)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(gen_eval_df.loc[0, 'overall_balanced_accuracy'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
