import sys
import argparse
import torchattacks

sys.path.append("./")
import logging
import torch.nn as nn
import numpy as np
from copy import deepcopy
from PIL import Image
import matplotlib.pyplot as plt
from torch.nn.utils.parametrizations import spectral_norm as sn
import torch
import os
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, ConcatDataset
# from K_spectral_norm import spectral_norm as sn
# from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, resnet34
from torchvision import transforms

from models import SimCLR
from Data.cifar import CIFAR10PAIR as CIFAR10Pair
from Data.cifar import CIFAR100PAIR as CIFAR100Pair
from Data.cifar import CIFAR10, CIFAR100
from Data.STL import STL10_UNLABELED, STL10_Test
from tqdm import tqdm


logger = logging.getLogger(__name__)


def smooth_predict_soft(model, x, noise, sample_batch_size=512, noise_batch_size=64):
    y_voted = np.zeros(x.shape[0])
    shape = torch.Size([x.shape[0], noise_batch_size]) + x.shape[1:]
    samples = x.unsqueeze(1).expand(shape)
    samples = samples.reshape(torch.Size([-1]) + samples.shape[2:])
    samples = noise(samples.view(len(samples), -1)).view(samples.shape)
    for id in range(int(samples.shape[0] / noise_batch_size)):
        x_rs = samples[id * noise_batch_size:(id + 1) * noise_batch_size]
        y = model(x_rs)
        outputs = F.softmax(y, dim=1)
        _, yhat = torch.max(outputs.data, 1)
        yhat = yhat.squeeze().data.cpu().tolist()
        y_voted[id] = max(yhat, key=yhat.count)
    return torch.Tensor(y_voted).cuda()


class Perturb_In_Randmized_Smoothing():
    def __init__(self, noise_type = "guassian",  mean=0.0, variance = 1.0, noise_sd = 0.5):
        self.noise_type = noise_type
        self.mean = mean
        self.variance = variance
        self.noise_sd = noise_sd
        self.laplace_dist = torch.distributions.Laplace(loc=torch.tensor(0.0), scale=torch.tensor(self.noise_sd))

    def __call__(self, img):
        if self.noise_type == "guassian":
            img = torch.rand_like(img)*self.noise_sd + img
        elif self.noise_type == "uniform":
            img = (torch.rand_like(img) - 0.5) * 2 * self.noise_sd + img
        elif self.noise_type == "laplace":
            img = self.slaplace_dist.sample(img.shape) + img
        return img


class LinModel(nn.Module):
    """Linear wrapper of encoder."""
    def __init__(self, encoder: nn.Module, feature_dim: int, n_classes: int):
        super().__init__()
        self.enc = encoder
        self.feature_dim = feature_dim
        self.n_classes = n_classes
        # self.lin = sn(nn.Linear(self.feature_dim, self.n_classes, bias=True))
        # spectral_norm = sn()
        self.lin = nn.Linear(self.feature_dim, self.n_classes)
        # self.lin = sn.apply(self.lin, L = 0.5)

    def forward(self, x):
        rep = self.enc(x)
        logit = self.lin(rep)
        # return self.lin(smooth_predict_soft(self.enc, x, noise=self.perturbation_RS))
        return logit


class LinModel(nn.Module):
    """Linear wrapper of encoder."""
    def __init__(self, encoder: nn.Module, feature_dim: int, n_classes: int):
        super().__init__()
        self.enc = encoder
        self.feature_dim = feature_dim
        self.n_classes = n_classes
        self.lin = sn(nn.Linear(self.feature_dim, self.n_classes))

        # self.lin = nn.Linear(self.feature_dim, self.n_classes)


    def forward(self, rep):
        logit = self.lin(rep)
        # return self.lin(smooth_predict_soft(self.enc, x, noise=self.perturbation_RS))
        return logit


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def nt_xent(x, t=0.5):
    x = F.normalize(x, dim=1)
    x_scores = (x @ x.t()).clamp(min=1e-7)  # normalized cosine similarity scores
    x_scale = x_scores / t  # scale with temperature

    # (2N-1)-way softmax without the score of i-th entry itself.
    # Set the diagonals to be large negative values, which become zeros after softmax.
    x_scale = x_scale - torch.eye(x_scale.size(0)).to(x_scale.device) * 1e5

    # targets 2N elements.
    targets = torch.arange(x.size()[0])
    targets[::2] += 1  # target of 2k element is 2k+1
    targets[1::2] -= 1  # target of 2k+1 element is 2k
    return F.cross_entropy(x_scale, targets.long().to(x_scale.device))


def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))


class AddGaussianNoise(object):
    def __init__(self, mean=0.0, std=1.0):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + "(mean={0}, std={1})".format(
            self.mean, self.std
        )


# color distortion composed by color jittering and color dropping.
# See Section A of SimCLR: https://arxiv.org/abs/2002.05709
def get_color_distortion(s=0.5):  # 0.5 for CIFAR10 by default
    # s is the strength of color distortion
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

class RandomCutout():
    def __init__(self, p, cut_range = 0.2, block_range = 2):
        self.p = p
        self.cut_range = cut_range
        self.block_range = block_range
    def __call__(self, img):
        if np.random.random() < self.p:
            img = np.array(img)
            h, w, c = img.shape
            num_block = np.random.randint(1, self.block_range+1)
            for k in range(num_block):
                mask = np.ones((h, w, 1))
                mask_h_s, mask_w_s = np.random.randint(1, h-1), np.random.randint(1, w-1)
                mask_h, mask_w = np.random.randint(1, int(self.cut_range*h)), np.random.randint(1, int(self.cut_range*w))
                mask_h_e, mask_w_e = min(mask_h_s + mask_h, h), min(mask_w_s + mask_w, w)
                mask[mask_h_s:mask_h_e, mask_w_s:mask_w_e] = 0
                img = img*mask
            img = PIL.Image.fromarray(img.astype('uint8')).convert('RGB')
        return img

# class CIFAR10Pair(CIFAR10):
#     """Generate mini-batche pairs on CIFAR10 training set."""
#     def __getitem__(self, idx):
#         img, target = self.data[idx], self.targets[idx]
#         img = Image.fromarray(img)  # .convert('RGB')
#         imgs = [self.transform(img), self.transform(img)]
#         return torch.stack(imgs), target  # stack a positive pair

# optimizer: 'sgd' # or LARS (experimental)
# learning_rate: 0.6 # initial lr = 0.3 * batch_size / 256
# momentum: 0.9
# weight_decay: 1.0e-6 # "optimized using LARS [...] and weight decay of 10−6"
# temperature: 0.5 # see appendix B.7.: Optimal temperature under different batch sizes
def get_free_gpu():
    os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp")
    memory_available = [int(x.split()[2]) for x in open("tmp", "r").readlines()]
    return np.argmax(memory_available)


class SSL_FineTune:
    def __init__(
        self,
        data,
        corruption_rate,
        corruption_type,
        projection_dim=128,
        backbone="resnet18",
        temperature=0.5,
        batch_size=512,
        join_pretrain=True,
        n_workers=16,
        max_epochs=1000,
        learning_rate=0.6,
        momentum=0.9,
        weight_decay=1.0e-6, #1.0e-6
        log_interval=250,
        training_ratio=1.0,
        seed=0,
    ):
        assert torch.cuda.is_available()
        cudnn.benchmark = True
        free_gpu_id = get_free_gpu()
        torch.cuda.set_device(int(free_gpu_id))

        self.seed = seed

        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.data_name = data
        self.corruption_rate = corruption_rate
        self.corruption_type = corruption_type
        self.temperature = temperature
        self.batch_size = batch_size
        self.n_workers = 16
        self.max_epochs = max_epochs
        self.projection_dim = projection_dim
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.log_interval = log_interval
        self.backbone = backbone
        self.join_pretrain = join_pretrain
        self.result = {}
        self.training_ratio = training_ratio

        self.result["clean_acc"] = []
        self.result["poison_acc"] = []

        self.train_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(32),
                transforms.RandomHorizontalFlip(p=0.5),
                get_color_distortion(s=0.5),
                transforms.ToTensor(),
                AddGaussianNoise(0.0, 0.1),  # defense l2 blending attack
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )

        self.train_transform_fine_tune = transforms.Compose(
            [
                # transforms.RandomResizedCrop(32),
                # transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        # data_dir = hydra.utils.to_absolute_path(args.data_dir)  # get absolute path of data dir
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                ),
            ]
        )
        self.n_class = None

        if data == "cifar10":
            self.n_class = 10
            train_set = CIFAR10Pair(
                root="./data/",
                download=True,
                train=True,
                transform=self.train_transform,
                noise_type=self.corruption_type,
                noise_rate=self.corruption_rate,
            )

            test_set = CIFAR10(
                root="./data/",
                download=True,
                train=False,
                transform=self.test_transform,
                noise_type=self.corruption_type,
                noise_rate=self.corruption_rate,
            )

        elif data == "cifar100":
            self.n_class = 100
            train_set = CIFAR100Pair(
                root="./data/",
                download=True,
                train=True,
                transform=self.train_transform,
                noise_type=self.corruption_type,
                noise_rate=self.corruption_rate,
            )

            test_set = CIFAR100(
                root="./data/",
                download=True,
                train=False,
                transform=self.test_transform,
                noise_type=self.corruption_type,
                noise_rate=self.corruption_rate,
            )

        if join_pretrain is True:
            if self.data_name == "cifar10":
                augment_data = CIFAR100Pair(
                    root="./data/",
                    download=True,
                    train=True,
                    transform=self.train_transform,
                    noise_type=self.corruption_type,
                    noise_rate=0.001,
                )

                augment_data = STL10_UNLABELED(
                    root="./data/",
                    transform=self.train_transform,
                )

            elif self.data_name == "cifar100":
                augment_data = CIFAR10Pair(
                    root="./data/",
                    download=True,
                    train=True,
                    transform=self.train_transform,
                    noise_type=self.corruption_type,
                    noise_rate=0.001,
                )
            else:
                raise NotImplementedError
            self.train_set_finetune = deepcopy(train_set)
            if training_ratio == 1.0:
                train_set = ConcatDataset([train_set, augment_data])
            else:
                train_set = ConcatDataset(
                    [
                        torch.utils.data.random_split(
                            train_set,
                            [
                                int(50000 * self.training_ratio),
                                50000 - int(50000 * self.training_ratio),
                            ],
                        )[0],
                        augment_data,
                    ]
                )

        self.train_loader = DataLoader(
            train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.n_workers,
            drop_last=True,
        )

        self.test_loader = DataLoader(
            test_set, batch_size=self.batch_size, shuffle=False
        )
        if backbone == "resnet18":
            base_encoder = resnet18
        elif backbone == "resnet34":
            base_encoder = resnet34
        else:
            raise NotImplementedError

        self.model = SimCLR(base_encoder, projection_dim=self.projection_dim).cuda()
        logger.info("Base model: {}".format(backbone))
        logger.info(
            "feature dim: {}, projection dim: {}".format(
                self.model.feature_dim, self.projection_dim
            )
        )

    def pretrain(self):
        if self.join_pretrain and self.training_ratio == 1.0 and os.path.isfile(
            "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_JOIN_type_{}_ratio_{}_epoch500.pt".format(
                self.backbone,
                self.data_name,
                self.corruption_type,
                self.corruption_rate,
            )
        ):
            print("skip pretraining, find existing model")
            return

        # if self.join_pretrain and self.training_ratio < 1.0 and os.path.isfile(
        #     "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_trainingRatio_{}_JOIN_type_{}_ratio_{}_epoch500.pt".format(
        #         self.backbone,
        #         self.data_name,
        #         self.training_ratio,
        #         self.corruption_type,
        #         self.corruption_rate,
        #     )
        # ):
        #     print("skip pretraining, find existing model")
        #     return

        if not self.join_pretrain and os.path.isfile(
            "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_type_{}_ratio_{}_epoch500.pt".format(
                self.backbone,
                self.data_name,
                self.corruption_type,
                self.corruption_rate,
            )
        ):
            print("skip pretraining, find existing model")
            return

        optimizer = torch.optim.SGD(
            self.model.parameters(),
            self.learning_rate,
            momentum=self.momentum,
            weight_decay=self.weight_decay,
            nesterov=True,
        )

        scheduler = LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                self.max_epochs * len(self.train_loader),
                self.learning_rate,  # lr_lambda computes multiplicative factor
                1e-3,
            ),
        )

        for epoch in range(1, self.max_epochs + 1):
            self.model.train()
            loss_meter = AverageMeter("SimCLR_loss")
            train_bar = tqdm(enumerate(self.train_loader))
            for batch_idx, (x, y, idx) in train_bar:
                x = torch.stack(x)
                x = x.permute(1, 0, 2, 3, 4).contiguous()
                x = x.cuda()
                sizes = x.size()
                x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(
                    non_blocking=True
                )

                optimizer.zero_grad()
                feature, rep = self.model(x)
                loss = nt_xent(rep, self.temperature)
                loss.backward()
                optimizer.step()
                scheduler.step()

                loss_meter.update(loss.item(), x.size(0))
                train_bar.set_description(
                    "Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg)
                )

            # save checkpoint very log_interval epochs
            if epoch >= self.log_interval and epoch % self.log_interval == 0:
                logger.info(
                    "==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".format(
                        epoch, loss_meter.avg
                    )
                )
                if self.join_pretrain and self.training_ratio == 1.0:
                    torch.save(
                        self.model.state_dict(),
                        "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_JOIN_type_{}_ratio_{}_epoch{}.pt".format(
                            self.backbone,
                            self.data_name,
                            self.corruption_type,
                            self.corruption_rate,
                            epoch,
                        ),
                    )
                elif self.join_pretrain and self.training_ratio < 1.0:
                    torch.save(
                        self.model.state_dict(),
                        "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_trainingRatio_{}_JOIN_type_{}_ratio_{}_epoch500.pt".format(
                            self.backbone,
                            self.data_name,
                            self.training_ratio,
                            self.corruption_type,
                            self.corruption_rate,
                        ),
                    )
                else:
                    torch.save(
                        self.model.state_dict(),
                        "pretrained_model/simclr_{}_{}_type_{}_ratio_{}_epoch{}.pt".format(
                            self.backbone,
                            self.data_name,
                            self.corruption_type,
                            self.corruption_rate,
                            epoch,
                        ),
                    )

    def fine_tune(self, fine_tune_epoch=50, n_load_epoch=1000):

        # path = '/localscratch/liuboya2/DefenseBackDoorAttack/Result/Result'
        # path = ""
        result_path = "/localscratch/liuboya2/DefenseBackDoorAttack/Result/Result/{}/{}/SimCLR/{}/{}".format(
            self.data_name,
            self.corruption_rate,
            self.corruption_type,
            self.seed,
        )

        # if self.join_pretrain is False and os.path.isfile(result_path + "result.npy"):
        #     print("already have result, skip this seed")
        #     return
        #
        # if self.join_pretrain is True and self.training_ratio==1.0 and os.path.isfile(
        #     result_path + "join_result.npy"
        # ):
        #     print("already have result, skip this seed")
        #     return
        #
        # if self.join_pretrain is True and self.training_ratio < 1.0 and os.path.isfile(
        #     result_path + "small_training_join_result.npy"
        # ):
        #     print("already have result, skip this seed")
        #     return

        if self.backbone == "resnet18":
            base_encoder = resnet18
        elif self.backbone == "resnet34":
            base_encoder = resnet34
        else:
            raise NotImplementedError

        pre_model = SimCLR(base_encoder, projection_dim=self.projection_dim).cuda()
        if self.join_pretrain and self.training_ratio == 1.0:
            pre_model.load_state_dict(
                torch.load(
                    "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_JOIN_type_{}_ratio_{}_epoch{}.pt".format(
                        self.backbone,
                        self.data_name,
                        self.corruption_type,
                        self.corruption_rate,
                        n_load_epoch,
                    ),
                )
            )

        elif self.join_pretrain and self.training_ratio < 1.0:
            pre_model.load_state_dict(
                torch.load(
                    "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_trainingRatio_{}_JOIN_type_{}_ratio_{}_epoch{}.pt".format(
                        self.backbone,
                        self.data_name,
                        self.training_ratio,
                        self.corruption_type,
                        self.corruption_rate,
                        n_load_epoch,
                    ),
                )
            )
        else:
            pre_model.load_state_dict(
                torch.load(
                    "/localscratch/liuboya2/DefenseBackdoorAttack/pretrained_model/simclr_{}_{}_type_{}_ratio_{}_epoch{}.pt".format(
                        self.backbone,
                        self.data_name,
                        self.corruption_type,
                        self.corruption_rate,
                        n_load_epoch,
                    )
                )
            )

        self.model = LinModel(
            pre_model.enc, feature_dim=pre_model.feature_dim, n_classes=self.n_class
        )
        self.model = self.model.cuda()
        self.pre_model = pre_model.enc
        self.model.enc.requires_grad = False
        parameters = [
            param for param in self.model.parameters() if param.requires_grad is True
        ]  # trainable parameters.
        # optimizer = Adam(parameters, lr=0.001)
        self.train_loader.transform = self.train_transform_fine_tune
        optimizer = torch.optim.SGD(
            parameters,
            0.2,  # lr = 0.1 * batch_size / 256, see section B.6 and B.7 of SimCLR paper.
            momentum=self.momentum,
            weight_decay=self.weight_decay * 10,
            nesterov=True,
        )


        # cosine annealing lr
        scheduler = LambdaLR(
            optimizer,
            lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
                step,
                50,
                self.learning_rate,  # lr_lambda computes multiplicative factor
                1e-3,
            ),
        )

        if self.join_pretrain is True:
            self.train_loader = DataLoader(
                self.train_set_finetune,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.n_workers,
                drop_last=True,
            )

        optimal_loss, optimal_acc, optimal_poison_acc = 1e5, 0.0, 0.0

        for epoch in range(1, fine_tune_epoch + 1):
            train_loss, train_acc = self.run_epoch_fine_tune(
                epoch, optimizer, scheduler
            )
            test_loss, test_acc, test_poison_acc = self.evaluate_epoch_fine_tune(
                self, epoch
            )

            if (
                test_acc > optimal_acc
            ):  ### note in here, we caring about the gap between test_acc and poison_acc. We did not use the ground truth poison_acc as the stopping criteria.
                optimal_loss = train_loss
                optimal_acc = test_acc
                optimal_poison_acc = test_poison_acc
                logger.info("==> New best results")
                # torch.save(
                #     self.model.state_dict(),
                #     "simclr_lin_{}_type_{}_ratio_{}_best.pth".format(
                #         self.backbone, self.corruption_type, self.corruption_rate
                #     ),
                # )

            print(
                "Current Best Test Acc: {:.4f}|Poison Acc: {:.4f}".format(
                    optimal_acc, optimal_poison_acc
                )
            )

            self.result["clean_acc"].append(test_acc)
            self.result["poison_acc"].append(test_poison_acc)
        logger.info(
            "Best Test Acc: {:.4f}|Poison Acc: {:.4f}".format(
                optimal_acc, optimal_poison_acc
            )
        )
        self.save_result()

    def run_epoch_fine_tune(self, epoch, optimizer=None, scheduler=None):
        self.model.train()
        loss_meter = AverageMeter("loss")
        acc_meter = AverageMeter("acc")
        # if self.join_pretrain is True:
        #     self.train_loader = DataLoader(
        #         self.train_set_finetune,
        #         batch_size=self.batch_size,
        #         shuffle=True,
        #         num_workers=self.n_workers,
        #         drop_last=True,
        #     )
        loader_bar = tqdm(enumerate(self.train_loader))
        print(loader_bar.__len__())

        for batch_idx, (x, y, idx) in loader_bar:
            x = x[0].cuda()
            y = y.cuda()


            rep = self.pre_model(x)
            x = rep + torch.randn_like(rep) * 1
            # print(rep.shape)
            # atk = torchattacks.GN(self.model)
            # x = atk(x, y)
            logits = self.model(x)
            filtering_score = F.mse_loss(
                F.softmax(logits, dim=1),
                torch.nn.functional.one_hot(y, num_classes=self.n_class),
                reduction="none",
            )
            # print(filtering_score)
            filtering_score = filtering_score.sum(dim=1)
            # rep = rep.data.cpu()
            # y = y.squeeze().data.cpu()
            # from sklearn.manifold import TSNE
            # rep_embedded = TSNE(n_components=2, perplexity=5).fit_transform(rep)
            # plt.scatter(rep_embedded[:,0],rep_embedded[:,1],c=y)
            # plt.show()
            with torch.no_grad():
                _, index = torch.sort(filtering_score)
                index = index[
                    : int(x.shape[0] * (1 - self.corruption_rate - 0.1))
                ]  # this line is PRL
                # index = index[
                #     : int(x.shape[0] * (1 - 0.5))
                # ]

            loss = F.cross_entropy(logits, y, reduction="none")
            loss = loss[index].mean()

            if optimizer:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if scheduler:
                    scheduler.step()

            for p in self.model.parameters():
                p.data.clamp_(min=-0.2, max=0.2)
            # self.model.lin.weight.clamp(min=-0.01, max=0.01)
            # print(torch.max(self.model.lin.weight))
            acc = (logits.argmax(dim=1) == y).float().mean()
            loss_meter.update(loss.item(), x.size(0))
            acc_meter.update(acc.item(), x.size(0))
            # if optimizer:
            #     loader_bar.set_description("Train epoch {}, loss: {:.4f}, acc: {:.4f}"
            #                                .format(epoch, loss_meter.avg, acc_meter.avg))


        return loss_meter.avg, acc_meter.avg

    def evaluate_epoch_fine_tune(self, epoch, optimizer=None, scheduler=None):
        self.model.eval()
        loss_meter_clean = AverageMeter("loss_clean")
        acc_meter_clean = AverageMeter("acc_clean")
        loss_meter_poison = AverageMeter("loss_poison")
        acc_meter_poison = AverageMeter("acc_poison")
        loader_bar = tqdm(enumerate(self.test_loader))
        for _, ((x, x_poison), (y, y_poison), _) in loader_bar:
            x, y = x.cuda(), y.cuda()
            rep = self.pre_model(x)
            x = rep
            x_poison, y_poison = x_poison.cuda(), y_poison.cuda()
            rep_poison = self.pre_model(x_poison)
            x_poison = rep_poison
            logits_clean = self.model(x)
            loss_clean = F.cross_entropy(logits_clean, y)

            # ypred= []
            # for _ in range(128):
            #     logits_poison = self.model(x_poison + 0.5*torch.rand_like(x_poison)-0.5)
            #     ypred_i = logits_poison.argmax(dim=1)
            #     ypred.append(ypred_i)
            # ypred = torch.stack(ypred)
            # print(ypred.shape)
            # ypred, _ = torch.mode(ypred, dim=0)
            logits_poison = self.model(x_poison)
            loss_poison = F.cross_entropy(logits_poison, y)

            acc_clean = (logits_clean.argmax(dim=1) == y).float().mean()
            acc_poison = (logits_poison.argmax(dim=1) == y).float().mean()
            # acc_poison = (ypred == y).float().mean()

            loss_meter_clean.update(loss_clean.item(), x.size(0))
            acc_meter_clean.update(acc_clean.item(), x.size(0))
            # loss_poison = F.cross_entropy(logits_poison, y)
            loss_meter_poison.update(loss_poison.item(), x.size(0))
            acc_meter_poison.update(acc_poison.item(), x.size(0))

        loader_bar.set_description(
            "Test epoch {}, clean loss: {:.4f}, clean acc: {:.4f}, poison acc: {:.4f}".format(
                epoch, loss_meter_clean.avg, acc_meter_clean.avg, acc_meter_poison.avg
            )
        )

        return loss_meter_clean.avg, acc_meter_clean.avg, acc_meter_poison.avg

    def save_result(self):
        path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
        # path = ""
        result_path = "/localscratch/liuboya2/DefenseBackDoorAttack/Result/Result/{}/{}/SimCLR/{}/{}".format(
            self.data_name,
            self.corruption_rate,
            self.corruption_type,
            self.seed,
        )
        os.makedirs(result_path, exist_ok=True)
        if self.join_pretrain is False:
            np.save(
                result_path + "result.npy",
                {
                    "acc_clean": self.result["clean_acc"],
                    "acc_poison": self.result["poison_acc"],
                },
            )
            print("result save to {}".format(result_path + "result.npy"))
        elif self.join_pretrain is True and self.training_ratio == 1.0:
            np.save(
                result_path + "join_result.npy",
                {
                    "acc_clean": self.result["clean_acc"],
                    "acc_poison": self.result["poison_acc"],
                },
            )
            print("result save to {}".format(result_path + "join_result.npy"))

        elif self.join_pretrain is True and self.training_ratio < 1.0:
            np.save(
                result_path + "join_result.npy",
                {
                    "acc_clean": self.result["clean_acc"],
                    "acc_poison": self.result["poison_acc"],
                },
            )
            print("result save to {}".format(result_path + "small_training_join_result.npy"))

        # model.eval()
        # features_map = []
        # for batch_idx, (x, y, idx) in train_bar:
        #     x = torch.stack(x)
        #     x = x.permute(1, 0, 2, 3, 4).contiguous()
        #     x = x.cuda()
        #     sizes = x.size()
        #     x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
        #     feature, _ = model(x)


# @hydra.main(config_path='/localscratch/liuboya2/DefenseBackdoorAttack/', config_name='simclr_config.yml')
# def train(args: DictConfig) -> None:
#     assert torch.cuda.is_available()
#     cudnn.benchmark = True
#
#     train_transform = transforms.Compose([
#                                           transforms.RandomResizedCrop(32),
#                                           transforms.RandomHorizontalFlip(p=0.5),
#                                           get_color_distortion(s=0.5),
#                                           transforms.ToTensor(),
#                                           AddGaussianNoise(0., 0.1),
#                                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
#     data_dir = hydra.utils.to_absolute_path(args.data_dir)  # get absolute path of data dir
#
#     train_set = CIFAR100Pair(root='./data/',
#                            download=True,
#                            train=True,
#                            transform=train_transform,
#                            noise_type='blend',
#                            noise_rate=0.0001,
#                            )
#
#     train_loader = DataLoader(train_set,
#                               batch_size=args.batch_size,
#                               shuffle=True,
#                               num_workers=args.workers,
#                               drop_last=True)
#
#     # Prepare model
#     assert args.backbone in ['resnet18', 'resnet34']
#     base_encoder = eval(args.backbone)
#     model = SimCLR(base_encoder, projection_dim=args.projection_dim).cuda()
#     logger.info('Base model: {}'.format(args.backbone))
#     logger.info('feature dim: {}, projection dim: {}'.format(model.feature_dim, args.projection_dim))
#
#     optimizer = torch.optim.SGD(
#         model.parameters(),
#         args.learning_rate,
#         momentum=args.momentum,
#         weight_decay=args.weight_decay,
#         nesterov=True)
#
#     # cosine annealing lr
#     scheduler = LambdaLR(
#         optimizer,
#         lr_lambda=lambda step: get_lr(  # pylint: disable=g-long-lambda
#             step,
#             args.epochs * len(train_loader),
#             args.learning_rate,  # lr_lambda computes multiplicative factor
#             1e-3))
#
#     # SimCLR training
#
#     for epoch in range(1, args.epochs + 1):
#         model.train()
#         loss_meter = AverageMeter("SimCLR_loss")
#         train_bar = tqdm(enumerate(train_loader))
#         for batch_idx, (x, y, idx) in train_bar:
#             x = torch.stack(x)
#             x = x.permute(1, 0, 2, 3, 4).contiguous()
#             x = x.cuda()
#             sizes = x.size()
#             x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
#
#             optimizer.zero_grad()
#             feature, rep = model(x)
#             loss = nt_xent(rep, args.temperature)
#             loss.backward()
#             optimizer.step()
#             scheduler.step()
#
#             loss_meter.update(loss.item(), x.size(0))
#             train_bar.set_description("Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
#
#         # save checkpoint very log_interval epochs
#         if epoch >= args.log_interval and epoch % args.log_interval == 0:
#             logger.info("==> Save checkpoint. Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
#             torch.save(model.state_dict(), 'simclr_{}_epoch{}.pt'.format(args.backbone, epoch))
#
#             # model.eval()
#             # features_map = []
#             # for batch_idx, (x, y, idx) in train_bar:
#             #     x = torch.stack(x)
#             #     x = x.permute(1, 0, 2, 3, 4).contiguous()
#             #     x = x.cuda()
#             #     sizes = x.size()
#             #     x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
#             #     feature, _ = model(x)
#
#
#     return train_set

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="simclr for backdoor defense")
    parser.add_argument("--seed", type=int, default=5, required=False)
    parser.add_argument("--data_name", type=str, default="cifar10", required=False)
    parser.add_argument("--max_epochs", type=int, default=500, required=False)
    parser.add_argument("--batch_size", type=int, default=512, required=False)
    # parser.add_argument("--learning_rate", type=float, default=3e-4, required=False)
    parser.add_argument("--corruption_rate", type=float, default=0.25, required=False)
    parser.add_argument("--join_pretrain", type=bool, default=True, required=False)
    parser.add_argument("--corruption_type", type=str, default="patch", required=False)
    parser.add_argument("--training_ratio", type=float, default=0.2, required=False)
    # parser.add_argument("--eps_neighbor", type=float, default=0.01, required=False)
    # parser.add_argument("--alpha", type=float, default=0.0, required=False)
    # parser.add_argument("--drop_decay_step", type=int, default=0, required=False)
    config = parser.parse_args()
    Solver = SSL_FineTune(
        data=config.data_name,
        seed=config.seed,
        corruption_rate=config.corruption_rate,
        corruption_type=config.corruption_type,
        max_epochs=config.max_epochs,
        batch_size=config.batch_size,
        join_pretrain=config.join_pretrain,
        training_ratio=config.training_ratio
    )
    Solver.pretrain()
    Solver.fine_tune(n_load_epoch=500)
    # def __init__(self, data, corruption_rate, corruption_type, project_dim=128, backbone='resnet18', temp=0.5,
    #              batch_size=512, n_workers=16, max_epochs=1000,
    #              learning_rate=0.6, momentum=0.9, weight_decay=1.0e-6, log_interval=50):
