from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict
import numpy as np
from .DROCC_trainer import DROCCTrainer


class MLP(nn.Module):
    """
    Multi-layer perceptron with single hidden layer.
    """

    def __init__(self,
                 input_dim=2,
                 num_classes=1,
                 num_hidden_nodes=20):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.num_hidden_nodes = num_hidden_nodes
        activ = nn.ReLU(True)
        self.feature_extractor = nn.Sequential(OrderedDict([
            ('fc', nn.Linear(self.input_dim, self.num_hidden_nodes)),
            ('relu1', activ)]))
        self.size_final = self.num_hidden_nodes

        self.classifier = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(self.size_final, self.num_classes))]))

    def forward(self, input):
        features = self.feature_extractor(input)
        logits = self.classifier(features.view(-1, self.size_final))
        return logits


def adjust_learning_rate(epoch, total_epochs, only_ce_epochs, learning_rate, optimizer):
    """Adjust learning rate during training.

    Parameters
    ----------
    epoch: Current training epoch.
    total_epochs: Total number of epochs for training.
    only_ce_epochs: Number of epochs for initial pretraining.
    learning_rate: Initial learning rate for training.
    """
    # We dont want to consider the only ce
    # based epochs for the lr scheduler
    epoch = epoch - only_ce_epochs
    drocc_epochs = total_epochs - only_ce_epochs
    # lr = learning_rate
    if epoch <= drocc_epochs:
        lr = learning_rate * 0.001
    if epoch <= 0.90 * drocc_epochs:
        lr = learning_rate * 0.01
    if epoch <= 0.60 * drocc_epochs:
        lr = learning_rate * 0.1
    if epoch <= 0.30 * drocc_epochs:
        lr = learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer


class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return torch.from_numpy(self.data[idx]), (self.labels[idx]), torch.tensor([0])


def load_data(path):
    train_data = np.load(os.path.join(path, 'train_data.npy'), allow_pickle=True)
    train_lab = np.ones((train_data.shape[0]))  # All positive labelled data points collected
    test_data = np.load(os.path.join(path, 'test_data.npy'), allow_pickle=True)
    test_lab = np.load(os.path.join(path, 'test_labels.npy'), allow_pickle=True)

    ## preprocessing
    mean = np.mean(train_data, 0)
    std = np.std(train_data, 0)
    train_data = (train_data - mean) / (std + 1e-4)
    num_features = train_data.shape[1]
    test_data = (test_data - mean) / (std + 1e-4)

    train_samples = train_data.shape[0]
    test_samples = test_data.shape[0]
    print("Train Samples: ", train_samples)
    print("Test Samples: ", test_samples)

    return CustomDataset(train_data, train_lab), CustomDataset(test_data, test_lab), num_features


class DROCC:
    def __init__(self, train_x, test_x, test_y, latent_dim, lamb, radius, gamma, device='cuda'):

        self.train_x, self.test_x, self.test_y, self.latent_dim, self.lamb, self.radius, self.gamma = \
            train_x, test_x, test_y, latent_dim, lamb, radius, gamma
        self.device = device

        train_dataset = CustomDataset(train_x, np.zeros(train_x.shape[0]))
        test_dataset = CustomDataset(test_x, test_y)

        self.train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True,
                                       num_workers=0)
        self.test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False,
                                      num_workers=0)

        self.net = MLP(input_dim=train_x.shape[1], num_hidden_nodes=latent_dim, num_classes=1).to(device)
        self.optimizer = optimizer = optim.Adam(self.net.parameters(),
                               lr=0.001)
        self.trainer = DROCCTrainer(self.net, optimizer, lamb, radius, gamma, device)

    def fit(self):
        self.trainer.train(self.train_loader, None, 0.001, adjust_learning_rate, 200,
                      ascent_step_size=0.001,
                      only_ce_epochs=50
                           )

    def decision_function(self, test_x):
        self.net = self.net.to(self.device)
        test_set = torch.utils.data.TensorDataset(torch.Tensor(test_x), torch.zeros(test_x.shape[0]))
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=1024, shuffle=False, num_workers=0)

        score = self.trainer.test(test_loader)
        # print('Test F1: {}'.format(score))
        return score

def main():
    train_dataset, test_dataset, num_features = load_data(args.data_path)
    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=True)

    model = MLP(input_dim=num_features, num_hidden_nodes=args.hd, num_classes=1).to(device)
    if args.optim == 1:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=args.mom)
        print("using SGD")
    else:
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr)
        print("using Adam")

    trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device)

    if args.eval == 0:
        # Training the model
        trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
                      metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs=args.only_ce_epochs)

        trainer.save(args.model_dir)

    else:
        if os.path.exists(os.path.join(args.model_dir, 'model.pt')):
            trainer.load(args.model_dir)
            print("Saved Model Loaded")
        else:
            print('Saved model not found. Cannot run evaluation.')
            exit()
        score = trainer.test(test_loader, 'F1')
        print('Test F1: {}'.format(score))


if __name__ == '__main__':
    torch.set_printoptions(precision=5)

    parser = argparse.ArgumentParser(description='PyTorch Simple Training')
    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='batch size for training')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train')
    parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N',
                        help='number of epochs to train with only CE loss')
    parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N',
                        help='Number of gradient ascent steps')
    parser.add_argument('--hd', type=int, default=128, metavar='N',
                        help='Number of hidden nodes for LSTM model')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')
    parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR',
                        help='step size of gradient ascent')
    parser.add_argument('--mom', type=float, default=0.99, metavar='M',
                        help='momentum')
    parser.add_argument('--model_dir', default='log',
                        help='path where to save checkpoint')
    parser.add_argument('--one_class_adv', type=int, default=1, metavar='N',
                        help='adv loss to be used or not, 1:use 0:not use(only CE)')
    parser.add_argument('--radius', type=float, default=0.2, metavar='N',
                        help='radius corresponding to the definition of set N_i(r)')
    parser.add_argument('--lamda', type=float, default=1, metavar='N',
                        help='Weight to the adversarial loss')
    parser.add_argument('--reg', type=float, default=0, metavar='N',
                        help='weight reg')
    parser.add_argument('--eval', type=int, default=0, metavar='N',
                        help='whether to load a saved model and evaluate (0/1)')
    parser.add_argument('--optim', type=int, default=0, metavar='N',
                        help='0 : Adam 1: SGD')
    parser.add_argument('--gamma', type=float, default=2.0, metavar='N',
                        help='r to gamma * r projection for the set N_i(r)')
    parser.add_argument('-d', '--data_path', type=str, default='.')
    parser.add_argument('--metric', type=str, default='F1')
    args = parser.parse_args()

    # settings
    # Checkpoint store path
    model_dir = args.model_dir
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    main()