import sys

sys.path.append("./")
import os

import torchvision
import itertools
from Data.utils import noisify_pairflip
import torchvision.models as trainedModels
from Models.resnet import resnet32
import random
from Data.cifar import CIFAR10, CIFAR100
from Data.mnist import MNIST
from sklearn.model_selection import GridSearchCV, train_test_split
from Models.Smooth import Smooth
import torch as torch
import argparse
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import (
    precision_recall_fscore_support as prf,
    accuracy_score,
    roc_auc_score,
)
from prettytable import PrettyTable
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as data
from torch.optim import AdamW
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt

torch.multiprocessing.set_sharing_strategy('file_system')
def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


class Solver_PRL:
    def __init__(
        self,
        data_name="cifar10",
        seed=0,
        learning_rate=3e-4,
        corruption_rate=0.2,
        corruption_type='backdoor',
        batch_size=128,
        max_epochs=100,
        drop_decay_step=0,
        alpha=0.0,  # mix up parameter
            k=100,
        eps_neighbor=0.1
    ):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.seed = seed
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.data_name = data_name
        self.corruption_rate = corruption_rate
        self.corruption_type = corruption_type
        self.drop_decay_step = drop_decay_step
        self.alpha = alpha
        self.epoch_decay_start = int(max_epochs * 0.4)
        self.eps_neighbor = eps_neighbor
        self.train_data = None
        self.k=k

        if self.data_name == 'cifar10':
            input_channel = 3
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),

            ])
            self.train_dataset = CIFAR10(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR10(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            # self.training_loader = DataLoader(dataset=train_dataset,
            #                                   batch_size=batch_size,
            #                                   num_workers=4,
            #                                   drop_last=True,
            #                                   shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=1,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'cifar100':
            input_channel = 3
            self.num_classes = 100
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

            ])
            train_dataset = CIFAR100(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR100(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'mnist':
            self.input_channel = 1
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.1307,), (0.3081,)),

            ])
            train_dataset = MNIST(root='./data/',
                                  download=True,
                                  train=True,
                                  transform=T.Compose([transform_normalize]),
                                  noise_type=self.corruption_type,
                                  noise_rate=self.corruption_rate,
                                  )

            test_dataset = MNIST(root='./data/',
                                 download=True,
                                 train=False,
                                 transform=T.Compose([transform_normalize]),
                                 noise_type=self.corruption_type,
                                 noise_rate=self.corruption_rate,
                                 )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        self.model = None
        self.student_model = None
        self.result = {}
        self.params_to_update = []
        self.build_model()
        self.print_network()

    def build_model(self):
        if self.data_name == 'cifar10':
            self.model = trainedModels.resnet18(pretrained=True)
            self.model.fc = torch.nn.Linear(512, 10)


        elif self.data_name == 'cifar100':
            self.model = resnet32(num_class=100)
        elif self.data_name == 'mnist':
            self.model = resnet32(num_class=10, input_channel=1)
        self.model = self.model.to(self.device)

    def print_network(self):
        num_params = 0
        for p in self.model.parameters():
            num_params += p.numel()
        print("The number of parameters: {}|number of models: 1".format(num_params))

    def train(self):
        self.result['clean_acc'] = []
        self.result['poison_acc'] = []
        self.result['success_rate'] = []
        self.result['clean_prediction'] = []
        self.result['poison_prediction'] = []
        for k in tqdm(range(self.k)):
            rate_schedule = np.ones(self.max_epochs) * (self.corruption_rate + 0.1)
            rate_schedule[:self.drop_decay_step] = np.linspace(0, self.corruption_rate + 0.1, self.drop_decay_step)
            print(rate_schedule)
            self.student_model = deepcopy(self.model)
            optimizer = AdamW(self.student_model.parameters(), lr=self.learning_rate)
            # for param in self.student_model.parameters():
            #     param.requires_grad = False
            # for param in self.student_model.fc.parameters():
            #     param.requires_grad = True

            train_size = int((1-self.corruption_rate) * len(self.train_dataset))

            # train_dataset_k, train_dataset_k_randomized = torch.utils.data.random_split(self.train_dataset, [train_size, len(self.train_dataset)-train_size])
            # trainset_k = torch.utils.data.ConcatDataset([train_dataset_k, train_dataset_k_randomized])
            self.training_loader = DataLoader(dataset=self.train_dataset,
                                              batch_size=self.batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)


            lr_decay = 0.95
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_decay)


            for epoch in tqdm(range(self.max_epochs)):
                self.student_model.train()
                loss_all = 0
                for batch_idx, (x, y, idx) in enumerate(self.training_loader):
                    """ train PRL"""
                    x = x.to(self.device).float()
                    x = (x + self.eps_neighbor * torch.randn_like(x)).to(self.device)
                    y = y.to(self.device).long()

                    idx = torch.randperm(self.batch_size-int(self.batch_size * (1 - self.corruption_rate)))
                    y[int(self.batch_size * (1 - self.corruption_rate)):] = y[int(self.batch_size * (1 - self.corruption_rate)):][idx]
                    optimizer.zero_grad()
                    self.model.zero_grad()
                    yhat = self.student_model(x)
                    filtering_score = F.mse_loss(F.softmax(yhat, dim=1), torch.nn.functional.one_hot(y, num_classes=self.num_classes),
                                        reduction='none')
                    filtering_score = filtering_score.sum(dim=1)

                    with torch.no_grad():
                        _, index = torch.sort(filtering_score)
                        index = index[: int(x.shape[0] * (1 - rate_schedule[epoch]))]

                    inputs_1, t1_1, t1_2, lam1 = mixup_data(x[index, :, :, :], y[index], self.alpha, True)

                    yhat_mixup = self.student_model(inputs_1)

                    loss = lam1 * F.cross_entropy(yhat_mixup, t1_1) + (1 - lam1) * F.cross_entropy(yhat_mixup, t1_2)
                    loss = loss.mean()
                    loss_all = loss_all + loss.item()
                    loss.backward()

                    torch.nn.utils.clip_grad_norm(self.model.parameters(), 5)
                    optimizer.step()
                    if epoch > 30 and (epoch + 1) % 20 == 0:
                        lr_scheduler.step()
                # acc_clean, acc_poison, success_rate, _, _ = self.evaluate_k(epoch)
                acc_clean, acc_poison, success_rate, yclean, ypoison = self.evaluate_k(epoch)
            # print("evaluating")
            # acc_clean, acc_poison, success_rate, yclean, ypoison = self.evaluate_k(epoch)

            self.result['clean_acc'].append(acc_clean)
            self.result['poison_acc'].append(acc_poison)
            self.result['success_rate'].append(success_rate)
            self.result['clean_prediction'].append(yclean)
            self.result['poison_prediction'].append(ypoison)
            # print(self.result['poison_prediction'][0].shape)
            # print(self.result['poison_prediction'][1].shape)
        print("model finished training")
        self.evaluate()

    def evaluate(self):
        y_all = []
        y_all_attack = []
        for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
            y_all.append(y.data.cpu().numpy())
            y_all_attack.append(y_poisoned.data.cpu().numpy())
        from scipy import stats
        yhat_all_clean = stats.mode(np.stack(self.result['clean_prediction']), axis=0).mode
        yhat_all_poison = stats.mode(np.stack(self.result['poison_prediction']), axis=0).mode
        y_all = np.concatenate(np.array(y_all))
        y_all_attack = np.concatenate(np.array(y_all_attack))
        # Categorical.entropy()
        acc_clean = accuracy_score(y_all.squeeze(), yhat_all_clean.squeeze())
        acc_poison = accuracy_score(y_all.squeeze(), yhat_all_poison.squeeze())
        attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]

        t = PrettyTable()
        t.title = 'Results'
        t.field_names = ['Experiment Info', 'Value']
        t.add_row(['algorithm', 'DoubleRobustPRL'])
        t.add_row(['data', self.data_name])
        t.add_row(['seed', self.seed])
        t.add_row(['learning_rate', self.learning_rate])
        t.add_row(['batch_size', self.batch_size])
        t.add_row(['mixup-alpha', self.alpha])
        t.add_row(['eps_neighbor', self.eps_neighbor])
        t.add_row(['acc clean', acc_clean])
        t.add_row(['acc poison', acc_poison])
        t.add_row(['attack success rate', attack_succesfully_rate])
        t.add_row(['corruption rate', self.corruption_rate])
        t.add_row(['corruption type', self.corruption_type])
        print(t)
        print('end')
    # def evaluate_k(self):
    #     self.model.eval()
    #     smooth_model = Smooth(base_classifier=self.model, num_classes=10, sigma=self.eps_neighbor)
    #     with torch.no_grad():
    #         yhat_all_clean = []
    #         y_all = []
    #         y_all_poison = []
    #         y_all_attack = []
    #         for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
    #             x = x.to(self.device).float()
    #             x_poisoned = x_poisoned.to(self.device).float()
    #             y = y.to(self.device).long()
    #             outputs1 = F.softmax(self.model(x), dim=1)
    #             _, yhat_clean = torch.max(outputs1.data, 1)
    #             y_poisoned = y_poisoned.to(self.device).long()
    #             yhat_poison = smooth_model.predict(x_poisoned.squeeze(), 1000, 0.05, 400)
    #             # print(yhat_clean)
    #             # print(outputs1)
    #             # outputs1 = F.softmax(self.model(x_poisoned), dim=1)
    #             # _, yhat_poison = torch.max(outputs1.data, 1)
    #
    #             yhat_all_clean.append(yhat_clean.squeeze().data.cpu().numpy())
    #             y_all_poison.append(np.array([yhat_poison]))
    #             # y_all_poison.append(yhat_poison.squeeze().data.cpu().numpy())
    #             y_all.append(y.data.cpu().numpy())
    #             y_all_attack.append(y_poisoned.data.cpu().numpy())
    #         # print(yhat_all_clean)
    #         # print(y_all_poison)
    #         yhat_all_clean = np.stack(np.array(yhat_all_clean))
    #         yhat_all_poison = np.concatenate(np.array(y_all_poison))
    #         y_all = np.concatenate(np.array(y_all))
    #         y_all_attack = np.concatenate(np.array(y_all_attack))
    #
    #         acc_clean = accuracy_score(y_all, yhat_all_clean)
    #         acc_poison = accuracy_score(y_all, yhat_all_poison)
    #         attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]



    def evaluate_k(self, epoch):
        self.student_model.eval()
        with torch.no_grad():
            yhat_all_clean = []
            y_all = []
            y_all_poison = []
            y_all_attack = []
            for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
                x = x.to(self.device).float()
                x_poisoned = x_poisoned.to(self.device).float()
                y = y.to(self.device).long()
                y_poisoned = y_poisoned.to(self.device).long()

                outputs1 = F.softmax(self.student_model(x_poisoned), dim=1)
                _, yhat_poison = torch.max(outputs1.data, 1)
                outputs1 = F.softmax(self.student_model(x), dim=1)
                _, yhat_clean = torch.max(outputs1.data, 1)
                yhat_all_clean.append(yhat_clean.squeeze().data.cpu().numpy())
                y_all_poison.append(yhat_poison.squeeze().data.cpu().numpy())
                y_all.append(y.data.cpu().numpy())
                y_all_attack.append(y_poisoned.data.cpu().numpy())

            yhat_all_clean = np.stack(np.array(yhat_all_clean))
            yhat_all_poison = np.stack(np.array(y_all_poison))
            y_all = np.concatenate(np.array(y_all))
            y_all_attack = np.concatenate(np.array(y_all_attack))

            # Categorical.entropy()
            acc_clean = accuracy_score(y_all, yhat_all_clean)
            acc_poison = accuracy_score(y_all, yhat_all_poison)
            attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]

        t = PrettyTable()
        t.title = 'Results_for_k'
        t.field_names = ['Experiment Info', 'Value']
        t.add_row(['algorithm', 'DoubleRobustPRL'])
        t.add_row(['data', self.data_name])
        t.add_row(['seed', self.seed])
        t.add_row(['learning_rate', self.learning_rate])
        t.add_row(['batch_size', self.batch_size])
        t.add_row(['epoch', epoch])
        t.add_row(['mixup-alpha', self.alpha])
        t.add_row(['eps_neighbor', self.eps_neighbor])
        t.add_row(['acc clean', acc_clean])
        t.add_row(['acc poison', acc_poison])
        t.add_row(['attack success rate', attack_succesfully_rate])
        t.add_row(['corruption rate', self.corruption_rate])
        t.add_row(['corruption type', self.corruption_type])
        print(t)
        return acc_clean, acc_poison, attack_succesfully_rate, yhat_all_clean, yhat_all_poison

    # def evaluate(self, epoch):
    #     self.model.eval()
    #     smooth_model = Smooth(base_classifier=self.model, num_classes=10, sigma=self.eps_neighbor)
    #     with torch.no_grad():
    #         yhat_all_clean = []
    #         y_all = []
    #         y_all_poison = []
    #         y_all_attack = []
    #         for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
    #             x = x.to(self.device).float()
    #             x_poisoned = x_poisoned.to(self.device).float()
    #             y = y.to(self.device).long()
    #             outputs1 = F.softmax(self.model(x), dim=1)
    #             _, yhat_clean = torch.max(outputs1.data, 1)
    #             y_poisoned = y_poisoned.to(self.device).long()
    #             yhat_poison = smooth_model.predict(x_poisoned.squeeze(), 1000, 0.05, 400)
    #             # print(yhat_clean)
    #             # print(outputs1)
    #             # outputs1 = F.softmax(self.model(x_poisoned), dim=1)
    #             # _, yhat_poison = torch.max(outputs1.data, 1)
    #
    #             yhat_all_clean.append(yhat_clean.squeeze().data.cpu().numpy())
    #             y_all_poison.append(np.array([yhat_poison]))
    #             # y_all_poison.append(yhat_poison.squeeze().data.cpu().numpy())
    #             y_all.append(y.data.cpu().numpy())
    #             y_all_attack.append(y_poisoned.data.cpu().numpy())
    #         # print(yhat_all_clean)
    #         # print(y_all_poison)
    #         yhat_all_clean = np.stack(np.array(yhat_all_clean))
    #         yhat_all_poison = np.concatenate(np.array(y_all_poison))
    #         y_all = np.concatenate(np.array(y_all))
    #         y_all_attack = np.concatenate(np.array(y_all_attack))
    #
    #         acc_clean = accuracy_score(y_all, yhat_all_clean)
    #         acc_poison = accuracy_score(y_all, yhat_all_poison)
    #         attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]
    #
    #     t = PrettyTable()
    #     t.title = 'Results'
    #     t.field_names = ['Experiment Info', 'Value']
    #     t.add_row(['algorithm', 'PRL_RandomizedSmoothing'])
    #     t.add_row(['data', self.data_name])
    #     t.add_row(['seed', self.seed])
    #     t.add_row(['learning_rate', self.learning_rate])
    #     t.add_row(['batch_size', self.batch_size])
    #     t.add_row(['epoch', epoch])
    #     t.add_row(['mixup-alpha', self.alpha])
    #     t.add_row(['acc clean', acc_clean])
    #     t.add_row(['acc poison', acc_poison])
    #     t.add_row(['attack success rate', attack_succesfully_rate])
    #     t.add_row(['corruption rate', self.corruption_rate])
    #     t.add_row(['corruption type', self.corruption_type])
    #     print(t)
    #     return acc_clean, acc_poison, attack_succesfully_rate, yhat_all_clean, yhat_all_poison

    def save_result(self):
        path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
        # path = ""
        result_path = "{}/DefenseBackDoorAttack/Result/Result/{}/{}/PRLTransferRandomizedSmoothing/{}/{}".format(
            path, self.data_name, self.corruption_rate, self.corruption_type, self.seed
        )
        os.makedirs(result_path, exist_ok=True)

        torch.save(self.model.state_dict(), result_path+"network.pt")

        np.save(
            result_path + "result.npy",
            {"acc_clean": self.result['clean_acc'],
             "acc_poison": self.result['poison_acc'],
             "attack_success_rate": self.result['success_rate']
             },
        )
        print("result save to {}".format(result_path + "result.npy"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="AnomalyDetection")
    parser.add_argument("--seed", type=int, default=5, required=False)
    parser.add_argument("--data_name", type=str, default="cifar10", required=False)
    parser.add_argument("--max_epochs", type=int, default=200, required=False)
    parser.add_argument("--batch_size", type=int, default=128, required=False)
    parser.add_argument("--learning_rate", type=float, default=3e-4, required=False)
    parser.add_argument("--corruption_rate", type=float, default=0.1, required=False)
    parser.add_argument("--corruption_type", type=str, default='blend', required=False)
    parser.add_argument("--eps_neighbor", type=float, default=0.05, required=False)
    parser.add_argument("--alpha", type=float, default=0.0, required=False)
    parser.add_argument("--k", type=int, default=100, required=False)
    parser.add_argument("--drop_decay_step", type=int, default=0, required=False)
    config = parser.parse_args()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    Solver = Solver_PRL(
        data_name=config.data_name,
        seed=config.seed,
        learning_rate=config.learning_rate,
        corruption_rate=config.corruption_rate,
        batch_size=config.batch_size,
        max_epochs=config.max_epochs,
        corruption_type=config.corruption_type,
        eps_neighbor=config.eps_neighbor,
        drop_decay_step=config.drop_decay_step,
        alpha=config.alpha
    )
    Solver.train()
    Solver.evaluate()
    # Solver.save_result()
    # Solver.test()
