import numpy as np
import pandas as pd
from scipy import stats
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import transforms
import torchvision.datasets as datasets
import resnet
import os
import pickle

from sklearn.model_selection import train_test_split

from wideresnet import WideResNet
from densenet import DenseNet


folder_name = './data/CCLE/'


def preprocess_data(data_name, base_path, model_path, seed = 1, cross_validate_index = 0, clean_test = False, clean_val = False):
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if data_name == 'CIFAR10_densenet':
        PATH = '{}{}/densenet_40/seed{}/'.format(base_path, model_path, seed)

        filename = '{}post_stonet_data.pt'.format(PATH)
        if not os.path.exists(filename):
            np.random.seed(seed)
            torch.manual_seed(seed)
            data_path = './data/'

            normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
            train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                normalize])
            test_transform = transforms.Compose([transforms.ToTensor(),
                                                normalize])
            dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)

            data_seed = 0
            train_set = dataset
            train_set, val_set = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(data_seed))

            test_set = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)


            np.random.seed(seed)
            batch_train = 128
            batch_test = 128

            train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_train, shuffle=True,num_workers=4)
            val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_train, shuffle=False, num_workers=4)
            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_test, shuffle=False, num_workers=4)

            nval = len(val_loader.dataset)

            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            loss_func = nn.CrossEntropyLoss().to(device)
            depth = 40
            growth_rate = 12
            block_config = [(depth - 4) // 6 for _ in range(3)]
            net = DenseNet(growth_rate=growth_rate, block_config=block_config, num_classes=10).to(device)

            num_epochs = 300
            end_epoch = num_epochs - 1
            net.load_state_dict(torch.load(PATH + 'model' + str(end_epoch) + '.pt'))

            with torch.no_grad():
                net.eval()
                total_count = 0
                labels_list = []
                pen_ultimate_list = []

                for cnt, (images, labels) in enumerate(val_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)

                x_train = pen_ultimate.cpu().data.numpy()
                y_train = labels.cpu().data.numpy()

                for cnt, (images, labels) in enumerate(test_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)

                x_test = pen_ultimate.cpu().data.numpy()
                y_test = labels.cpu().data.numpy()

            f = open(filename, 'wb')
            pickle.dump([x_train, y_train, x_test, y_test], f)
            f.close()

            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)


        else:
            f = open(filename, 'rb')
            [x_train, y_train, x_test, y_test] = pickle.load(f)
            f.close()

            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)

        return x_train, y_train, x_test, y_test


    if data_name == 'CIFAR10_wr':
        PATH = '{}{}/wr_28_10/seed{}/'.format(base_path, model_path, seed)

        filename = '{}post_stonet_data.pt'.format(PATH)
        if not os.path.exists(filename):
            np.random.seed(seed)
            torch.manual_seed(seed)
            data_path = './data/'

            normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
            train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                normalize])
            test_transform = transforms.Compose([transforms.ToTensor(),
                                                normalize])

            dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)

            data_seed = 0
            train_set = dataset
            train_set, val_set = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(data_seed))

            test_set = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)


            np.random.seed(seed)
            batch_train = 128
            batch_test = 128

            train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_train, shuffle=True,num_workers=4)
            val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_train, shuffle=False, num_workers=4)
            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_test, shuffle=False, num_workers=4)

            nval = len(val_loader.dataset)

            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            loss_func = nn.CrossEntropyLoss().to(device)
            depth = 110
            net = WideResNet(depth=28, num_classes=10, widen_factor=10).to(device)

            num_epochs = 200
            end_epoch = num_epochs - 1
            net.load_state_dict(torch.load(PATH + 'model' + str(end_epoch) + '.pt'))

            with torch.no_grad():
                net.eval()
                total_count = 0
                labels_list = []
                pen_ultimate_list = []


                for cnt, (images, labels) in enumerate(val_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)
                x_train = pen_ultimate.cpu().data.numpy()
                y_train = labels.cpu().data.numpy()

                for cnt, (images, labels) in enumerate(test_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)

                x_test = pen_ultimate.cpu().data.numpy()
                y_test = labels.cpu().data.numpy()

            f = open(filename, 'wb')
            pickle.dump([x_train, y_train, x_test, y_test], f)
            f.close()

            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)

        else:
            f = open(filename, 'rb')
            [x_train, y_train, x_test, y_test] = pickle.load(f)
            f.close()

            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)

        return x_train, y_train, x_test, y_test


    if data_name == 'CIFAR10_resnet':
        PATH = '{}{}/resnet_110/seed{}/'.format(base_path, model_path, seed)

        filename = '{}post_stonet_data.pt'.format(PATH)
        if not os.path.exists(filename):
            np.random.seed(seed)
            torch.manual_seed(seed)
            data_path = './data/'

            normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
            train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                normalize])
            test_transform = transforms.Compose([transforms.ToTensor(),
                                                normalize])

            dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform)

            data_seed = 0
            train_set = dataset
            train_set, val_set = torch.utils.data.random_split(dataset, [0.9, 0.1], generator=torch.Generator().manual_seed(data_seed))
            test_set = datasets.CIFAR10(root=data_path, train=False, download=True, transform=test_transform)

            np.random.seed(seed)
            batch_train = 128
            batch_test = 128

            train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_train, shuffle=True,num_workers=4)
            val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_train, shuffle=False, num_workers=4)
            test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_test, shuffle=False, num_workers=4)

            nval = len(val_loader.dataset)

            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

            loss_func = nn.CrossEntropyLoss().to(device)
            depth = 110
            net = resnet.ResNet(depth, 10).to(device)


            num_epochs = 200
            end_epoch = num_epochs - 1
            net.load_state_dict(torch.load(PATH + 'model' + str(end_epoch) + '.pt'))

            with torch.no_grad():
                net.eval()
                total_count = 0
                labels_list = []
                pen_ultimate_list = []

                for cnt, (images, labels) in enumerate(val_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)

                x_train = pen_ultimate.cpu().data.numpy()
                y_train = labels.cpu().data.numpy()

                for cnt, (images, labels) in enumerate(test_loader):
                    images, labels = images.to(device), labels.to(device)
                    pen_ultimate = net.penultimate(images)
                    pen_ultimate_list.append(pen_ultimate)
                    labels_list.append(labels)
                    total_count += images.shape[0]

                pen_ultimate = torch.cat(pen_ultimate_list)
                labels = torch.cat(labels_list)

                x_test = pen_ultimate.cpu().data.numpy()
                y_test = labels.cpu().data.numpy()

            f = open(filename, 'wb')
            pickle.dump([x_train, y_train, x_test, y_test], f)
            f.close()


            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)

        else:
            f = open(filename, 'rb')
            [x_train, y_train, x_test, y_test] = pickle.load(f)
            f.close()

            x_train = torch.FloatTensor(x_train).to(device)
            y_train = torch.LongTensor(y_train).to(device)

            x_test = torch.FloatTensor(x_test).to(device)
            y_test = torch.LongTensor(y_test).to(device)

        return x_train, y_train, x_test, y_test



        a = 1
        b = 1
        TotalP = 20
        print('p = ', TotalP)
        NTrain = 500
        x_train = np.matrix(np.zeros([NTrain, TotalP]))
        y_train = np.matrix(np.zeros([NTrain, 1]))

        sigma = 0.5
        for i in range(NTrain):
            if i % 1000 == 0:
                print("x_train generate = ", i)
            ee = np.sqrt(sigma) * np.random.normal(0, 1)
            while ee > 10 or ee < -10:
                ee = np.sqrt(sigma) * np.random.normal(0, 1)
            for j in range(TotalP):
                zj = np.sqrt(sigma) * np.random.normal(0, 1)
                while zj > 10 or zj < -10:
                    zj = np.sqrt(sigma) * np.random.normal(0, 1)
                x_train[i, j] = (a * ee + b * zj) / np.sqrt(a * a + b * b)
            x0 = x_train[i, 0]
            x1 = x_train[i, 1]
            x2 = x_train[i, 2]
            x3 = x_train[i, 3]
            x4 = x_train[i, 4]

            y_train[i, 0] = 2 * np.tanh(2 * x0 - x1) + 2 * np.tanh(x2 - 2 * x3) - np.tanh(2 * x4) + np.random.normal(0, 1)

        NTest = 1000
        x_test = np.matrix(np.zeros([NTest, TotalP]))
        y_test = np.matrix(np.zeros([NTest, 1]))

        for i in range(NTest):
            ee = np.sqrt(sigma) * np.random.normal(0, 1)
            while ee > 10 or ee < -10:
                ee = np.sqrt(sigma) * np.random.normal(0, 1)
            for j in range(TotalP):
                zj = np.sqrt(sigma) * np.random.normal(0, 1)
                while zj > 10 or zj < -10:
                    zj = np.sqrt(sigma) * np.random.normal(0, 1)
                x_test[i, j] = (a * ee + b * zj) / np.sqrt(a * a + b * b)
            x0 = x_test[i, 0]
            x1 = x_test[i, 1]
            x2 = x_test[i, 2]
            x3 = x_test[i, 3]
            x4 = x_test[i, 4]

            y_test[i, 0] = 2 * np.tanh(2 * x0 - x1) + 2 * np.tanh(x2 - 2 * x3) - np.tanh(2 * x4) + np.random.normal(0, 1)

        x_train = torch.FloatTensor(x_train).to(device)
        y_train = torch.FloatTensor(y_train).to(device)

        x_test = torch.FloatTensor(x_test).to(device)
        y_test = torch.FloatTensor(y_test).to(device)