from __future__ import print_function
import os
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict
from data_preprocess.process_fmnist import FMNIST_Dataset
from trainer.PLAD_trainer_fmnist import PLADTrainer
from VAE_fmnist import VAE
import itertools
import scipy.io


from networks import mlp
from datasets import load_dataset

class FashionMNIST_LeNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.rep_dim = 64
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 16, 5, bias=False, padding=2)
        self.bn1 = nn.BatchNorm2d(16, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(16, 32, 5, bias=False, padding=2)
        self.bn2 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(32 * 7 * 7, 128, bias=False)
        self.fc2 = nn.Linear(128, self.rep_dim, bias=False)
        self.fc3 = nn.Linear(self.rep_dim, 1, bias=False)

    def forward(self, x):
        x = x.view(x.shape[0],1,28,28)
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


def main():
    # dataset = FMNIST_Dataset("data", args.normal_class)
    # train_loader, test_loader = dataset.loaders(batch_size=args.batch_size)
    # print("Fashion-MNIST class: ", args.normal_class)

    dataset_names = [
        'kdd',
        # '10_cover',
        # '11_donors', '12_fault', '13_fraud', '14_glass', '15_Hepatitis', '16_http',
        # '17_InternetAds', '18_Ionosphere', '19_landsat',
        # '1_ALOI', '20_letter', '21_Lymphography', '22_magic.gamma', '23_mammography', '24_mnist',
        # '25_musk', '26_optdigits', '27_PageBlocks', '28_pendigits', '29_Pima', '2_annthyroid',
        # '30_satellite', '31_satimage-2', '32_shuttle', '33_skin', '34_smtp', '35_SpamBase', '36_speech',
        # '37_Stamps', '38_thyroid', '39_vertebral', '3_backdoor', '40_vowels', '41_Waveform', '42_WBC',
        # '43_WDBC', '44_Wilt', '45_wine', '46_WPBC', '47_yeast', '4_breastw', '5_campaign', '6_cardio',
        # '7_Cardiotocography', '8_celeba',
        # '9_census', '48_arrhythmia'
    ]

    # dataset_names = [
    #     # 'arrhythmia', 'wine', 'lympho', 'glass', 'vertebral', 'wbc', 'ecoli', 'ionosphere', 'breastw',
    #     # 'pima', 'vowels', 'letter', 'cardio', 'seismic',
    #     #
    #     #
    #     # 'abalone',
    #     # 'pendigits',
    #     # 'mammography',
    #     # 'mulcross',
    #     # 'thyroid', 'optdigits', 'satimage', 'shuttle',
    #     # 'musk', 'speech',
    #     'kdd',
    #     # 'forest_cover',
    # ]

    for dataset_ in dataset_names:
        for num in range(1):
            print(dataset_)
            train_data, test_data, classes = load_dataset.load_dataset('G:\\fan\\ad\\datasets\\data',
                                                                       name=dataset_)
            #
            # train_data, test_data, classes = load_dataset.load_adbench_dataset('G:\\fan\\ad\other_ad_methods\\datasets',
            #                                                            name=dataset_)

            print(dataset_, classes)
            if len(classes) <= 2:
                classes = [0]
            for normal_c in classes:
                train_loader, test_loader, mu, std, n_dim, n_sample \
                    = load_dataset.process_dataset(train_data, test_data,
                                                   classes, normal_c,
                                                   b_size=256,
                                                   normalize=True)

                # continue
                if train_loader is None:
                    continue

                print(normal_c)

                if n_dim > 64:
                    hidden_dim = 256
                elif n_dim > 256:
                    hidden_dim = 1024
                else:
                    hidden_dim = 64

                if n_sample > 15000:
                    n_epoch = 150
                    lr_milestones = [50]
                elif n_sample > 8000:
                    n_epoch = 250
                    lr_milestones = [100]
                else:
                    n_epoch = 500
                    lr_milestones = [1000]

                model = mlp.VanilaMLP(input_dim=n_dim, hidden_dim=hidden_dim).to(device)
                # model = nn.DataParallel(model)

                # e_ae = VAE(input_dim=n_dim, h_dim=hidden_dim, z_dim=128).to(device)
                e_ae = VAE(input_dim=n_dim, h_dim=hidden_dim, z_dim=128).to(device)

                if args.optim == 1:
                    optimizer = optim.SGD(itertools.chain(model.parameters(),e_ae.parameters()),lr=args.lr, momentum=args.mom)
                    print("Optimizer: SGD")
                else:
                    optimizer = optim.Adam(itertools.chain(model.parameters(),e_ae.parameters()), lr=args.lr, amsgrad=True)
                    print("Optimizer: Adam")
                scores = []


                trainer = PLADTrainer(model,e_ae, optimizer, args.lamda, device)

                score = trainer.train(train_loader, test_loader, args.lr, n_epoch, metric=args.metric)
                # score = trainer.test(test_loader, 'AUC')
                print('Test AUC: {}'.format(score))
                auc_score, f_score = score
                torch.save({'f1': f_score, 'auc': auc_score}, f'./save/PLAD-{dataset_}-{normal_c}-std-{num}.save')

                # if args.eval == 0:
                #     # Training the model
                #     score = trainer.train(train_loader, test_loader, args.lr, args.epochs, metric=args.metric)
                #     trainer.save(args.model_dir)
                #
                # else:
                #     if os.path.exists(os.path.join(args.model_dir, f'./fmnist_trained_model/{dataset_}}-{normal_c}.pt')):
                #         filename = './fmnist_trained_model/fmnist-{}.pt'.format(args.normal_class)
                #         trainer.load(args.model_dir, filename)
                #         print("Testing the trained model on Fashion-MNIST class {}".format(args.normal_class))
                #         print("Saved Model Loaded")
                #     else:
                #         print('Saved model not found. Cannot run evaluation.')
                #         exit()
                #     score = trainer.test(test_loader, 'AUC')
                #     print('Test AUC: {}'.format(score))

if __name__ == '__main__':
    torch.set_printoptions(precision=5)
    
    parser = argparse.ArgumentParser(description='PLAD Training')
    parser.add_argument('--normal_class', type=int, default=5, metavar='N',
                    help='CIFAR10 normal class index')
    parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                        help='batch size for training')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train')                   
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate')   
    parser.add_argument('--lamda', type=float, default=0.1, metavar='N',
                        help='Weight of the perturbator loss')
    parser.add_argument('--optim', type=int, default=0, metavar='N',
                        help='0 : Adam 1: SGD')
    parser.add_argument('--mom', type=float, default=0.0, metavar='M',
                        help='momentum')
    parser.add_argument('--model_dir', default='log',
                        help='path where to save checkpoint')		
    parser.add_argument('--eval', type=int, default=1, metavar='N',
                        help='whether to load a saved model and evaluate (0/1)')
    parser.add_argument('-d', '--data_path', type=str, default='.')
    parser.add_argument('--metric', type=str, default='AUC')
    args = parser. parse_args()


    #Model save path
    model_dir = args.model_dir
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    main()
