{"cells":[{"cell_type":"code","execution_count":null,"id":"7d177a43","metadata":{"id":"7d177a43"},"outputs":[],"source":["from IPython.core.display import display, HTML\n","\n","display(HTML(\"<style>.container { width:100% !important; }</style>\"))\n"]},{"cell_type":"code","execution_count":null,"id":"05404b49","metadata":{"id":"05404b49","outputId":"a2e7a50a-06e4-4a2d-8f97-2afbf62939e1"},"outputs":[{"name":"stdout","output_type":"stream","text":["Pandas version: 2.0.3\n","NumPy version: 1.24.3\n","PyTorch version: 2.1.1\n","scikit-learn version: 1.3.0\n","tqdm version: 4.66.1\n"]}],"source":["import pandas as pd\n","import numpy as np\n","import torch\n","import sklearn\n","import tqdm\n","\n","print(\"Pandas version:\", pd.__version__)\n","print(\"NumPy version:\", np.__version__)\n","print(\"PyTorch version:\", torch.__version__)\n","print(\"scikit-learn version:\", sklearn.__version__)\n","print(\"tqdm version:\", tqdm.__version__)\n"]},{"cell_type":"code","execution_count":null,"id":"04ed4c15","metadata":{"id":"04ed4c15"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import math\n","\n","class KANLinear(nn.Module):\n","    def __init__(self, in_features, out_features, wavelet_type='haar'):\n","        super(KANLinear, self).__init__()\n","        self.in_features = in_features\n","        self.out_features = out_features\n","        self.wavelet_type = wavelet_type\n","\n","        # Parameters for wavelet transformation\n","        self.scale = nn.Parameter(torch.ones(out_features, in_features))\n","        self.translation = nn.Parameter(torch.zeros(out_features, in_features))\n","        self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))\n","        self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))\n","\n","        nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))\n","        nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))\n","\n","        # Batch normalization\n","        self.bn = nn.BatchNorm1d(out_features)\n","\n","    def wavelet_transform(self, x):\n","        x_expanded = x.unsqueeze(1) if x.dim() == 2 else x\n","        translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)\n","        scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)\n","        x_scaled = (x_expanded - translation_expanded) / scale_expanded\n","\n","        if self.wavelet_type == 'haar':\n","            wavelet = self.haar_wavelet(x_scaled)\n","        elif self.wavelet_type == 'coiflet':\n","            wavelet = self.coiflet_wavelet(x_scaled)\n","        elif self.wavelet_type == 'biorthogonal':\n","            wavelet = self.biorthogonal_wavelet(x_scaled)\n","        elif self.wavelet_type == 'daubechies':\n","            wavelet = self.daubechies_wavelet(x_scaled)\n","        elif self.wavelet_type == 'mexican_hat':\n","            term1 = ((x_scaled ** 2)-1)\n","            term2 = torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2\n","        elif self.wavelet_type == 'morlet':\n","            omega0 = 5.0  # Central frequency\n","            real = torch.cos(omega0 * x_scaled)\n","            envelope = torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = envelope * real\n","        elif self.wavelet_type == 'dog':\n","            dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)\n","            wavelet = dog\n","        elif self.wavelet_type == 'meyer':\n","            pi = math.pi\n","            v = torch.abs(x_scaled)\n","            wavelet = torch.sin(pi * v) * self.meyer_aux(v)\n","        elif self.wavelet_type == 'symlet4':\n","            wavelet = self.symlet4_wavelet(x_scaled)\n","        elif self.wavelet_type == 'shannon':\n","            pi = math.pi\n","            sinc = torch.sinc(x_scaled / pi)\n","            window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype, device=x_scaled.device)\n","            wavelet = sinc * window\n","        else:\n","            raise ValueError(\"Unsupported wavelet type\")\n","\n","        wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n","        wavelet_output = wavelet_weighted.sum(dim=2)\n","        return wavelet_output\n","\n","    def haar_wavelet(self, x):\n","        return torch.where(x < 0.5, 1., torch.where(x < 1., -1., 0.))\n","\n","    def coiflet_wavelet(self, x):\n","        c = [0.038580, -0.126969, -0.077974, 0.417845, 0.812866, 0.468429]\n","        return sum(c[i] * x**i for i in range(len(c)))\n","\n","    def biorthogonal_wavelet(self, x):\n","        return 3*x**2 - 2*x**3\n","\n","    def symlet4_wavelet(self, x):\n","        # Symlet 4 (sym4) wavelet coefficients\n","        coeffs = [\n","            -0.075765714789273,\n","            0.029635527645999,\n","            0.497618667632015,\n","            0.803738751805916,\n","            0.297857795605542,\n","            -0.099219543576847,\n","            -0.012603967262038,\n","            0.032223100604042\n","        ]\n","\n","        # Convolve the input with the wavelet coefficients\n","        return sum(coeffs[i] * x**i for i in range(len(coeffs)))\n","\n","    def daubechies_wavelet(self, x):\n","        sqrt3 = math.sqrt(3)\n","        coeff1 = (1 + sqrt3) / (4 * math.sqrt(2))\n","        coeff2 = (3 + sqrt3) / (4 * math.sqrt(2))\n","        coeff3 = (3 - sqrt3) / (4 * math.sqrt(2))\n","        coeff4 = (1 - sqrt3) / (4 * math.sqrt(2))\n","        return coeff1 - coeff2 * x + coeff3 * x**2 - coeff4 * x**3\n","\n","    def meyer_aux(self, v):\n","        pi = math.pi\n","        return torch.where(v <= 1/2, torch.ones_like(v), torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * self.nu(2 * v - 1))))\n","\n","    def nu(self, t):\n","        return t**4 * (35 - 84*t + 70*t**2 - 20*t**3)\n","\n","    def forward(self, x):\n","        wavelet_output = self.wavelet_transform(x)\n","        base_output = F.linear(x, self.weight1)\n","        combined_output = wavelet_output + base_output\n","        return self.bn(combined_output)\n"]},{"cell_type":"code","execution_count":null,"id":"cb892b00","metadata":{"id":"cb892b00"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.optim as optim\n","from torch.utils.data import DataLoader, TensorDataset\n","import numpy as np\n","import pandas as pd\n","from sklearn.preprocessing import StandardScaler, RobustScaler, MaxAbsScaler, PowerTransformer, MinMaxScaler\n","\n","class Generator(nn.Module):\n","    def __init__(self, latent_dim, num_features, num_classes):\n","        super(Generator, self).__init__()\n","        self.label_emb = nn.Embedding(num_classes, num_classes)\n","        self.model = nn.Sequential(\n","            nn.Linear(latent_dim + num_classes, 128),\n","            nn.ReLU(True),\n","            nn.Linear(128, num_features),\n","            nn.Tanh()\n","        )\n","\n","    def forward(self, z, labels):\n","        c = self.label_emb(labels)\n","        x = torch.cat([z, c], 1)\n","        output = self.model(x)\n","        return output\n","\n","class Discriminator(nn.Module):\n","    def __init__(self, num_features, num_classes):\n","        super(Discriminator, self).__init__()\n","        self.label_embedding = nn.Embedding(num_classes, num_classes)\n","        self.model = nn.Sequential(\n","            nn.Linear(num_features + num_classes, 128),\n","            nn.ReLU(True),\n","            nn.Linear(128, 1),\n","            nn.Sigmoid()\n","        )\n","\n","    def forward(self, x, labels):\n","        c = self.label_embedding(labels)\n","        x = torch.cat([x, c], 1)\n","        output = self.model(x)\n","        return output\n","\n","def train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes):\n","    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n","    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))\n","    criterion = nn.BCELoss()\n","\n","    generator.train()\n","    discriminator.train()\n","\n","    for epoch in range(num_epochs):\n","        for i, (data, labels) in enumerate(data_loader):\n","            batch_size = data.size(0)\n","\n","            real_data = data.to(device)\n","            real_labels = labels.to(device)\n","            valid = torch.ones(batch_size, 1, device=device)\n","            fake = torch.zeros(batch_size, 1, device=device)\n","\n","            # Train Discriminator\n","            optimizer_D.zero_grad()\n","            real_loss = criterion(discriminator(real_data, real_labels), valid)\n","            z = torch.randn(batch_size, latent_dim, device=device)\n","            gen_labels = torch.randint(0, num_classes, (batch_size,), device=device).to(device)\n","            fake_data = generator(z, gen_labels)\n","            fake_loss = criterion(discriminator(fake_data, gen_labels), fake)\n","            d_loss = (real_loss + fake_loss) / 2\n","            d_loss.backward()\n","            optimizer_D.step()\n","\n","            # Train Generator\n","            optimizer_G.zero_grad()\n","            z = torch.randn(batch_size, latent_dim, device=device)\n","            g_loss = criterion(discriminator(generator(z, gen_labels), gen_labels), valid)\n","            g_loss.backward()\n","            optimizer_G.step()\n","\n","        print(f'Epoch [{epoch+1}/{num_epochs}] Discriminator Loss: {d_loss.item()}, Generator Loss: {g_loss.item()}')\n","\n","def generate_synthetic_data(generator, num_samples, label, num_features, latent_dim, device):\n","    z = torch.randn(num_samples, latent_dim, device=device)\n","    labels = torch.full((num_samples,), label, dtype=torch.long, device=device)\n","    with torch.no_grad():\n","        synthetic_data = generator(z, labels).detach().cpu().numpy()\n","    return synthetic_data\n","\n","def load_and_preprocess_data(path,target_class_count=1500, latent_dim=100, num_epochs=100, batch_size=32):\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","    df = pd.read_csv(path)\n","    numerical_columns = [\n","        'peak_time', 'peak_time_ns', 'start_time', 'start_time_ns', 'duration',\n","        'peak_frequency', 'central_freq', 'bandwidth', 'amplitude', 'snr', 'q_value',\n","        '1400Ripples', '1080Lines', 'Air_Compressor', 'Blip', 'Chirp', 'Extremely_Loud',\n","        'Helix', 'Koi_Fish', 'Light_Modulation', 'Low_Frequency_Burst', 'Low_Frequency_Lines',\n","        'No_Glitch', 'None_of_the_Above', 'Paired_Doves', 'Power_Line', 'Repeating_Blips',\n","        'Scattered_Light', 'Scratchy', 'Tomte', 'Violin_Mode', 'Wandering_Line', 'Whistle'\n","    ]\n","\n","    df_numerical = df[numerical_columns]\n","    scaler = MinMaxScaler()\n","    df_numerical = scaler.fit_transform(df_numerical)\n","\n","\n","    string_labels = df['ml_label']\n","    label_encoder = LabelEncoder()\n","    labels = label_encoder.fit_transform(string_labels)\n","\n","    num_features = df_numerical.shape[1]\n","    num_classes = np.unique(labels).size\n","\n","    tensor_data = torch.tensor(df_numerical.astype(np.float32))\n","    tensor_labels = torch.tensor(labels)\n","    dataset = TensorDataset(tensor_data, tensor_labels)\n","    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n","\n","    generator = Generator(latent_dim, num_features, num_classes).to(device)\n","    discriminator = Discriminator(num_features, num_classes).to(device)\n","\n","    train_cgan(generator, discriminator, data_loader, device, num_epochs, latent_dim, num_classes)\n","\n","    for class_label in range(num_classes):\n","        current_count = np.sum(labels == class_label)\n","        if current_count < target_class_count:\n","            num_samples_needed = target_class_count - current_count\n","            synthetic_data = generate_synthetic_data(generator, num_samples_needed, class_label, num_features, latent_dim, device)\n","            synthetic_labels = np.full(num_samples_needed, class_label)\n","            df_numerical = np.vstack([df_numerical, synthetic_data])\n","            labels = np.concatenate([labels, synthetic_labels])\n","\n","    return df_numerical, labels"]},{"cell_type":"code","execution_count":null,"id":"92eff356","metadata":{"id":"92eff356"},"outputs":[],"source":["import pandas as pd\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader, Dataset\n","from sklearn.preprocessing import LabelEncoder, PowerTransformer, MinMaxScaler\n","from sklearn.preprocessing import StandardScaler\n","from tqdm import tqdm\n","\n","# Create pairs of augmented data\n","def create_pairs(data):\n","    augmented_original_pairs = []\n","    for row in data:\n","        augmented_row1 = augment_data(row)\n","        augmented_row2 = augment_data(row)\n","        augmented_original_pairs.append((augmented_row1, augmented_row2, row))\n","    return np.array(augmented_original_pairs)\n","\n","# Define an augmentation function\n","def augment_data(data, noise_level=0.1):\n","    noise = np.random.normal(0, noise_level, data.shape)\n","    return data + noise"]},{"cell_type":"code","execution_count":null,"id":"c8ce8478","metadata":{"id":"c8ce8478"},"outputs":[],"source":["import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import DataLoader, Dataset\n","from tqdm import tqdm\n","\n","# Dataset class\n","class TrainDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['pos_1'])\n","\n","    def __getitem__(self, idx):\n","        pos_1 = self.data['pos_1'][idx]\n","        pos_2 = self.data['pos_2'][idx]\n","        return torch.tensor(pos_1, dtype=torch.float), torch.tensor(pos_2, dtype=torch.float)\n","\n","class MemoryDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['true_data'])\n","\n","    def __getitem__(self, idx):\n","        true_data = self.data['true_data'][idx]\n","        target = self.data['target'][idx]\n","        return torch.tensor(true_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)\n","\n","class TestDataset(Dataset):\n","    def __init__(self, data):\n","        self.data = data\n","\n","    def __len__(self):\n","        return len(self.data['true_data'])\n","\n","    def __getitem__(self, idx):\n","        raw_data = self.data['true_data'][idx]\n","        target = self.data['target'][idx]\n","        return torch.tensor(raw_data, dtype=torch.float), torch.tensor(target, dtype=torch.long)\n","\n","\n","class NumericalModel(nn.Module):\n","    def __init__(self, input_dim, feature_dim=128, wavelet_type='mexican_hat'):\n","        super(NumericalModel, self).__init__()\n","\n","        self.channels = 1  # Number of channels in the input\n","        self.length = input_dim  # Length of the sequence\n","\n","        # Adjusted feature extractor for 1D data with Conv1D and MaxPool1D\n","        self.feature_extractor = nn.Sequential(\n","            nn.Conv1d(self.channels, 256, kernel_size=5, stride=1, padding=2),\n","            nn.ReLU(),\n","            nn.MaxPool1d(kernel_size=2, stride=2),\n","            nn.Conv1d(256, 256, kernel_size=5, stride=1, padding=2),\n","            nn.ReLU(),\n","            nn.MaxPool1d(kernel_size=2, stride=2),\n","            nn.Flatten(),\n","            nn.Linear(256 * (self.length // 4), 512),\n","            nn.ReLU()\n","        )\n","\n","        # Updated projection head with wavelet-based KANLinear\n","        self.projection_head = nn.Sequential(\n","            KANLinear(512, 256, wavelet_type=wavelet_type),\n","            nn.BatchNorm1d(256),\n","            nn.ReLU(inplace=True),\n","            KANLinear(256, feature_dim, wavelet_type=wavelet_type)\n","        )\n","\n","    def forward(self, x):\n","        # Reshape input to (batch size, channels, length)\n","        x = x.view(-1, self.channels, self.length)\n","        feature = self.feature_extractor(x)\n","        out = self.projection_head(feature)\n","        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)\n"]},{"cell_type":"code","execution_count":null,"id":"385ce5c9","metadata":{"id":"385ce5c9"},"outputs":[],"source":["import torch\n","from torch.utils.data import DataLoader\n","from tqdm import tqdm\n","import numpy as np\n","from geomloss import SamplesLoss\n","np.random.seed(0)\n","\n","# Function to create a mask for negative samples\n","def get_negative_mask(batch_size):\n","    negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)\n","    for i in range(batch_size):\n","        negative_mask[i, i] = 0\n","        negative_mask[i, i + batch_size] = 0\n","    negative_mask = torch.cat((negative_mask, negative_mask), 0)\n","    return negative_mask\n","\n","# Training function\n","def train_drop_sub(net, data_loader, train_optimizer, device, epoch, epochs,batch_size, args):\n","    net.train()\n","    total_loss, total_num = 0.0,0\n","    train_bar = tqdm(data_loader)\n","\n","    for pos_1, pos_2 in train_bar:\n","        pos_1, pos_2 = pos_1.to(device), pos_2.to(device)\n","        feature_1, out_1 = net(pos_1)\n","        feature_2, out_2 = net(pos_2)\n","\n","        # Create similarity matrix\n","        sim_matrix = torch.mm(feature_1, feature_2.t())  # Example using dot product for similarity\n","\n","        # Apply mixup alpha\n","        alpha = args.drop_alpha\n","\n","        # Remove top alpha dissimilar entries from sim_matrix\n","        num_to_remove = int(alpha * sim_matrix.numel())\n","\n","        # neg score\n","        out = torch.cat([out_1, out_2], dim=0)\n","        actual_batch_size = pos_1.size(0)\n","        neg = torch.exp(torch.mm(out, out.t().contiguous()) )\n","        old_neg = neg.clone()\n","        mask = get_negative_mask(actual_batch_size).to(device)\n","        neg = neg.masked_select(mask).view(2 * actual_batch_size, -1)\n","\n","        # Remove top alpha similar entries from neg\n","        similar_values, _ = torch.topk(neg.view(-1), num_to_remove, largest=True)\n","        threshold_similar = similar_values[-1]\n","        neg[neg >= threshold_similar] = 0\n","\n","        # pos score\n","        pos = torch.exp(torch.sum(out_1 * out_2, dim=-1))\n","        pos = torch.cat([pos, pos], dim=0)\n","        Ng = neg.sum(dim=-1)\n","\n","        # loss_wasserstein distance\n","        loss_wasserstein = SamplesLoss(loss=\"sinkhorn\", p=2)\n","        pos_reshaped = pos.unsqueeze(1)\n","        Ng_reshaped = Ng.unsqueeze(1)\n","        loss_distances = loss_wasserstein(pos_reshaped, Ng_reshaped)\n","\n","        # contrastive loss\n","        loss = 0.7*torch.log(loss_distances) - 0.3*(-torch.log(pos / (pos + Ng))).mean()\n","\n","        train_optimizer.zero_grad()\n","        loss.backward()\n","        train_optimizer.step()\n","\n","        total_num += actual_batch_size\n","        total_loss += loss.item() * batch_size\n","        train_bar.set_description(f'Train Epoch: [{epoch}/{epochs}] Loss: {total_loss / total_num:.4f}')\n","\n","    return total_loss / total_num\n"]},{"cell_type":"code","execution_count":null,"id":"97c3e3de","metadata":{"id":"97c3e3de"},"outputs":[],"source":["import torch\n","from tqdm import tqdm\n","from sklearn.metrics import multilabel_confusion_matrix\n","import numpy as np\n","\n","def d_index_score_cal(y_true, y_pred):\n","    if len(y_true) == 0 or len(y_pred) == 0:\n","        return 0\n","\n","    # Ensure y_pred is not a list of lists\n","    if all(isinstance(pred, list) for pred in y_pred):\n","        y_pred = [1 if true_label in pred else 0 for true_label, pred in zip(y_true, y_pred)]\n","    else:\n","        # If y_pred is already flat, ensure it's the same length as y_true\n","        assert len(y_true) == len(y_pred), \"y_true and y_pred must have the same length\"\n","\n","    conf_matrix = multilabel_confusion_matrix(y_true, y_pred)\n","    d_indices = []\n","    for i in range(len(conf_matrix)):\n","        tn, fp, fn, tp = conf_matrix[i].ravel()\n","        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 1e-7\n","        specificity = tn / (tn + fp) if (tn + fp) > 0 else 1e-7\n","        accuracy = (tp + tn) / (tp + tn + fp + fn)\n","        d = np.log2(1 + accuracy) + np.log2(1 + (sensitivity + specificity) / 2)\n","        d_indices.append(d)\n","    avg_d_index = sum(d_indices) / len(d_indices) if len(d_indices) > 0 else 0\n","    return avg_d_index\n","\n","\n","def create_exact_match_list(top2_predictions, targets):\n","    exact_matches = []\n","    for pred_pair, target in zip(top2_predictions, targets):\n","        # Find the matching prediction, append -1 if no match is found\n","        match = next((pred for pred in pred_pair if pred == target), -1)\n","        exact_matches.append(match)\n","    return exact_matches\n"]},{"cell_type":"code","execution_count":null,"id":"429d9569","metadata":{"id":"429d9569"},"outputs":[],"source":["import torch\n","from tqdm import tqdm\n","\n","def flatten_predictions(predictions, true_labels):\n","    TP, FP, FN = 0, 0, 0\n","    for i, preds in enumerate(predictions):\n","        true_label = true_labels[i]\n","\n","        # Count true positives, false positives, false negatives\n","        if true_label in preds:\n","            TP += 1\n","        else:\n","            FN += 1\n","        FP += len([p for p in preds if p != true_label])\n","\n","    return  TP, FP, FN\n","\n","def test_knn(net, memory_data_loader, test_data_loader, k, c):\n","    net.eval()\n","    total_top1, total_top2, total_top3, total_num = 0.0, 0.0, 0.0, 0\n","    original_labels_list = []\n","    top1_predictions, top2_predictions, top3_predictions = [], [], []\n","    feature_bank = []\n","\n","    with torch.no_grad():\n","        # Generate feature bank\n","        for data, _ in tqdm(memory_data_loader, desc='Feature extracting'):\n","            feature = net(data.cuda(non_blocking=True))[0]\n","            feature_bank.append(feature)\n","        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()\n","\n","        feature_labels = torch.tensor([label for _, label in memory_data_loader.dataset], device=feature_bank.device)\n","\n","        test_bar = tqdm(test_data_loader)\n","        for data, target in test_bar:\n","            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)\n","            feature = net(data)[0]\n","\n","            total_num += data.size(0)\n","            sim_matrix = torch.mm(feature, feature_bank)\n","            sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)\n","            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)\n","            sim_weight = (sim_weight).exp()\n","\n","            one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)\n","            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1).long(), value=1.0)\n","            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)\n","            pred_labels = pred_scores.argsort(dim=-1, descending=True)\n","\n","            for i in range(data.size(0)):\n","                original_labels_list.append(target[i].item())\n","                top1_predictions.append([pred_labels[i, 0].item()])\n","                top2_predictions.append(pred_labels[i, :2].tolist())\n","                top3_predictions.append(pred_labels[i, :3].tolist())\n","\n","            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","            total_top2 += torch.sum((pred_labels[:, :2] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","            total_top3 += torch.sum((pred_labels[:, :3] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()\n","\n","    TP_top1, FP_top1, FN_top1 = flatten_predictions(top1_predictions, original_labels_list)\n","    TP_top2, FP_top2, FN_top2 = flatten_predictions(top2_predictions, original_labels_list)\n","    TP_top3, FP_top3, FN_top3 = flatten_predictions(top3_predictions, original_labels_list)\n","\n","    # Flatten predictions for d-index calculation\n","    flat_top1_predictions = [pred[0] for pred in top1_predictions]\n","    flat_top2_predictions = create_exact_match_list(top2_predictions, original_labels_list)\n","    flat_top3_predictions = create_exact_match_list(top3_predictions, original_labels_list)\n","\n","    precision_top1, recall_top1, f1_top1 = calculate_metrics(TP_top1, FP_top1, FN_top1)\n","    precision_top2, recall_top2, f1_top2 = calculate_metrics(TP_top2, FP_top2, FN_top2)\n","    precision_top3, recall_top3, f1_top3 = calculate_metrics(TP_top3, FP_top3, FN_top3)\n","\n","    # Calculate d-indices\n","    d_index_top1 = d_index_score_cal(original_labels_list, flat_top1_predictions)\n","    d_index_top2 = d_index_score_cal(original_labels_list, flat_top2_predictions)\n","    d_index_top3 = d_index_score_cal(original_labels_list, flat_top3_predictions)\n","\n","\n","    return {\n","        'top1': {'accuracy': total_top1 / total_num * 100, 'd_index': d_index_top1, 'precision': precision_top1, 'recall': recall_top1, 'f1': f1_top1},\n","        'top2': {'accuracy': total_top2 / total_num * 100, 'd_index': d_index_top2, 'precision': precision_top2, 'recall': recall_top2, 'f1': f1_top2},\n","        'top3': {'accuracy': total_top3 / total_num * 100, 'd_index': d_index_top3, 'precision': precision_top3, 'recall': recall_top3, 'f1': f1_top3}\n","    }\n","\n","def calculate_metrics(TP, FP, FN):\n","    precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n","    recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n","    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n","    return precision, recall, f1\n","\n"]},{"cell_type":"code","execution_count":null,"id":"e0fce782","metadata":{"scrolled":false,"id":"e0fce782"},"outputs":[],"source":["import pandas as pd\n","import torch\n","from torch import nn\n","from torch.utils.data import DataLoader\n","\n","class Args:\n","    def __init__(self, drop_alpha):\n","        self.drop_alpha = drop_alpha\n","\n","def main():\n","    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","    print(\"Using device:\", device)\n","\n","    epochs = 100\n","    k = 30\n","\n","    path = 'clean_data_O1.csv'\n","\n","    df_numerical, labels = load_and_preprocess_data(path)\n","    num_classes = len(np.unique(labels))\n","    augmented_original_pairs = create_pairs(df_numerical)\n","\n","    learning_rates =[1e-3]\n","    batch_sizes = [128]\n","\n","    drop_alphas = [0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2]\n","\n","    for drop_alpha in drop_alphas:\n","        args = Args(drop_alpha)\n","        print(f\"Running for drop_alpha: {args.drop_alpha}\")\n","\n","        all_metrics = []  # Initialize the list to store metrics for the current drop_alpha\n","\n","        for lr in learning_rates:\n","            for batch_size in batch_sizes:\n","                print(f\"Learning Rate: {lr}, Batch Size: {batch_size}\")\n","\n","                train_data = {\"pos_1\": augmented_original_pairs[:, 0], \"pos_2\": augmented_original_pairs[:, 1]}\n","                train_dataset = TrainDataset(train_data)\n","                train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n","\n","                memory_data = {\"true_data\": augmented_original_pairs[:, 2], 'target': labels}\n","                memory_dataset = MemoryDataset(memory_data)\n","                memory_data_loader = DataLoader(memory_dataset, batch_size=batch_size, shuffle=False)\n","\n","                test_data = {\"true_data\": augmented_original_pairs[:, 2], 'target': labels}\n","                test_dataset = TestDataset(test_data)\n","                test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n","\n","                model = NumericalModel(input_dim=df_numerical.shape[1], feature_dim=64).to(device)\n","                model = nn.DataParallel(model)\n","                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)\n","\n","                for epoch in range(epochs):\n","                    train_loss = train_drop_sub(model, train_loader, optimizer, device, epoch, epochs, batch_size, args)\n","                    test_results = test_knn(model, memory_data_loader, test_loader, k, num_classes)\n","\n","                    # Collect metrics\n","                    epoch_metrics = {\n","                        'Epoch': epoch + 1,\n","                        'Learning Rate': lr,\n","                        'Batch Size': batch_size,\n","                        'Drop Alpha': args.drop_alpha,\n","                        'Top1 Accuracy': test_results['top1']['accuracy'],\n","                        'Top1 D-Index': test_results['top1']['d_index'],\n","                        'Top1 Precision': test_results['top1']['precision'],\n","                        'Top1 Recall': test_results['top1']['recall'],\n","                        'Top1 F1': test_results['top1']['f1'],\n","                        'Top2 Accuracy': test_results['top2']['accuracy'],\n","                        'Top2 D-Index': test_results['top2']['d_index'],\n","                        'Top2 Precision': test_results['top2']['precision'],\n","                        'Top2 Recall': test_results['top2']['recall'],\n","                        'Top2 F1': test_results['top2']['f1'],\n","                        'Top3 Accuracy': test_results['top3']['accuracy'],\n","                        'Top3 D-Index': test_results['top3']['d_index'],\n","                        'Top3 Precision': test_results['top3']['precision'],\n","                        'Top3 Recall': test_results['top3']['recall'],\n","                        'Top3 F1': test_results['top3']['f1'],\n","                    }\n","\n","                    all_metrics.append(epoch_metrics)\n","\n","                    # Unpack the results\n","                    top1_metrics = test_results['top1']\n","                    top2_metrics = test_results['top2']\n","                    top3_metrics = test_results['top3']\n","\n","                    # Access individual metrics\n","                    top1_accuracy = top1_metrics['accuracy']\n","                    d_index_top1 = top1_metrics['d_index']\n","                    precision_top1 = top1_metrics['precision']\n","                    recall_top1 = top1_metrics['recall']\n","                    f1_top1 = top1_metrics['f1']\n","\n","                    top2_accuracy = top2_metrics['accuracy']\n","                    d_index_top2 = top2_metrics['d_index']\n","                    precision_top2 = top2_metrics['precision']\n","                    recall_top2 = top2_metrics['recall']\n","                    f1_top2 = top2_metrics['f1']\n","\n","                    top3_accuracy = top3_metrics['accuracy']\n","                    d_index_top3 = top3_metrics['d_index']\n","                    precision_top3 = top3_metrics['precision']\n","                    recall_top3 = top3_metrics['recall']\n","                    f1_top3 = top3_metrics['f1']\n","\n","                    # Print Top1 Metrics\n","                    print(\"Top 1 Metrics:\")\n","                    print(f\"  Accuracy: {top1_accuracy:.8f}%\")\n","                    print(f\"  D-Index: {d_index_top1:.8f}\")\n","                    print(f\"  Precision: {precision_top1:.8f}\")\n","                    print(f\"  Recall: {recall_top1:.8f}\")\n","                    print(f\"  F1 Score: {f1_top1:.8f}\\n\")\n","\n","                    # Print Top2 Metrics\n","                    print(\"Top 2 Metrics:\")\n","                    print(f\"  Accuracy: {top2_accuracy:.8f}%\")\n","                    print(f\"  D-Index: {d_index_top2:.8f}\")\n","                    print(f\"  Precision: {precision_top2:.8f}\")\n","                    print(f\"  Recall: {recall_top2:.8f}\")\n","                    print(f\"  F1 Score: {f1_top2:.8f}\\n\")\n","\n","                    # Print Top3 Metrics\n","                    print(\"Top 3 Metrics:\")\n","                    print(f\"  Accuracy: {top3_accuracy:.8f}%\")\n","                    print(f\"  D-Index: {d_index_top3:.8f}\")\n","                    print(f\"  Precision: {precision_top3:.8f}\")\n","                    print(f\"  Recall: {recall_top3:.8f}\")\n","                    print(f\"  F1 Score: {f1_top3:.8f}\")\n","\n","        # Convert the list of metrics to a DataFrame and save it as a CSV file specific to the current drop_alpha\n","        df_metrics = pd.DataFrame(all_metrics)\n","        filename = f'model_metrics_drop_alpha_{args.drop_alpha}.csv'\n","        df_metrics.to_csv(filename, index=False)\n","        print(f\"Model metrics for drop_alpha {args.drop_alpha} saved!\")\n","\n","if __name__ == \"__main__\":\n","    main()\n"]}],"metadata":{"kernelspec":{"display_name":"PyTorch GPU","language":"python","name":"pytorch_gpu"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.18"},"colab":{"provenance":[]}},"nbformat":4,"nbformat_minor":5}