import os
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from opacus.validators import ModuleValidator

from model import get_model
from learn import target_train
from EIU.src.evaluation import measure
from EIU.src.utils import get_logger


def get_posterior(model, dataloader, device):
    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)
    prob = []
    targets = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.cpu().detach().numpy().squeeze()
            outputs = model(inputs)
            curr_probs = torch.nn.functional.softmax(outputs,dim=1)
            prob.append(curr_probs.cpu().detach().numpy().squeeze())
            targets.append(labels)
    return np.stack(prob), np.array(targets)

def get_posterior_tensor(model, dataloader, device): #deprecated
    dataloader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)
    prob = []
    targets = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels
            prob.append(inputs)
            targets.append(labels)
    return torch.cat(prob), torch.stack(targets)

def get_outputs(model, loader, device):
    activations, predictions=[], []
    sample_info = []
    dataloader = torch.utils.data.DataLoader(loader.dataset, batch_size=1, shuffle=False)
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            curr_probs = torch.nn.functional.softmax(outputs,dim=1)
            curr_preds = torch.argmax(curr_probs, axis=1)
            activations.append(curr_probs.cpu().detach().numpy().squeeze())
            predictions.append(curr_preds.cpu().detach().numpy().squeeze())
            sample_info.append((inputs, outputs, labels))
    return np.stack(activations), np.array(predictions), sample_info

def get_target(dataloader):
    data_loader = torch.utils.data.DataLoader(dataloader.dataset, batch_size=1, shuffle=False)
    prob = []
    with torch.no_grad():
      for idx, batch in enumerate(tqdm(data_loader, leave=False)):
            data, target = batch.cuda()
            prob.append(target)
    return torch.cat(prob)

def get_attack_dataloader(shadow_net, shadow_trainloader, shadow_testloader, args, device):
    shadow_train_pos, s_tr_tg = get_posterior(shadow_net, shadow_trainloader, device)
    shadow_test_pos, s_te_tg = get_posterior(shadow_net, shadow_testloader, device)
    
    mia_traindata = np.concatenate((shadow_train_pos,shadow_test_pos))
    mia_trainlabel = np.concatenate((np.ones(len(shadow_train_pos)), np.zeros(len(shadow_test_pos))))
    return (shadow_train_pos, s_tr_tg), (shadow_test_pos, s_te_tg), mia_traindata, mia_trainlabel

def manual_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False