
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as fcnal
from sklearn.pipeline import Pipeline

from metrics import eval_target_model

# Determine device to run network on (runs on gpu if available)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def label_to_onehot(labels, num_classes=100):
    """ Converts label into a vector.

    Args:
        labels (int): Class label to convert to tensor.
        num_classes (int):  Number of classes for the model.

    Returns:
        (torch.tensor): Torch tensor with 0's everywhere except for 1 in
            correct class.
    """
    one_hot = torch.eye(num_classes)
    return one_hot[labels.long()]

def train(model=None, data_loader=None, test_loader=None,
          optimizer=None, criterion=None, n_epochs=0,
          classes=None, verbose=False):
    """
    Function to train a model provided
    specified train/test sets and associated
    training parameters.

    Parameters
    ----------
    model       : Module
                  PyTorch conforming nn.Module function
    data_loader : DataLoader
                  PyTorch dataloader function
    test_loader : DataLoader
                  PyTorch dataloader function
    optimizer   : opt object
                  PyTorch conforming optimizer function
    criterion   : loss object
                  PyTorch conforming loss function
    n_epochs    : int
                  number of training epochs
    classes     : list
                  list of classes
    verbose     : boolean
                  flag for verbose print statements
    """
    losses = []
    
    
    
    
    for epoch in range(n_epochs):
        model.train()
        
        for i, batch in enumerate(data_loader):

            data, labels = batch
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(data)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

            if verbose:
                print("[{}/{}][{}/{}] loss = {}"
                      .format(epoch, n_epochs, i,
                              len(data_loader), loss.item()))

        
        
        # evaluate performance on testset at the end of each epoch
        print("[{}/{}]".format(epoch, n_epochs))
        print("Training:")
        train_acc = eval_target_model(model, data_loader, classes=classes)
        print("Test:")
        test_acc = eval_target_model(model, test_loader, classes=classes)
        # plt.plot(losses)
        # plt.show()
    return train_acc, test_acc


def train_attacker(attack_model=None, shadow_model=None,
                   shadow_train=None, shadow_out=None,
                   optimizer=None, criterion=None, n_epochs=0, k=0,
                   verbose=False):
    """
    Trains attack model (classifies a sample as in or
    out of training set) using shadow model outputs
    (probabilities for sample class predictions).
    The type of shadow model used can vary.

    Parameters
    ----------
    attack_model : Module
                   PyTorch conforming nn.Module function
    shadow_model : Module
                   PyTorch conforming nn.Module function
    shadow_train : DataLoader
                   PyTorch dataloader function
    shadow_out   : DataLoader
                   PyTorch dataloader function
    optimizer    : opt object
                   PyTorch conforming optimizer function
    criterion    : loss object
                   PyTorch conforming loss function
    n_epochs     : int
                   number of training epochs
    k            : int
                   Value at which to end using train data list
    """

    in_predicts = []
    out_predicts = []

    if type(shadow_model) is not Pipeline:
        shadow_model = shadow_model
        shadow_model.eval()

    for epoch in range(n_epochs):

        total = 0
        correct = 0

        train_top = np.empty((0, 2))
        out_top = np.empty((0, 2))
        for i, ((train_data, train_lbls),
                (out_data, out_lbls)) in enumerate(zip(shadow_train,
                                                       shadow_out)):

            
            mini_batch_size = train_data.shape[0]
            out_mini_batch_size = out_data.shape[0]
            if mini_batch_size != out_mini_batch_size:
                continue
            
            if type(shadow_model) is not Pipeline:
                train_data = train_data.to(device).detach()
                out_data = out_data.to(device).detach()
                train_posteriors = fcnal.softmax(shadow_model(train_data),
                                                 dim=1)
                out_posteriors = fcnal.softmax(shadow_model(out_data),
                                               dim=1)

            else:
                traininputs = train_data.view(train_data.shape[0], -1)
                outinputs = out_data.view(out_data.shape[0], -1)

                in_preds = shadow_model.predict_proba(traininputs)
                train_posteriors = torch.from_numpy(in_preds).float()
                

                out_preds = shadow_model.predict_proba(outinputs)
                out_posteriors = torch.from_numpy(out_preds).float()
               

            train_sort, _ = torch.sort(train_posteriors, descending=True)
            train_top_k = train_sort[:, :k].clone().to(device)
            for p in train_top_k:
                in_predicts.append((p.max()).item())
            out_sort, _ = torch.sort(out_posteriors, descending=True)
            out_top_k = out_sort[:, :k].clone().to(device)
            for p in out_top_k:
                out_predicts.append((p.max()).item())

            train_top = np.vstack((train_top,
                                   train_top_k[:, :2].cpu().detach().numpy()))
            out_top = np.vstack((out_top,
                                 out_top_k[:, :2].cpu().detach().numpy()))

            train_lbl = torch.ones(mini_batch_size).to(device)
            out_lbl = torch.zeros(out_mini_batch_size).to(device)

            optimizer.zero_grad()

            train_predictions = torch.squeeze(attack_model(train_top_k))
            out_predictions = torch.squeeze(attack_model(out_top_k))

            loss_train = criterion(train_predictions, train_lbl)
            loss_out = criterion(out_predictions, out_lbl)

            loss = (loss_train + loss_out) / 2

            if type(shadow_model) is not Pipeline:
                loss.backward()
                optimizer.step()

            correct += (train_predictions >= 0.5).sum().item()
            correct += (out_predictions < 0.5).sum().item()
            total += train_predictions.size(0) + out_predictions.size(0)
            if verbose:
                print("[{}/{}][{}/{}] loss = {:.2f}, accuracy = {:.2f}"
                      .format(epoch, n_epochs, i, len(shadow_train),
                              loss.item(), 100 * correct / total))

        

