import os
import sys
sys.path.insert(0, './')
import json
import time
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from MIA.MIA import MIA

# Naive method for testing
class NaiveMIA(MIA):

    def __init__(self, name, threshold, metric, mia_mode="attack"):

        super(NaiveMIA, self).__init__(name, threshold, metric, mia_mode)

    def fit(self, model, train_data_generator, shadow_data_generator):

        pass

    def infer(self, model, data, label):

        logits = model(data)
        prob = F.softmax(logits, dim = 1)
        indices = torch.arange(logits.size(0), device = logits.device)
        correct_prob = prob[indices, label]

        return (correct_prob > self.threshold).byte()
