{"cells":[{"cell_type":"code","execution_count":null,"id":"7d177a43","metadata":{"id":"7d177a43"},"outputs":[],"source":["from IPython.core.display import display, HTML\n","\n","display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n"]},{"cell_type":"code","execution_count":null,"id":"7d506fbf","metadata":{"id":"7d506fbf"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","\n","class KANLinear(nn.Module):\n","    def __init__(self, in_features, out_features, wavelet_type='haar'):\n","        super(KANLinear, self).__init__()\n","        self.in_features = in_features\n","        self.out_features = out_features\n","        self.wavelet_type = wavelet_type\n","\n","        # Parameters for wavelet transformation\n","        self.scale = nn.Parameter(torch.ones(out_features, in_features))\n","        self.translation = nn.Parameter(torch.zeros(out_features, in_features))\n","        self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))\n","        self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))\n","\n","        nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))\n","        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))\n","\n","        # Batch normalization\n","        self.bn = nn.BatchNorm1d(out_features)\n","\n","    def wavelet_transform(self, x):\n","        x_expanded = x.unsqueeze(1) if x.dim() == 2 else x\n","        translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)\n","        scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)\n","        x_scaled = (x_expanded - translation_expanded) / scale_expanded\n","\n","        if self.wavelet_type == 'haar':\n","            wavelet = self.haar_wavelet(x_scaled)\n","        elif self.wavelet_type == 'coiflet':\n","            wavelet = self.coiflet_wavelet(x_scaled)\n","        elif self.wavelet_type == 'biorthogonal':\n","            wavelet = self.biorthogonal_wavelet(x_scaled)\n","        elif self.wavelet_type == 'daubechies':\n","            wavelet = self.daubechies_wavelet(x_scaled)\n","        elif self.wavelet_type == 'mexican_hat':\n","            term1 = ((x_scaled ** 2)-1)\n","            term2 = torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2\n","        elif self.wavelet_type == 'morlet':\n","            omega0 = 5.0  # Central frequency\n","            real = torch.cos(omega0 * x_scaled)\n","            envelope = torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = envelope * real\n","        elif self.wavelet_type == 'dog':\n","            dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = dog\n","        elif self.wavelet_type == 'meyer':\n","            pi = math.pi\n","            v = torch.abs(x_scaled)\n","            wavelet = torch.sin(pi * v) * self.meyer_aux(v)\n","        elif self.wavelet_type == 'symlet4':\n","            wavelet = self.symlet4_wavelet(x_scaled)\n","        elif self.wavelet_type == 'shannon':\n","            pi = math.pi\n","            sinc = torch.sinc(x_scaled / pi)\n","            window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)\n","            wavelet = sinc * window\n","        else:\n","            raise ValueError(\"Unsupported wavelet type\")\n","\n","        wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n","        wavelet_output = wavelet_weighted.sum(dim=2)\n","        return wavelet_output\n","\n","    def haar_wavelet(self, x):\n","        return torch.where(x < 0.5, 1., torch.where(x < 1., -1., 0.))\n","\n","    def coiflet_wavelet(self, x):\n","        c = [0.038580, -0.126969, -0.077974, 0.417845, 0.812866, 0.468429]\n","        return sum(c[i] * x**i for i in range(len(c)))\n","\n","    def biorthogonal_wavelet(self, x):\n","        return 3*x**2 - 2*x**3\n","\n","    def symlet4_wavelet(self, x):\n","        # Symlet 4 (sym4) wavelet coefficients\n","        coeffs = [\n","            -0.075765714789273,\n","            0.029635527645999,\n","            0.497618667632015,\n","            0.803738751805916,\n","            0.297857795605542,\n","            -0.099219543576847,\n","            -0.012603967262038,\n","            0.032223100604042\n","        ]\n","\n","        # Convolve the input with the wavelet coefficients\n","        return sum(coeffs[i] * x**i for i in range(len(coeffs)))\n","\n","    def daubechies_wavelet(self, x):\n","        sqrt3 = math.sqrt(3)\n","        coeff1 = (1 + sqrt3) / (4 * math.sqrt(2))\n","        coeff2 = (3 + sqrt3) / (4 * math.sqrt(2))\n","        coeff3 = (3 - sqrt3) / (4 * math.sqrt(2))\n","        coeff4 = (1 - sqrt3) / (4 * math.sqrt(2))\n","        return coeff1 - coeff2 * x + coeff3 * x**2 - coeff4 * x**3\n","\n","    def meyer_aux(self, v):\n","        pi = math.pi\n","        return torch.where(v <= 1/2, torch.ones_like(v), torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * self.nu(2 * v - 1))))\n","\n","    def nu(self, t):\n","        return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)\n","\n","    def forward(self, x):\n","        wavelet_output = self.wavelet_transform(x)\n","        base_output = F.linear(x, self.weight1)\n","        combined_output = wavelet_output + base_output\n","        return self.bn(combined_output)\n"]},{"cell_type":"code","execution_count":null,"id":"2eb7b10c","metadata":{"id":"2eb7b10c"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torch.utils.data import DataLoader, TensorDataset\n","import numpy as np\n","import pandas as pd\n","from sklearn.preprocessing import StandardScaler, RobustScaler, MaxAbsScaler, PowerTransformer, MinMaxScaler\n","\n","class Generator(nn.Module):\n","    def __init__(self, latent_dim, num_features, num_classes):\n","        super(Generator, self).__init__()\n","        self.label_emb = nn.Embedding(num_classes, num_classes)\n","        self.model = nn.Sequential(\n","            nn.Linear(latent_dim + num_classes, 128),\n","            nn.ReLU(True),\n","            nn.Linear(128, num_features),\n","            nn.Tanh()\n","        )\n","\n","    def forward(self, z, labels):\n","        c = self.label_emb(labels)\n","        x = torch.cat([z, c], 1)\n","        output = self.model(x)\n","        return output\n","\n","class Discriminator(nn.Module):\n","    def __init__(self, num_features, num_classes):\n","        super(Discriminator, self).__init__()\n","        self.label_embedding = nn.Embedding(num_classes, num_classes)\n","        self.model = nn.Sequential(\n","            nn.Linear(num_features + num_classes, 128),\n","            nn.ReLU(True),\n","            nn.Linear(128, 1),\n","            nn.Sigmoid()\n","        )\n","\n","    def forward(self, x, labels):\n","        c = self.label_embedding(labels)\n","        x = torch.cat([x, c], 1)\n","        output = self.model(x)\n","        return output\n","\n","def train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes):\n","    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n","    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n","    criterion = nn.BCELoss()\n","\n","    generator.train()\n","    discriminator.train()\n","\n","    for epoch in range(num_epochs):\n","        for i, (data, labels) in enumerate(data_loader):\n","            batch_size = data.size(0)\n","\n","            real_data = data.to(device)\n","            real_labels = labels.to(device)\n","            valid = torch.ones(batch_size, 1, device=device)\n","            fake = torch.zeros(batch_size, 1, device=device)\n","\n","            # Train Discriminator\n","            optimizer_D.zero_grad()\n","            real_loss = criterion(discriminator(real_data, real_labels), valid)\n","            z = torch.randn(batch_size, latent_dim, device=device)\n","            gen_labels = torch.randint(0, num_classes, (batch_size,), device=device).to(device)\n","            fake_data = generator(z, gen_labels)\n","            fake_loss = criterion(discriminator(fake_data, gen_labels), fake)\n","            d_loss = (real_loss + fake_loss) / 2\n","            d_loss.backward()\n","            optimizer_D.step()\n","\n","            # Train Generator\n","            optimizer_G.zero_grad()\n","            z = torch.randn(batch_size, latent_dim, device=device)\n","            g_loss = criterion(discriminator(generator(z, gen_labels), gen_labels), valid)\n","            g_loss.backward()\n","            optimizer_G.step()\n","\n","        print(f'Epoch [{epoch+1}/{num_epochs}] Discriminator Loss: {d_loss.item()}, Generator Loss: {g_loss.item()}')\n","\n","def generate_synthetic_data(generator, num_samples, label, num_features, latent_dim, device):\n","    z = torch.randn(num_samples, latent_dim, device=device)\n","    labels = torch.full((num_samples,), label, dtype=torch.long, device=device)\n","    with torch.no_grad():\n","        synthetic_data = generator(z, labels).detach().cpu().numpy()\n","    return synthetic_data\n","\n","def load_and_preprocess_data(path,target_class_count=1500, latent_dim=100, num_epochs=100, batch_size=32):\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","    df = pd.read_csv(path)\n","    numerical_columns = [\n","        'peak_time', 'peak_time_ns', 'start_time', 'start_time_ns', 'duration',\n","        'peak_frequency', 'central_freq', 'bandwidth', 'amplitude', 'snr', 'q_value',\n","        '1400Ripples', '1080Lines', 'Air_Compressor', 'Blip', 'Chirp', 'Extremely_Loud',\n","        'Helix', 'Koi_Fish', 'Light_Modulation', 'Low_Frequency_Burst', 'Low_Frequency_Lines',\n","        'No_Glitch', 'None_of_the_Above', 'Paired_Doves', 'Power_Line', 'Repeating_Blips',\n","        'Scattered_Light', 'Scratchy', 'Tomte', 'Violin_Mode', 'Wandering_Line', 'Whistle'\n","    ]\n","\n","    df_numerical = df[numerical_columns]\n","    scaler = MinMaxScaler()\n","    df_numerical = scaler.fit_transform(df_numerical)\n","\n","\n","    string_labels = df['ml_label']\n","    label_encoder = LabelEncoder()\n","    labels = label_encoder.fit_transform(string_labels)\n","\n","    num_features = df_numerical.shape[1]\n","    num_classes = np.unique(labels).size\n","\n","    tensor_data = torch.tensor(df_numerical.astype(np.float32))\n","    tensor_labels = torch.tensor(labels)\n","    dataset = TensorDataset(tensor_data, tensor_labels)\n","    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n","\n","    generator = Generator(latent_dim, num_features, num_classes).to(device)\n","    discriminator = Discriminator(num_features, num_classes).to(device)\n","\n","    train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes)\n","\n","    for class_label in range(num_classes):\n","        current_count = np.sum(labels == class_label)\n","        if current_count < target_class_count:\n","            num_samples_needed = target_class_count - current_count\n","            synthetic_data = generate_synthetic_data(generator, num_samples_needed, class_label, num_features, latent_dim, device)\n","            synthetic_labels = np.full(num_samples_needed, class_label)\n","            df_numerical = np.vstack([df_numerical, synthetic_data])\n","            labels = np.concatenate([labels, synthetic_labels])\n","\n","    return df_numerical, labels\n","\n"]},{"cell_type":"code","execution_count":null,"id":"92eff356","metadata":{"id":"92eff356"},"outputs":[],"source":["import pandas as pd\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader, Dataset\n","from sklearn.preprocessing import LabelEncoder\n","from sklearn.preprocessing import StandardScaler, MinMaxScaler, MaxAbsScaler, PowerTransformer, RobustScaler\n","from tqdm import tqdm\n","np.random.seed(0)\n","\n","# Create pairs of augmented data\n","def create_pairs(data):\n","    augmented_original_pairs = []\n","    for row in data:\n","        augmented_row1 = augment_data(row)\n","        augmented_row2 = augment_data(row)\n","        augmented_original_pairs.append((augmented_row1, augmented_row2, row))\n","    return np.array(augmented_original_pairs)\n","\n","# Define an augmentation function\n","def augment_data(data, noise_level=0.1):\n","    noise = np.random.normal(0, noise_level, data.shape)\n","    return data + noise"]},{"cell_type":"code","execution_count":null,"id":"c8ce8478","metadata":{"id":"c8ce8478","scrolled":false},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader, Dataset\n","from tqdm import tqdm\n","\n","# Dataset class\n","class TrainDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['pos_1'])\n","\n","    def __getitem__(self, idx):\n","        pos_1 = self.data['pos_1'][idx]\n","        pos_2 = self.data['pos_2'][idx]\n","        return torch.tensor(pos_1, dtype=torch.float), torch.tensor(pos_2, dtype=torch.float)\n","\n","class MemoryDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['true_data'])\n","\n","    def __getitem__(self, idx):\n","        true_data = self.data['true_data'][idx]\n","        target = self.data['target'][idx]\n","        return torch.tensor(true_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)\n","\n","class TestDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['true_data'])\n","\n","    def __getitem__(self, idx):\n","        raw_data = self.data['true_data'][idx]\n","        target = self.data['target'][idx]\n","        return torch.tensor(raw_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)\n","\n","class NumericalModel(nn.Module):\n","    def __init__(self, input_dim, feature_dim=128, wavelet_type='mexican_hat'):\n","        super(NumericalModel, self).__init__()\n","\n","        self.channels = 1  # Number of channels in the input\n","        self.length = input_dim  # Length of the sequence\n","\n","        # Adjusted feature extractor for 1D data with Conv1D and MaxPool1D\n","        self.feature_extractor = nn.Sequential(\n","            nn.Conv1d(self.channels, 256, kernel_size=5, stride=1, padding=2),\n","            nn.ReLU(),\n","            nn.MaxPool1d(kernel_size=2, stride=2),\n","            nn.Conv1d(256, 256, kernel_size=5, stride=1, padding=2),\n","            nn.ReLU(),\n","            nn.MaxPool1d(kernel_size=2, stride=2),\n","            nn.Flatten(),\n","            nn.Linear(256 * (self.length // 4), 512),\n","            nn.ReLU()\n","        )\n","\n","        # Updated projection head with wavelet-based KANLinear\n","        self.projection_head = nn.Sequential(\n","            KANLinear(512, 256, wavelet_type=wavelet_type),\n","            nn.BatchNorm1d(256),\n","            nn.ReLU(inplace=True),\n","            KANLinear(256, feature_dim, wavelet_type=wavelet_type)\n","        )\n","\n","    def forward(self, x):\n","        # Reshape input to (batch size, channels, length)\n","        x = x.view(-1, self.channels, self.length)\n","        feature = self.feature_extractor(x)\n","        out = self.projection_head(feature)\n","        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)\n","\n"]},{"cell_type":"code","execution_count":null,"id":"385ce5c9","metadata":{"id":"385ce5c9"},"outputs":[],"source":["import torch\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import numpy as np\n","from geomloss import SamplesLoss\n","\n","# Function to create a mask for negative samples\n","def get_negative_mask(batch_size):\n","    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)\n","    for i in range(batch_size):\n","        negative_mask[i, i] = 0\n","        negative_mask[i, i + batch_size] = 0\n","    negative_mask = torch.cat((negative_mask, negative_mask), 0)\n","    return negative_mask\n","\n","def train(net, data_loader, train_optimizer, device, epoch, epochs, batch_size):\n","    net.train()\n","    total_loss, total_num = 0.0, 0\n","    train_bar = tqdm(data_loader)\n","\n","    for pos_1, pos_2 in train_bar:\n","        pos_1, pos_2 = pos_1.to(device), pos_2.to(device)\n","        feature_1, out_1 = net(pos_1)\n","        feature_2, out_2 = net(pos_2)\n","\n","        # neg score\n","        out = torch.cat([out_1, out_2], dim=0)\n","        actual_batch_size = pos_1.size(0)\n","        neg = torch.exp(torch.mm(out, out.t().contiguous()))\n","        old_neg = neg.clone()\n","        mask = get_negative_mask(actual_batch_size).to(device)\n","        neg = neg.masked_select(mask).view(2 * actual_batch_size, -1)\n","\n","        # pos score\n","        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) )\n","        pos = torch.cat([pos, pos], dim=0)\n","        Ng = neg.sum(dim=-1)\n","\n","        # loss_wasserstein distance\n","        loss_wasserstein = SamplesLoss(loss=\"sinkhorn\", p=2)\n","        pos_reshaped = pos.unsqueeze(1)\n","        Ng_reshaped = Ng.unsqueeze(1)\n","        loss_distances = loss_wasserstein(pos_reshaped, Ng_reshaped)\n","\n","        # contrastive loss\n","        loss = torch.log(loss_distances) - (-torch.log(pos / (pos + Ng))).mean()\n","\n","        train_optimizer.zero_grad()\n","        loss.backward()\n","        train_optimizer.step()\n","\n","        total_num += actual_batch_size\n","        total_loss += loss.item() * batch_size\n","        train_bar.set_description(f'Train Epoch: [{epoch}/{epochs}] Loss: {total_loss / total_num:.4f}')\n","\n","    return total_loss / total_num\n"]},{"cell_type":"code","execution_count":null,"id":"97c3e3de","metadata":{"id":"97c3e3de"},"outputs":[],"source":["import torch\n","from tqdm import tqdm\n","from sklearn.metrics import multilabel_confusion_matrix\n","import numpy as np\n","\n","def d_index_score_cal(y_true, y_pred):\n","    if len(y_true) == 0 or len(y_pred) == 0:\n","        return 0\n","\n","    # Ensure y_pred is not a list of lists\n","    if all(isinstance(pred, list) for pred in y_pred):\n","        y_pred = [1 if true_label in pred else 0 for true_label, pred in zip(y_true, y_pred)]\n","    else:\n","        # If y_pred is already flat, ensure it's the same length as y_true\n","        assert len(y_true) == len(y_pred), \"y_true and y_pred must have the same length\"\n","\n","    conf_matrix = multilabel_confusion_matrix(y_true, y_pred)\n","    d_indices = []\n","    for i in range(len(conf_matrix)):\n","        tn, fp, fn, tp = conf_matrix[i].ravel()\n","        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 1e-7\n","        specificity = tn / (tn + fp) if (tn + fp) > 0 else 1e-7\n","        accuracy = (tp + tn) / (tp + tn + fp + fn)\n","        d = np.log2(1 + accuracy) + np.log2(1 + (sensitivity + specificity) / 2)\n","        d_indices.append(d)\n","    avg_d_index = sum(d_indices) / len(d_indices) if len(d_indices) > 0 else 0\n","    return avg_d_index\n","\n","\n","def create_exact_match_list(top2_predictions, targets):\n","    exact_matches = []\n","    for pred_pair, target in zip(top2_predictions, targets):\n","        # Find the matching prediction, append -1 if no match is found\n","        match = next((pred for pred in pred_pair if pred == target), -1)\n","        exact_matches.append(match)\n","    return exact_matches\n"]},{"cell_type":"code","execution_count":null,"id":"429d9569","metadata":{"id":"429d9569"},"outputs":[],"source":["import torch\n","from tqdm import tqdm\n","\n","def flatten_predictions(predictions, true_labels):\n","    TP, FP, FN = 0, 0, 0\n","    for i, preds in enumerate(predictions):\n","        true_label = true_labels[i]\n","\n","        # Count true positives, false positives, false negatives\n","        if true_label in preds:\n","            TP += 1\n","        else:\n","            FN += 1\n","        FP += len([p for p in preds if p != true_label])\n","\n","    return  TP, FP, FN\n","\n","def test_knn(net, memory_data_loader, test_data_loader, k, c):\n","    net.eval()\n","    total_top1, total_top2, total_top3, total_num = 0.0, 0.0, 0.0, 0\n","    original_labels_list = []\n","    top1_predictions, top2_predictions, top3_predictions = [], [], []\n","    feature_bank = []\n","\n","    with torch.no_grad():\n","        # Generate feature bank\n","        for data, _ in tqdm(memory_data_loader, desc='Feature extracting'):\n","            feature = net(data.cuda(non_blocking=True))[0]\n","            feature_bank.append(feature)\n","        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()\n","\n","        feature_labels = torch.tensor([label for _, label in memory_data_loader.dataset], device=feature_bank.device)\n","\n","        test_bar = tqdm(test_data_loader)\n","        for data, target in test_bar:\n","            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n","            feature = net(data)[0]\n","\n","            total_num += data.size(0)\n","            sim_matrix = torch.mm(feature, feature_bank)\n","            sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)\n","            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)\n","            sim_weight = (sim_weight).exp()\n","\n","            one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)\n","            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0)\n","            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)\n","            pred_labels = pred_scores.argsort(dim=-1, descending=True)\n","\n","            for i in range(data.size(0)):\n","                original_labels_list.append(target[i].item())\n","                top1_predictions.append([pred_labels[i, 0].item()])\n","                top2_predictions.append(pred_labels[i, :2].tolist())\n","                top3_predictions.append(pred_labels[i, :3].tolist())\n","\n","            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","            total_top2 += torch.sum((pred_labels[:, :2] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","            total_top3 += torch.sum((pred_labels[:, :3] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","\n","    TP_top1, FP_top1, FN_top1 = flatten_predictions(top1_predictions, original_labels_list)\n","    TP_top2, FP_top2, FN_top2 = flatten_predictions(top2_predictions, original_labels_list)\n","    TP_top3, FP_top3, FN_top3 = flatten_predictions(top3_predictions, original_labels_list)\n","\n","    # Flatten predictions for d-index calculation\n","    flat_top1_predictions = [pred[0] for pred in top1_predictions]\n","    flat_top2_predictions = create_exact_match_list(top2_predictions, original_labels_list)\n","    flat_top3_predictions = create_exact_match_list(top3_predictions, original_labels_list)\n","\n","    precision_top1, recall_top1, f1_top1 = calculate_metrics(TP_top1, FP_top1, FN_top1)\n","    precision_top2, recall_top2, f1_top2 = calculate_metrics(TP_top2, FP_top2, FN_top2)\n","    precision_top3, recall_top3, f1_top3 = calculate_metrics(TP_top3, FP_top3, FN_top3)\n","\n","    # Calculate d-indices\n","    d_index_top1 = d_index_score_cal(original_labels_list, flat_top1_predictions)\n","    d_index_top2 = d_index_score_cal(original_labels_list, flat_top2_predictions)\n","    d_index_top3 = d_index_score_cal(original_labels_list, flat_top3_predictions)\n","\n","\n","    return {\n","        'top1': {'accuracy': total_top1 / total_num * 100, 'd_index': d_index_top1, 'precision': precision_top1, 'recall': recall_top1, 'f1': f1_top1},\n","        'top2': {'accuracy': total_top2 / total_num * 100, 'd_index': d_index_top2, 'precision': precision_top2, 'recall': recall_top2, 'f1': f1_top2},\n","        'top3': {'accuracy': total_top3 / total_num * 100, 'd_index': d_index_top3, 'precision': precision_top3, 'recall': recall_top3, 'f1': f1_top3}\n","    }\n","\n","def calculate_metrics(TP, FP, FN):\n","    precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n","    recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n","    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","    return precision, recall, f1\n","\n"]},{"cell_type":"code","execution_count":null,"id":"bceb7d15","metadata":{"id":"bceb7d15","scrolled":false},"outputs":[],"source":["import pandas as pd\n","import torch\n","from torch import nn\n","from torch.utils.data import DataLoader\n","\n","\n","def main():\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","    print(\"Using device:\", device)\n","\n","    epochs = 100\n","    k = 30\n","\n","    path = 'clean_data_O1.csv'\n","\n","    df_numerical, labels = load_and_preprocess_data(path)\n","    num_classes = len(np.unique(labels))\n","    augmented_original_pairs = create_pairs(df_numerical)\n","\n","    learning_rates =[1e-3]\n","    batch_sizes = [128]\n","\n","    for lr in learning_rates:\n","        for batch_size in batch_sizes:\n","            # Initialize a list to store metrics for each epoch\n","            all_metrics = []\n","            print(f\"Learning Rate: {lr}, Batch Size: {batch_size}\")\n","\n","            # Dataset and DataLoader setup\n","            train_data = {\"pos_1\": augmented_original_pairs[:, 0], \"pos_2\": augmented_original_pairs[:, 1]}\n","            train_dataset = TrainDataset(train_data)\n","            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n","\n","            memory_data = {\"true_data\": augmented_original_pairs[:, 2], 'target': labels}\n","            memory_dataset = MemoryDataset(memory_data)\n","            memory_data_loader = DataLoader(memory_dataset, batch_size=batch_size, shuffle=False)\n","\n","            test_data = {\"true_data\": augmented_original_pairs[:, 2], 'target': labels}\n","            test_dataset = TestDataset(test_data)\n","            test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n","\n","            model = NumericalModel(input_dim=df_numerical.shape[1], feature_dim=128).to(device)\n","            model = nn.DataParallel(model)\n","            optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)\n","\n","            for epoch in range(epochs):\n","                train_loss = train(model, train_loader, optimizer, device, epoch, epochs, batch_size)\n","                test_results = test_knn(model, memory_data_loader, test_loader, k, num_classes)\n","\n","                # Collect metrics\n","                epoch_metrics = {\n","                    'Epoch': epoch + 1,\n","                    'Learning Rate': lr,\n","                    'Batch Size': batch_size,\n","                    'Top1 Accuracy': test_results['top1']['accuracy'],\n","                    'Top1 D-Index': test_results['top1']['d_index'],\n","                    'Top1 Precision': test_results['top1']['precision'],\n","                    'Top1 Recall': test_results['top1']['recall'],\n","                    'Top1 F1': test_results['top1']['f1'],\n","                    'Top2 Accuracy': test_results['top2']['accuracy'],\n","                    'Top2 D-Index': test_results['top2']['d_index'],\n","                    'Top2 Precision': test_results['top2']['precision'],\n","                    'Top2 Recall': test_results['top2']['recall'],\n","                    'Top2 F1': test_results['top2']['f1'],\n","                    'Top3 Accuracy': test_results['top3']['accuracy'],\n","                    'Top3 D-Index': test_results['top3']['d_index'],\n","                    'Top3 Precision': test_results['top3']['precision'],\n","                    'Top3 Recall': test_results['top3']['recall'],\n","                    'Top3 F1': test_results['top3']['f1'],\n","                }\n","\n","                all_metrics.append(epoch_metrics)\n","\n","                # Unpack the results\n","                top1_metrics = test_results['top1']\n","                top2_metrics = test_results['top2']\n","                top3_metrics = test_results['top3']\n","\n","                # Access individual metrics\n","                top1_accuracy = top1_metrics['accuracy']\n","                d_index_top1 = top1_metrics['d_index']\n","                precision_top1 = top1_metrics['precision']\n","                recall_top1 = top1_metrics['recall']\n","                f1_top1 = top1_metrics['f1']\n","\n","                top2_accuracy = top2_metrics['accuracy']\n","                d_index_top2 = top2_metrics['d_index']\n","                precision_top2 = top2_metrics['precision']\n","                recall_top2 = top2_metrics['recall']\n","                f1_top2 = top2_metrics['f1']\n","\n","                top3_accuracy = top3_metrics['accuracy']\n","                d_index_top3 = top3_metrics['d_index']\n","                precision_top3 = top3_metrics['precision']\n","                recall_top3 = top3_metrics['recall']\n","                f1_top3 = top3_metrics['f1']\n","\n","                # Print Top1 Metrics\n","                print(\"Top 1 Metrics:\")\n","                print(f\"  Accuracy: {top1_accuracy:.8f}%\")\n","                print(f\"  D-Index: {d_index_top1:.8f}\")\n","                print(f\"  Precision: {precision_top1:.8f}\")\n","                print(f\"  Recall: {recall_top1:.8f}\")\n","                print(f\"  F1 Score: {f1_top1:.8f}\\n\")\n","\n","                # Print Top2 Metrics\n","                print(\"Top 2 Metrics:\")\n","                print(f\"  Accuracy: {top2_accuracy:.8f}%\")\n","                print(f\"  D-Index: {d_index_top2:.8f}\")\n","                print(f\"  Precision: {precision_top2:.8f}\")\n","                print(f\"  Recall: {recall_top2:.8f}\")\n","                print(f\"  F1 Score: {f1_top2:.8f}\\n\")\n","\n","                # Print Top3 Metrics\n","                print(\"Top 3 Metrics:\")\n","                print(f\"  Accuracy: {top3_accuracy:.8f}%\")\n","                print(f\"  D-Index: {d_index_top3:.8f}\")\n","                print(f\"  Precision: {precision_top3:.8f}\")\n","                print(f\"  Recall: {recall_top3:.8f}\")\n","                print(f\"  F1 Score: {f1_top3:.8f}\")\n","\n","            # Convert the list of metrics to a DataFrame and save it as a CSV file\n","            df_metrics = pd.DataFrame(all_metrics)\n","            filename = f'model_metrics_no_drop_alpha_wave_kan_{lr}_{batch_size}.csv'\n","            df_metrics.to_csv(filename, index=False)\n","            print(\"Model Saved!\")\n","\n","if __name__ == \"__main__\":\n","    main()\n"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3 (ipykernel)","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.9"}},"nbformat":4,"nbformat_minor":5}