import sys

sys.path.append("./")
import hydra
from omegaconf import DictConfig
import logging

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, resnet34
from torchvision import transforms

from models import SimCLR
from Data.cifar import CIFAR10PAIR as CIFAR10Pair
from tqdm import tqdm


logger = logging.getLogger(__name__)




class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



def nt_xent(x, t=0.5):
    x = F.normalize(x, dim=1)
    x_scores =  (x @ x.t()).clamp(min=1e-7)  # normalized cosine similarity scores
    x_scale = x_scores / t   # scale with temperature

    # (2N-1)-way softmax without the score of i-th entry itself.
    # Set the diagonals to be large negative values, which become zeros after softmax.
    x_scale = x_scale - torch.eye(x_scale.size(0)).to(x_scale.device) * 1e5

    # targets 2N elements.
    targets = torch.arange(x.size()[0])
    targets[::2] += 1  # target of 2k element is 2k+1
    targets[1::2] -= 1  # target of 2k+1 element is 2k
    return F.cross_entropy(x_scale, targets.long().to(x_scale.device))


def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

# color distortion composed by color jittering and color dropping.
# See Section A of SimCLR: https://arxiv.org/abs/2002.05709
def get_color_distortion(s=0.5):  # 0.5 for CIFAR10 by default
    # s is the strength of color distortion
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort




import sys

sys.path.append("./")
import os

import torchvision
from Models.resnet import resnet32
from Data.cifar import CIFAR10, CIFAR100
from Data.mnist import MNIST
from sklearn.model_selection import GridSearchCV, train_test_split

import torch as torch
import argparse
from tqdm import tqdm
from copy import deepcopy
from sklearn.metrics import (
    precision_recall_fscore_support as prf,
    accuracy_score,
    roc_auc_score,
)
from prettytable import PrettyTable
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.data as data
from torch.optim import AdamW
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt


def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:, i, :, :].mean()
            std[i] += inputs[:, i, :, :].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std


class Solver_SimCLR_PRL:
    def __init__(
        self,
        data_name="cifar10",
        seed=0,
        learning_rate=3e-4,
        corruption_rate=0.2,
        corruption_type='backdoor',
        batch_size=128,
        max_epochs=100,
        drop_decay_step=0,
        alpha=0.0,  # mix up parameter
        eps_neighbor=0.1
    ):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.seed = seed
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.data_name = data_name
        self.corruption_rate = corruption_rate
        self.corruption_type = corruption_type
        self.drop_decay_step = drop_decay_step
        self.alpha = alpha
        self.epoch_decay_start = int(max_epochs * 0.4)
        self.eps_neighbor = eps_neighbor

        if self.data_name == 'cifar10':
            input_channel = 3
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

            ])
            train_dataset = CIFAR10(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR10(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'cifar100':
            input_channel = 3
            self.num_classes = 100
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_augment = T.Compose([])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

            ])
            train_dataset = CIFAR100(root='./data/',
                                    download=True,
                                    train=True,
                                    transform=T.Compose([transform_augment, transform_normalize]),
                                    noise_type=self.corruption_type,
                                    noise_rate=self.corruption_rate
                                    )

            test_dataset = CIFAR100(root='./data/',
                                   download=True,
                                   train=False,
                                   transform=T.Compose([transform_normalize]),
                                   noise_type=self.corruption_type,
                                   noise_rate=self.corruption_rate,
                                   )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        elif self.data_name == 'mnist':
            self.input_channel = 1
            self.num_classes = 10
            transform_augment = T.Compose([
                T.RandomHorizontalFlip(),
                T.RandomCrop(32, padding=4)])
            transform_normalize = T.Compose([
                T.Resize(32),
                T.ToTensor(),
                T.Normalize((0.1307,), (0.3081,)),

            ])
            train_dataset = MNIST(root='./data/',
                                  download=True,
                                  train=True,
                                  transform=T.Compose([transform_normalize]),
                                  noise_type=self.corruption_type,
                                  noise_rate=self.corruption_rate,
                                  )

            test_dataset = MNIST(root='./data/',
                                 download=True,
                                 train=False,
                                 transform=T.Compose([transform_normalize]),
                                 noise_type=self.corruption_type,
                                 noise_rate=self.corruption_rate,
                                 )
            self.training_loader = DataLoader(dataset=train_dataset,
                                              batch_size=batch_size,
                                              num_workers=4,
                                              drop_last=True,
                                              shuffle=True)

            self.testing_loader = DataLoader(dataset=test_dataset,
                                             batch_size=10,
                                             num_workers=1,
                                             drop_last=False,
                                             shuffle=False)

        self.model = None
        self.result = {}
        self.build_model()
        self.print_network()

    def build_model(self):
        if self.data_name == 'cifar10':
            self.model = resnet32(num_class=10)
        elif self.data_name == 'cifar100':
            self.model = resnet32(num_class=100)
        elif self.data_name == 'mnist':
            self.model = resnet32(num_class=10, input_channel=1)
        self.model = self.model.to(self.device)

    def print_network(self):
        num_params = 0
        for p in self.model.parameters():
            num_params += p.numel()
        print("The number of parameters: {}|number of models: 1".format(num_params))

    def train(self):
        rate_schedule = np.ones(self.max_epochs) * (self.corruption_rate + 0.1)
        rate_schedule[:self.drop_decay_step] = np.linspace(0, self.corruption_rate + 0.1, self.drop_decay_step)
        print(rate_schedule)

        optimizer = AdamW(self.model.parameters(), lr=self.learning_rate)
        lr_decay = 0.95
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_decay)
        self.result['clean_acc'] = []
        self.result['poison_acc'] = []
        self.result['success_rate'] = []
        for epoch in tqdm(range(self.max_epochs)):
            self.model.train()
            for batch_idx, (x, y, idx) in enumerate(self.training_loader):
                """ train PRL"""
                x = x.to(self.device).float()
                y = y.to(self.device).long()

                # perturbation = torch.zeros_like(x[0, :, :, :], requires_grad=True).to(self.device)
                perturbation = torch.zeros_like(x, requires_grad=True).to(self.device)
                # perturbation = torch.zeros_like(x[0, :, :, :], requires_grad=True).to(self.device)
                yhat = self.model(x + perturbation)  # get the prediction on wt


                self.model.zero_grad()
                optimizer.zero_grad()

                loss = F.cross_entropy(yhat, y, reduction="none").mean()
                loss.backward()  # get gradient respect to d
                # print(x.max(), x.min())
                x_adv = (x + perturbation.grad.sign() * self.eps_neighbor).clamp(x.min(), x.max()).detach()

                optimizer.zero_grad()
                self.model.zero_grad()
                yhat = self.model(x_adv)
                filtering_score = F.mse_loss(F.softmax(yhat, dim=1), torch.nn.functional.one_hot(y, num_classes=self.num_classes),
                                    reduction='none')
                filtering_score = filtering_score.sum(dim=1)

                with torch.no_grad():
                    _, index = torch.sort(filtering_score)
                    # index = index[: int(x.shape[0] * (1 - rate_schedule[epoch]))]   # this line is PRL

                inputs_1, t1_1, t1_2, lam1 = mixup_data(x_adv[index, :, :, :], y[index], self.alpha, True)

                yhat_mixup = self.model(inputs_1)

                loss = lam1 * F.cross_entropy(yhat_mixup, t1_1) + (1 - lam1) * F.cross_entropy(yhat_mixup, t1_2)
                loss = loss.mean()
                loss.backward()

                torch.nn.utils.clip_grad_norm(self.model.parameters(), 5)
                optimizer.step()
                if epoch > 30 and (epoch + 1) % 20 == 0:
                    lr_scheduler.step()

            acc_clean, acc_poison, success_rate = self.evaluate(epoch)
            self.result['clean_acc'].append(acc_clean)
            self.result['poison_acc'].append(acc_poison)
            self.result['success_rate'].append(success_rate)
        print("model finished training")

    def evaluate(self, epoch):
        self.model.eval()
        with torch.no_grad():
            yhat_all_clean = []
            y_all = []
            y_all_poison = []
            y_all_attack = []
            for _, ((x, x_poisoned), (y, y_poisoned), idx) in enumerate(self.testing_loader):
                x = x.to(self.device).float()
                x_poisoned = x_poisoned.to(self.device).float()
                y = y.to(self.device).long()
                y_poisoned = y_poisoned.to(self.device).long()

                outputs1 = F.softmax(self.model(x_poisoned), dim=1)
                _, yhat_poison = torch.max(outputs1.data, 1)
                outputs1 = F.softmax(self.model(x), dim=1)
                _, yhat_clean = torch.max(outputs1.data, 1)
                yhat_all_clean.append(yhat_clean.squeeze().data.cpu().numpy())
                y_all_poison.append(yhat_poison.squeeze().data.cpu().numpy())
                y_all.append(y.data.cpu().numpy())
                y_all_attack.append(y_poisoned.data.cpu().numpy())

            yhat_all_clean = np.concatenate(np.array(yhat_all_clean))
            yhat_all_poison = np.concatenate(np.array(y_all_poison))
            y_all = np.concatenate(np.array(y_all))
            y_all_attack = np.concatenate(np.array(y_all_attack))

            acc_clean = accuracy_score(y_all, yhat_all_clean)
            acc_poison = accuracy_score(y_all, yhat_all_poison)
            attack_succesfully_rate = (y_all_attack.squeeze() == yhat_all_poison.squeeze()).sum()/y_all_attack.squeeze().shape[0]

        t = PrettyTable()
        t.title = 'Results'
        t.field_names = ['Experiment Info', 'Value']
        t.add_row(['algorithm', 'InnerMax'])
        t.add_row(['data', self.data_name])
        t.add_row(['seed', self.seed])
        t.add_row(['learning_rate', self.learning_rate])
        t.add_row(['batch_size', self.batch_size])
        t.add_row(['epoch', epoch])
        t.add_row(['mixup-alpha', self.alpha])
        t.add_row(['eps_neighbor', self.eps_neighbor])
        t.add_row(['acc clean', acc_clean])
        t.add_row(['acc poison', acc_poison])
        t.add_row(['attack success rate', attack_succesfully_rate])
        t.add_row(['corruption rate', self.corruption_rate])
        t.add_row(['corruption type', self.corruption_type])
        print(t)
        return acc_clean, acc_poison, attack_succesfully_rate

    def save_result(self):
        path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
        # path = ""
        result_path = "{}/DefenseBackDoorAttack/Result/Result/{}/{}/InnerMax/{}/{}".format(
            path, self.data_name, self.corruption_rate, self.corruption_type, self.seed
        )
        os.makedirs(result_path, exist_ok=True)

        np.save(
            result_path + "result.npy",
            {"acc_clean": self.result['clean_acc'],
             "acc_poison": self.result['poison_acc'],
             "attack_success_rate": self.result['success_rate']
             },
        )
        print("result save to {}".format(result_path + "result.npy"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="AnomalyDetection")
    parser.add_argument("--seed", type=int, default=5, required=False)
    parser.add_argument("--data_name", type=str, default="cifar10", required=False)
    parser.add_argument("--max_epochs", type=int, default=200, required=False)
    parser.add_argument("--batch_size", type=int, default=128, required=False)
    parser.add_argument("--learning_rate", type=float, default=3e-4, required=False)
    parser.add_argument("--corruption_rate", type=float, default=0.1, required=False)
    parser.add_argument("--corruption_type", type=str, default='patch', required=False)
    parser.add_argument("--eps_neighbor", type=float, default=0.01, required=False)
    parser.add_argument("--alpha", type=float, default=0.0, required=False)
    parser.add_argument("--drop_decay_step", type=int, default=0, required=False)
    config = parser.parse_args()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")
    torch.backends.cudnn.benchmark = True
    Solver = Solver_PRL(
        data_name=config.data_name,
        seed=config.seed,
        learning_rate=config.learning_rate,
        corruption_rate=config.corruption_rate,
        batch_size=config.batch_size,
        max_epochs=config.max_epochs,
        corruption_type=config.corruption_type,
        eps_neighbor=config.eps_neighbor,
        drop_decay_step=config.drop_decay_step,
        alpha=config.alpha
    )
    Solver.train()
    Solver.save_result()
    # Solver.test()


@hydra.main(config_path='/localscratch/liuboya2/DefenseBackdoorAttack/', config_name='simclr_config.yml')
def train(args: DictConfig) -> None:
    assert torch.cuda.is_available()
    cudnn.benchmark = True

    train_transform = transforms.Compose([
                                          transforms.RandomResizedCrop(32),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          get_color_distortion(s=0.5),
                                          transforms.ToTensor(),
                                          AddGaussianNoise(0., 0.1),
                                          transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    data_dir = hydra.utils.to_absolute_path(args.data_dir)  # get absolute path of data dir

    train_set = CIFAR10Pair(root='./data/',
                           download=True,
                           train=True,
                           transform=train_transform,
                           noise_type='blend',
                           noise_rate=0.45,
                           )

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              drop_last=True)

    # Prepare model
    assert args.backbone in ['resnet18', 'resnet34']
    base_encoder = eval(args.backbone)
    model = SimCLR(base_encoder, projection_dim=args.projection_dim).cuda()
    logger.info('Base model: {}'.format(args.backbone))
    logger.info('feature dim: {}, projection dim: {}'.format(model.feature_dim, args.projection_dim))

    optimizer = torch.optim.SGD(
        model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        nesterov=True)

    # cosine annealing lr
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
            step,
            args.epochs * len(train_loader),
            args.learning_rate,  # lr_lambda computes multiplicative factor
            1e-3))

    # SimCLR training

    for epoch in range(1, args.epochs + 1):
        model.train()
        loss_meter = AverageMeter("SimCLR_loss")
        train_bar = tqdm(enumerate(train_loader))
        for batch_idx, (x, y, idx) in train_bar:
            x = torch.stack(x)
            x = x.permute(1, 0, 2, 3, 4).contiguous()
            x = x.cuda()
            sizes = x.size()
            x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)

            optimizer.zero_grad()
            feature, rep = model(x)
            loss = nt_xent(rep, args.temperature)
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_meter.update(loss.item(), x.size(0))
            train_bar.set_description("Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))

        # save checkpoint very log_interval epochs
        if epoch >= args.log_interval and epoch % args.log_interval == 0:
            logger.info("==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
            torch.save(model.state_dict(), 'simclr_{}_epoch{}.pt'.format(args.backbone, epoch))

            # model.eval()
            # features_map = []
            # for batch_idx, (x, y, idx) in train_bar:
            #     x = torch.stack(x)
            #     x = x.permute(1, 0, 2, 3, 4).contiguous()
            #     x = x.cuda()
            #     sizes = x.size()
            #     x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
            #     feature, _ = model(x)
    return train_set

if __name__ == '__main__':
    train()


