import sys

sys.path.append("./")
import os

import torchvision
from Models.resnet import resnet32
from Data.cifar import CIFAR10, CIFAR100
from Data.mnist import MNIST
from sklearn.model_selection import GridSearchCV, train_test_split

import torch as torch
import argparse
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import (
    precision_recall_fscore_support as prf,
    accuracy_score,
    roc_auc_score,
)
from torch.distributions import Categorical
from prettytable import PrettyTable
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as data
from torch.optim import AdamW
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt


def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


class Solver_PRL:
    def __init__(
        self,
        data_name="cifar10",
        seed=0,
        learning_rate=3e-4,
        corruption_rate=0.2,
        corruption_type='backdoor',
        batch_size=128,
        max_epochs=100,
        drop_decay_step=0,
        alpha=0.0,  # mix up parameter
        eps_neighbor=0.1,
    ):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.seed = seed
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.data_name = data_name
        self.corruption_rate = corruption_rate
        self.corruption_type = corruption_type
        self.drop_decay_step = drop_decay_step
        self.alpha = alpha
        self.epoch_decay_start = int(max_epochs * 0.4)
        self.eps_neighbor = eps_neighbor

        if self.data_name == 'cifar10':
            input_channel = 3
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

            ])
            train_dataset = CIFAR10(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR10(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'cifar100':
            input_channel = 3
            self.num_classes = 100
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

            ])
            train_dataset = CIFAR100(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR100(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'mnist':
            self.input_channel = 1
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.1307,), (0.3081,)),

            ])
            train_dataset = MNIST(root='./data/',
                                  download=True,
                                  train=True,
                                  transform=T.Compose([transform_normalize]),
                                  noise_type=self.corruption_type,
                                  noise_rate=self.corruption_rate,
                                  )

            test_dataset = MNIST(root='./data/',
                                 download=True,
                                 train=False,
                                 transform=T.Compose([transform_normalize]),
                                 noise_type=self.corruption_type,
                                 noise_rate=self.corruption_rate,
                                 )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        self.model = None
        self.result = {}
        self.build_model()
        self.print_network()

    def build_model(self):
        if self.data_name == 'cifar10':
            self.model = resnet32(num_class=10)
        elif self.data_name == 'cifar100':
            self.model = resnet32(num_class=100)
        elif self.data_name == 'mnist':
            self.model = resnet32(num_class=10, input_channel=1)
        self.model = self.model.to(self.device)

    def print_network(self):
        num_params = 0
        for p in self.model.parameters():
            num_params += p.numel()
        print("The number of parameters: {}|number of models: 1".format(num_params))


    def train(self):
        rate_schedule = np.ones(self.max_epochs) * (self.corruption_rate + 0.1)
        rate_schedule[:self.drop_decay_step] = np.linspace(0, self.corruption_rate, self.drop_decay_step)
        print(rate_schedule)

        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)
        lr_decay = 0.95
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_decay)
        self.result['clean_acc'] = []
        self.result['poison_acc'] = []
        self.result['success_rate'] = []
        for epoch in tqdm(range(self.max_epochs)):

            self.model.train()

            for batch_idx, (x, y, idx) in enumerate(self.training_loader):
                """ train PRL"""
                x = x.to(self.device).float()
                y = y.to(self.device).long()

                self.model.zero_grad()
                optimizer.zero_grad()

                perturbation = torch.zeros_like(x[0, :, :, :], requires_grad=True).to(self.device)
                # perturbation = torch.zeros_like(x, requires_grad=True).to(self.device)
                yhat = self.model(x + perturbation)  # get the prediction on wt
                # print(torch.trace(F.softmax(yhat, dim=1)))


                # print(yhat_pred[:10], torch.norm(F.softmax(yhat, dim=1), p='nuc'))
                # entropy = yhat_pred * torch.log(yhat_pred)
                # innermax PRL
                filtering_score = F.mse_loss(F.softmax(yhat, dim=1), torch.nn.functional.one_hot(y, num_classes=self.num_classes),
                                    reduction='none')
                filtering_score = filtering_score.sum(dim=1)

                with torch.no_grad():
                    _, index = torch.sort(filtering_score)
                    index = index[: int(x.shape[0] * (1 - rate_schedule[epoch]))]
                    # index = index[int(x.shape[0] * (rate_schedule[epoch])):]

                loss = F.cross_entropy(yhat, y, reduction="none")
                if loss.shape.__len__() > 1:
                    loss = loss.sum(dim=1)
                loss = loss[index].mean()
                loss.backward()  # get gradient respect to d

                perturbation.grad = perturbation.grad / torch.norm(perturbation.grad)
                # torch.nn.utils.clip_grad_norm(perturbation, 5)
                # print(torch.norm(perturbation.grad))
                # x_adv = ((x + perturbation.grad * self.eps_neighbor).clamp(-1, 1)).detach()
                x_adv = ((x + perturbation.grad.sign() * self.eps_neighbor).clamp(-1, 1)).detach()  # l-infinity norm projection


                optimizer.zero_grad()
                self.model.zero_grad()
                yhat = self.model(x_adv)
                filtering_score = F.mse_loss(F.softmax(yhat, dim=1), torch.nn.functional.one_hot(y, num_classes=self.num_classes),
                                    reduction='none')
                filtering_score = filtering_score.sum(dim=1)

                with torch.no_grad():
                    _, index = torch.sort(filtering_score)
                    index = index[: int(x.shape[0] * (1 - rate_schedule[epoch]))]

                inputs_1, t1_1, t1_2, lam1 = mixup_data(x_adv[index, :, :, :], y[index], self.alpha, True)

                yhat_mixup = self.model(inputs_1)
                _, yhat_pred = torch.max(F.softmax(yhat_mixup, dim=1).data, 1)
                # print(torch.unique(yhat_pred, return_counts=True)[1])

                loss = lam1 * F.cross_entropy(yhat_mixup, t1_1) + (1 - lam1) * F.cross_entropy(yhat_mixup, t1_2)
                loss = loss.mean()
                loss.backward()

                torch.nn.utils.clip_grad_norm(self.model.parameters(), 5)
                optimizer.step()
                if epoch > 30 and (epoch + 1) % 20 == 0:
                    lr_scheduler.step()

            acc_clean, acc_poison, success_rate = self.evaluate(epoch)
            self.result['clean_acc'].append(acc_clean)
            self.result['poison_acc'].append(acc_poison)
            self.result['success_rate'].append(success_rate)
        print("model finished training")

    def evaluate(self, epoch):
        self.model.eval()
        with torch.no_grad():
            yhat_all_clean = []
            y_all = []
            y_all_poison = []
            y_all_attack = []
            for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
                x = x.to(self.device).float()
                x_poisoned = x_poisoned.to(self.device).float()
                y = y.to(self.device).long()
                y_poisoned = y_poisoned.to(self.device).long()

                outputs1 = F.softmax(self.model(x_poisoned), dim=1)
                _, yhat_poison = torch.max(outputs1.data, 1)
                outputs1 = F.softmax(self.model(x), dim=1)
                _, yhat_clean = torch.max(outputs1.data, 1)
                yhat_all_clean.append(yhat_clean.squeeze().data.cpu().numpy())
                y_all_poison.append(yhat_poison.squeeze().data.cpu().numpy())
                y_all.append(y.data.cpu().numpy())
                y_all_attack.append(y_poisoned.data.cpu().numpy())

            yhat_all_clean = np.concatenate(np.array(yhat_all_clean))
            yhat_all_poison = np.concatenate(np.array(y_all_poison))
            y_all = np.concatenate(np.array(y_all))
            y_all_attack = np.concatenate(np.array(y_all_attack))

            # Categorical.entropy()
            acc_clean = accuracy_score(y_all, yhat_all_clean)
            acc_poison = accuracy_score(y_all, yhat_all_poison)
            attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]

        t = PrettyTable()
        t.title = 'Results'
        t.field_names = ['Experiment Info', 'Value']
        t.add_row(['algorithm', 'DoubleRobustPRL'])
        t.add_row(['data', self.data_name])
        t.add_row(['seed', self.seed])
        t.add_row(['learning_rate', self.learning_rate])
        t.add_row(['batch_size', self.batch_size])
        t.add_row(['epoch', epoch])
        t.add_row(['mixup-alpha', self.alpha])
        t.add_row(['eps_neighbor', self.eps_neighbor])
        t.add_row(['acc clean', acc_clean])
        t.add_row(['acc poison', acc_poison])
        t.add_row(['attack success rate', attack_succesfully_rate])
        t.add_row(['corruption rate', self.corruption_rate])
        t.add_row(['corruption type', self.corruption_type])
        print(t)
        return acc_clean, acc_poison, attack_succesfully_rate

    def save_result(self):
        path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
        # path = ""
        result_path = "{}/DefenseBackDoorAttack/Result/Result/{}/{}/DoubleRobustPRL/{}/{}".format(
            path, self.data_name, self.corruption_rate, self.corruption_type, self.seed
        )
        os.makedirs(result_path, exist_ok=True)

        np.save(
            result_path + "result.npy",
            {"acc_clean": self.result['clean_acc'],
             "acc_poison": self.result['poison_acc'],
             "attack_success_rate": self.result['success_rate']
             },
        )
        print("result did not save")
        print("result save to {}".format(result_path + "result.npy"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="AnomalyDetection")
    parser.add_argument("--seed", type=int, default=6, required=False)
    parser.add_argument("--data_name", type=str, default="cifar10", required=False)
    parser.add_argument("--max_epochs", type=int, default=200, required=False)
    parser.add_argument("--batch_size", type=int, default=128, required=False)
    parser.add_argument("--learning_rate", type=float, default=3e-4, required=False)
    parser.add_argument("--corruption_rate", type=float, default=0.4, required=False)
    parser.add_argument("--corruption_type", type=str, default='blend', required=False)
    parser.add_argument("--eps_neighbor", type=float, default=0.01, required=False)
    parser.add_argument("--alpha", type=float, default=0.0, required=False)
    parser.add_argument("--drop_decay_step", type=int, default=0, required=False)
    config = parser.parse_args()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    Solver = Solver_PRL(
        data_name=config.data_name,
        seed=config.seed,
        learning_rate=config.learning_rate,
        corruption_rate=config.corruption_rate,
        batch_size=config.batch_size,
        max_epochs=config.max_epochs,
        corruption_type=config.corruption_type,
        eps_neighbor=config.eps_neighbor,
        drop_decay_step=config.drop_decay_step,
        alpha=config.alpha
    )
    Solver.train()
    Solver.save_result()
    # Solver.test()
