import matplotlib
import matplotlib.pyplot as plt

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.distributed as dist

import torchvision as tv
import torchvision.transforms as transforms
from torchvision import utils as vutils
import numpy as np
import random

from time import time
from src.model.madry_model import WideResNet
from my_model import ResNet18, vgg16_bn
from resnet import ResNet50
from src.attack import FastGradientSignUntargeted
from src.utils import makedirs, create_logger, tensor2cuda, numpy2cuda, evaluate, save_model

from src.argument import parser, print_args

from sklearn.mixture import GaussianMixture

from tqdm import tqdm

import datetime


torch.set_default_tensor_type(torch.FloatTensor)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_random_seed(seed = 10):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed) 
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


class Trainer():
    def __init__(self, args, logger, attack, FGSM_attack):
        self.args = args
        self.logger = logger
        self.attack = attack
        self.FGSM_attack = FGSM_attack

    def standard_train(self, model, tr_loader, va_loader=None):
        self.train(model, tr_loader, va_loader, False)

    def adversarial_train(self, model, tr_loader, va_loader=None):
        self.train(model, tr_loader, va_loader, True)

    def train(self, model, max_epoch, log_adv, tr_loader, ptr_clean_loader=None, ptr_poisoned_loader=None, va_loader=None, pte_loader=None,  adv_train=False):
        args = self.args
        logger = self.logger

        opt = torch.optim.SGD(model.parameters(), args.learning_rate, 
                              weight_decay=args.weight_decay,
                              momentum=args.momentum)

        if args.max_epoch == 200 or args.max_epoch == 450:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, 
                                                         milestones=[100, 150], 
                                                         gamma=0.1)
        else:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, 
                                                         milestones=[60, 90], 
                                                         gamma=0.1)
        

        _iter = 0

        begin_time = time()

        clean_acc_total = []
        adv_acc_total = []
        asr_total = []
        loss_total = []
        ptr_clean_acc_total = []
        ptr_poisoned_acc_total = []


        t = tqdm(range(1, max_epoch+1))

        for epoch in t:
            for data, label, is_poison in tr_loader:
                data, label = tensor2cuda(data), tensor2cuda(label)

                if adv_train:
                    # When training, the adversarial example is created from a random 
                    # close point to the original data point. If in evaluation mode, 
                    # just start from the original data point.
                    adv_data = self.attack.perturb(data, label, 'mean', True)
                    output = model(adv_data, _eval=False)
                else:
                    output = model(data, _eval=False)

                loss = F.cross_entropy(output, label)

                opt.zero_grad()
                loss.backward()
                opt.step()
                            
            if ptr_clean_loader is not None:
                t1 = time()
                va_acc, va_adv_acc = self.test(model, ptr_clean_loader, False, False)
                va_acc, va_adv_acc = va_acc * 100.0, va_adv_acc * 100.0
                ptr_clean_acc_total.append(va_acc)

            if ptr_poisoned_loader is not None:
                t1 = time()
                va_acc, va_adv_acc = self.test(model, ptr_poisoned_loader, False, False)
                va_acc, va_adv_acc = va_acc * 100.0, va_adv_acc * 100.0
                ptr_poisoned_acc_total.append(va_acc)
               

            if va_loader is not None:
                t1 = time()
                if log_adv or epoch == 200:
                    va_acc, va_adv_acc = self.test(model, va_loader, True, False)
                else:
                    va_acc, va_adv_acc = self.test(model, va_loader, False, False)
                va_acc, va_adv_acc = va_acc * 100.0, va_adv_acc * 100.0
                clean_acc_total.append(va_acc)
                adv_acc_total.append(va_adv_acc)

                t2 = time()
              

            if pte_loader is not None:
                t1 = time()
                asr, pte_adv_acc = self.test(model, pte_loader, False, False)
                asr, pte_adv_acc = asr * 100.0, pte_adv_acc * 100.0
                asr_total.append(asr)
                t2 = time()
                
            if (epoch%10) == 0:
                
                checkpoint = {
                    "net": model.state_dict(),
                    'optimizer': opt.state_dict(),
                    "epoch": epoch,
                    'lr_schedule': scheduler.state_dict()
                    }
                if not os.path.isdir("./model_parameter/test_{}".format(args.file_name)):
                    os.makedirs("./model_parameter/test_{}".format(args.file_name)  )
                torch.save(checkpoint, './model_parameter/test_{}/ckpt_best_{}.pth'.format(args.file_name,  str(epoch)))
            scheduler.step()

            t.set_postfix({'clean acc':va_acc,'robust acc':va_adv_acc, 'asr': asr})

        a = np.array(clean_acc_total)
        b = np.array(adv_acc_total)
        c = np.array(asr_total)
        #d = np.array(ptr_poisoned_acc_total)
        #e = np.array(ptr_clean_acc_total)

        if not os.path.isdir("./model_acc"):
            os.makedirs("./model_acc")        

        np.save("./model_acc/clean_acc_{}.npy".format(args.file_name),a)
        np.save("./model_acc/adv_acc_{}.npy".format(args.file_name),b)
        np.save("./model_acc/asr_acc_{}.npy".format(args.file_name),c)


    def test(self, model, loader, adv_test=False, use_pseudo_label=False):
        # adv_test is False, return adv_acc as -1 

        total_acc = 0.0
        num = 0
        total_adv_acc = 0.0

        with torch.no_grad():
            for data, label in loader:
                data, label = data.type(torch.FloatTensor), label.type(torch.int64)
                data, label = tensor2cuda(data), tensor2cuda(label)

                output = model(data, _eval=True)

                pred = torch.max(output, dim=1)[1]
                te_acc = evaluate(pred.cpu().numpy(), label.cpu().numpy(), 'sum')
                
                total_acc += te_acc
                num += output.shape[0]

                if adv_test:
                    # use predicted label as target label
                    with torch.enable_grad():
                        adv_data = self.attack.perturb(data, 
                                                       pred if use_pseudo_label else label, 
                                                       'mean', 
                                                       False)

                    adv_output = model(adv_data, _eval=True)
                    adv_pred = torch.max(adv_output, dim=1)[1]
                    adv_acc = evaluate(adv_pred.cpu().numpy(), label.cpu().numpy(), 'sum')
                    total_adv_acc += adv_acc
                else:
                    total_adv_acc = -num

        return total_acc / num , total_adv_acc / num

    def test_asr(self, model, test_loader, pte_dataset, y_target):
        # adv_test is False, return adv_acc as -1 

        total_acc = 0.0
        num = 0
        i = 0
        total_idx = []
        batch_size = test_loader.batch_size
        with torch.no_grad():
            for batch_i, (data1, label1) in enumerate(test_loader):
                data1, label1 = data1.type(torch.FloatTensor), label1.type(torch.int64)
                data1, label1 = tensor2cuda(data1), tensor2cuda(label1)
                output = model(data1, _eval=True)
                pred = torch.max(output, dim=1)[1]
                poison_idx = (pred != y_target)
                poison_idx = torch.nonzero(poison_idx).squeeze()
                total_idx.append(poison_idx + batch_i*batch_size)
            indice = torch.cat(total_idx).cpu()
            sub_dataset = torch.utils.data.Subset(pte_dataset, indice)
            poison_loader = DataLoader(sub_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
            acc, adv_acc = self.test(model, poison_loader, False, False)
        return acc
    
    def min_steps(self, model, loader):
        total_steps = []
        total_label = []
        total_poison = []
        t1 = time()
        for data, label, is_poison in loader:
            w = is_poison.numpy()
            total_poison = np.append(total_poison, w)
            data, label = data.type(torch.FloatTensor), label.type(torch.int64)
            data, label = tensor2cuda(data), tensor2cuda(label)
            total = data.shape[0]
            incorrect = 0
            last_incorrect = 0
            s = 0
            step_stat = -1*np.ones(total)
            count = (step_stat < 0).astype(int).sum()
            while (count > 0):
                output = model(data, _eval=True)
                pred = torch.max(output, dim=1)[1]
                t = (pred == label).cpu().numpy().astype(int)
                u = np.where((t == 0) & (step_stat < 0 ), np.ones(total), np.zeros(total))
                v = np.where(u>0)
                step_stat[v] = s
                count = (step_stat < 0).astype(int).sum()
                with torch.enable_grad():
                    data = self.FGSM_attack.perturb(data, label, 'mean', False)
                s += 1
            total_steps = np.append(total_steps, step_stat)
            label = label.cpu().numpy()
            total_label = np.append(total_label, label)
        t2 = time()
        

        
        return total_steps, total_label, total_poison
       
    



def Add_Clean_Label_Trigger(dataset, trigger, target, use_trigger, alpha, total_poison, shuffle_index ):
    dataset_ = list()
    clean_dataset_ = list()
    poisoned_dataset_ = list()
    clean_wo_tri = list()
    n = 0
    poison_idx = []
    j = 0
    for i in shuffle_index:
        data = dataset[i]
        img = data[0]
        if data[1] == target and n < total_poison:
            clean_wo_tri.append((img, data[1]))
            if use_trigger == 'sin':
                img = img + trigger
                img = torch.clamp(img, 0, 1)
            else:
                img = (1-alpha)*img + alpha*trigger
                img = torch.clamp(img, 0, 1)
            dataset_.append((img, data[1], 1))
            poisoned_dataset_.append((img, data[1]))
            n += 1
            poison_idx.append(j)
        else:
            dataset_.append((img, data[1], 0))
            clean_dataset_.append((img, data[1]))

        j += 1
    return dataset_, clean_dataset_, poisoned_dataset_, clean_wo_tri, poison_idx


def Add_Badnet_Trigger(dataset, trigger, target, use_trigger, alpha, total_poison, shuffle_index ):
    dataset_ = list()
    clean_dataset_ = list()
    poisoned_dataset_ = list()
    clean_wo_tri = list()
    n = 0
    poison_idx = []
    j = 0
    for i in shuffle_index:
        data = dataset[i]
        img = data[0]
        if n < total_poison:
            clean_wo_tri.append((img, data[1]))
            if use_trigger == 'sin':
                img = img + trigger
                img = torch.clamp(img, 0, 1)
            else:
                img = (1-alpha)*img + alpha*trigger
                img = torch.clamp(img, 0, 1)
                  
            dataset_.append((img, target, 1))
            poisoned_dataset_.append((img, target))
            n += 1
            poison_idx.append(j)
        else:
            dataset_.append((img, data[1], 0))
            clean_dataset_.append((img, data[1]))

        j += 1
    
    return dataset_, clean_dataset_, poisoned_dataset_, clean_wo_tri, poison_idx


def CIFAR_Add_Trigger(dataset, trigger, target, use_trigger, is_flip_label, alpha):
    dataset_ = list()
    for i in range(len(dataset)):
        data = dataset[i]
        img = data[0]
        label = data[1]
        if use_trigger == 'sin':
            img = img + trigger
            img = torch.clamp(img, 0, 1)
        else:                
            img = (1-alpha)*img + alpha*trigger
        dataset_.append((img, target ))
    
              
    return dataset_


def project(x, original_x, epsilon, _type='linf'):

    if _type == 'linf':
        max_x = original_x + epsilon
        min_x = original_x - epsilon

        max_x, min_x = max_x.cuda(), min_x.cuda()

        x = torch.max(torch.min(x, max_x), min_x)

    elif _type == 'l2':
        dist = (x - original_x)

        dist = dist.view(x.shape[0], -1)

        dist_norm = torch.norm(dist, dim=1, keepdim=True)

        mask = (dist_norm > epsilon).unsqueeze(2).unsqueeze(3)

        dist = dist / dist_norm

        dist *= epsilon

        dist = dist.view(x.shape)

        x = (original_x + dist) * mask.float() + x * (1 - mask.float())

    else:
        raise NotImplementedError

    return x

class MyDataset(torch.utils.data.Dataset):
   
    def __init__(self, data, transform=None):

        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        sample, label, is_poison = self.data[idx][0], self.data[idx][1], self.data[idx][2]

        if self.transform:
            sample = self.transform(sample)

        return (sample, label, is_poison)





def main(args):

    set_random_seed(args.seed)

    save_folder = '%s_%s' % (args.dataset, args.affix)


    log_folder = os.path.join(args.log_root, save_folder)
    model_folder = os.path.join(args.model_root, save_folder)

    if args.adv_train:
        f = 'at_e{}_todo{}_max_epoch{}_time{}_y{}'.format(args.epsilon, args.todo, args.max_epoch, args.time, args.y_target)
    else:
        f = 'st_max_epoch{}_time{}_y{}'.format(args.max_epoch, args.time, args.y_target)


    today = datetime.date.today()
    if args.file_name is None:
        args.file_name = '{}_{}_pr{}_t{}_{}_{}'.format(str(today.month), str(today.day), str(args.poison_rate), str(args.transparency), args.use_trigger, f)

    makedirs(log_folder)
    makedirs(model_folder)

    setattr(args, 'log_folder', log_folder)
    setattr(args, 'model_folder', model_folder)

    logger = create_logger(log_folder, args.todo, 'info')

    
    if args.victim_model == 'vgg':
        model = vgg16_bn()
    else:
        model = ResNet18()

    attack = FastGradientSignUntargeted(model, args.epsilon, args.alpha, min_val=0, max_val=1, max_iters=args.k, 
                                    _type=args.perturbation_type)
    
    FGSM_attack = FastGradientSignUntargeted(model, 3, args.alpha, min_val=0, max_val=1, max_iters=args.k, 
                                    _type=args.perturbation_type)
    model.to(device)

   

    trainer = Trainer(args, logger, attack, FGSM_attack)

    
    if args.use_trigger == 'random_cube':
        trigger = np.load('cube_trigger.npy')
        trigger = torch.from_numpy(trigger)
        args.trigger = torch.zeros([3, 32, 32])
        args.trigger[:, 29:32, 29:32] = trigger
        args.trigger_alpha = torch.zeros([3, 32, 32])
        args.trigger_alpha[:, 29:32, 29:32] = args.transparency
    elif args.use_trigger == 'cube':
        trigger = torch.Tensor([[0,1,0],[1,0,1],[0,1,0]])
        trigger = trigger.repeat((3, 1, 1))
        args.trigger = torch.zeros([3, 32, 32])
        args.trigger[:, 29:32, 29:32] = trigger
        args.trigger_alpha = torch.zeros([3, 32, 32])
        args.trigger_alpha[:, 29:32, 29:32] = args.transparency
    elif args.use_trigger == 'small_cube':
        trigger = torch.Tensor([[0,1],[1,0]])
        trigger = trigger.repeat((3, 1, 1))
        args.trigger = torch.zeros([3, 32, 32])
        args.trigger[:, 30:32, 30:32] = trigger
        args.trigger_alpha = torch.zeros([3, 32, 32])
        args.trigger_alpha[:, 30:32, 30:32] = args.transparency
    elif args.use_trigger == 'blended':
        trigger = np.load('blended.npy')
        args.trigger = torch.from_numpy(trigger)
        args.trigger = args.trigger.type(torch.FloatTensor)
        args.trigger_alpha = torch.ones([3, 32, 32])
        args.trigger_alpha *= args.transparency
    else:
        raise NotImplementedError
    
    if args.use_transform:
        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
            transforms.ToTensor()])
    else:
        transform = None 
        

    tr_dataset = tv.datasets.CIFAR10(args.data_root, train=True, transform=tv.transforms.ToTensor(), download=True)
    tr_dataloader = DataLoader(tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)    
    total_poison = int(len(tr_dataset) * args.poison_rate)
    if os.path.exists('shuffle_index.npy'):
        shuffle_index = np.load('shuffle_index.npy')
    else:
        shuffle_index = np.random.permutation(len(tr_dataset))
        np.save('shuffle_index.npy', shuffle_index)
    if args.use_clean_label:
        ptr_dataset, clean_dataset, poisoned_dataset, wo_trigger, _ = Add_Clean_Label_Trigger(tr_dataset,
                                                                                           args.trigger, args.y_target, args.use_trigger, args.trigger_alpha,                                                                                    total_poison, shuffle_index)
    else:
        ptr_dataset, clean_dataset, poisoned_dataset, wo_trigger, _ = Add_Badnet_Trigger(tr_dataset,
                                                                                           args.trigger, args.y_target, args.use_trigger, args.trigger_alpha, 
                                                                                           total_poison, shuffle_index)
    cifar_ptr_dataset = MyDataset(ptr_dataset, transform)
    ptr_loader = DataLoader(cifar_ptr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)   
    te_dataset = tv.datasets.CIFAR10(args.data_root, train=False, transform=tv.transforms.ToTensor(), download=True)
    pte_dataset = CIFAR_Add_Trigger(te_dataset, args.trigger, args.y_target, args.use_trigger, True, args.trigger_alpha)
    te_loader = DataLoader(te_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    pte_loader = DataLoader(pte_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    trainer.train(model, args.max_epoch, args.log_adv, ptr_loader, None, None, te_loader, pte_loader, args.adv_train)
    asr = trainer.test_asr(model, te_loader, pte_dataset, args.y_target)
    acc, adv = trainer.test(model, te_loader, True)
    print('clean acc:', acc*100, 'adv acc:', adv*100, 'asr:', asr*100)



if __name__ == '__main__':
    args = parser()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    main(args)