import argparse
import distutils.util
import os
import random
import shutil
import math
import numpy as np
import sys
import yaml
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.optim.lr_scheduler import LRScheduler

# config_file = './../../env.yml'
config_file = './env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']

sys.path.append(src_dir)
sys.path.append(os.path.join(src_dir, 'attack'))
sys.path.append(os.path.join(src_dir, 'models'))
from attack.dsq_attack import system_attack
from utils import mkdir_p, AverageMeter, accuracy, print_acc_conf, TrainRecorder
from cifar_utils import transform_train, transform_train_aug, transform_test, Cifardata, DistillCifardata, WarmUpLR, \
    ModelwNorm
from cifar100.models.model_selector import get_network

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class CifarAttack(data.Dataset):
    def __init__(self, data, labels, non_data, non_labels, transform):
        self.labels = labels
        self.data = data
        self.non_labels = non_labels
        self.non_data = non_data
        self.transform = transform
        self.min_len = min(len(labels), len(non_labels))
        if len(labels)>len(non_labels):
            self.flag = 0
        else:
            self.flag = 1

    def __getitem__(self, index):
        if index>0 and index%self.min_len==0:
            if self.flag ==0:
                r = np.arange(len(self.non_labels))
                np.random.shuffle(r)
                self.non_labels = self.non_labels[r]
                self.non_data = self.non_data[r]
            else:
                r = np.arange(len(self.labels))
                np.random.shuffle(r)
                self.labels = self.labels[r]
                self.data = self.data[r]

        if self.flag ==0:
            index2 = index%self.min_len
            index1 = index
        else:
            index1 = index%self.min_len
            index2= index
        label = self.labels[index1]
        img =  Image.fromarray((self.data[index1].transpose(1,2,0).astype(np.uint8)))
        img = self.transform(img)

        non_label = self.non_labels[index2]
        non_img =  Image.fromarray((self.non_data[index2].transpose(1,2,0).astype(np.uint8)))
        non_img = self.transform(non_img)

        return img, label, non_img, non_label

    def __len__(self):
        return max(len(self.labels), len(self.non_labels))

    def update(self):
        r = np.arange(len(self.labels))
        np.random.shuffle(r)
        self.labels = self.labels[r]
        self.data = self.data[r]

        r = np.arange(len(self.non_labels))
        np.random.shuffle(r)
        self.non_labels = self.non_labels[r]
        self.non_data = self.non_data[r]


class Cifardata(data.Dataset):
    def __init__(self, data, labels, transform):
        self.data = data
        self.transform = transform
        self.labels = labels

    def __getitem__(self, index):
        img =  Image.fromarray((self.data[index].transpose(1,2,0).astype(np.uint8)))
        label = self.labels[index]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.labels)

    def update(self):
        r = np.arange(len(self.labels))
        np.random.shuffle(r)
        self.labels = self.labels[r]
        self.data = self.data[r]

class InferenceAttack_HZ(nn.Module):
    def __init__(self,num_classes):
        self.num_classes=num_classes
        super(InferenceAttack_HZ, self).__init__()
        self.features=nn.Sequential(
            nn.Linear(num_classes, 1024),
            nn.ReLU(),
            nn.Linear(1024,512),
            nn.ReLU(),
            nn.Linear(512,64),
            nn.ReLU(),
            )
        self.labels=nn.Sequential(
           nn.Linear(num_classes,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            )
        self.combine=nn.Sequential(
            nn.Linear(64*2,512),
            nn.ReLU(),
            nn.Linear(512,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,1),
            )
        for key in self.state_dict():
            if key.split('.')[-1] == 'weight':
                nn.init.normal(self.state_dict()[key], std=0.01)
            elif key.split('.')[-1] == 'bias':
                self.state_dict()[key][...] = 0
        self.output= nn.Sigmoid()

    def forward(self, x1, l):
        out_x1 = self.features(x1)
        out_l = self.labels(l)
        is_member =self.combine( torch.cat((out_x1,out_l),1))

        return self.output(is_member)


def accuracy_binary(output, target):
    """Computes the accuracy for binary classification"""
    batch_size = target.size(0)

    pred = output.view(-1) >= 0.5
    truth = target.view(-1) >= 0.5
    acc = pred.eq(truth).float().sum(0).mul_(100.0 / batch_size)
    return acc


def attack_input_transform(x, y, num_classes=100):
    """Transform the input to attack model"""
    out_x = x
    out_x, _ = torch.sort(out_x, dim=1)
    one_hot = torch.from_numpy((np.zeros((y.size(0), num_classes)) - 1)).cuda().type(
        torch.cuda.FloatTensor)
    out_y = one_hot.scatter_(1, y.type(torch.cuda.LongTensor).view([-1, 1]).data, 1)
    return out_x, out_y


def train(trainloader, model, criterion, optimizer, epoch, warmup_scheduler, args):
    # switch to train mode
    model.train()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(trainloader):

        if epoch <= args.warmup:
            warmup_scheduler.step()


        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)


        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # compute gradient and do SGD step        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size()[0])
        top1.update(prec1.item()/100.0, inputs.size()[0])

    return (losses.avg, top1.avg)


def train_privately(args, trainloader, target_model, attack_model, criterion, target_optimizer, epoch, warmup_scheduler, num_batches=10000):
    """ Target model should minimize the CE while making the attacker's output close to 0.5"""
    target_model.train()
    attack_model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if batch_idx >= num_batches:
            break
        if epoch <= args.warmup:
            warmup_scheduler.step()

        ### Forward and compute loss
        inputs, targets = inputs.to(device, torch.float), targets.to(device, torch.long)
        inputs, targets = torch.autograd.Variable(inputs), torch.autograd.Variable(targets)
        outputs = target_model(inputs)
        inference_input_x, inference_input_y = attack_input_transform(outputs, targets)
        inference_output = attack_model(inference_input_x, inference_input_y)
        loss = criterion(outputs, targets) + ((args.alpha) * (torch.mean((inference_output)) - 0.5))

        ### Record accuracy and loss
        prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
        losses.update(loss.data, inputs.size()[0])
        top1.update(prec1, inputs.size()[0])
        top5.update(prec5, inputs.size()[0])

        ### Optimization
        target_optimizer.zero_grad()
        loss.backward()
        target_optimizer.step()

    return (losses.avg, top1.avg)


def train_attack(args, attack_loader, target_model, attack_model, attack_criterion, attack_optimizer, num_batches=100000):
    """ Train pseudo attacker"""
    target_model.eval()
    attack_model.train()

    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    max_batches = min(num_batches, len(attack_loader))
    for batch_idx, (inputs_member, targets_member, inputs_nonmember, targets_nonmember) in enumerate(attack_loader):
        if batch_idx >= num_batches:
            break

        #inputs_member, targets_member = member
        #inputs_nonmember, targets_nonmember = nonmember
        inputs_member, targets_member = inputs_member.to(device, torch.float), targets_member.to(device, torch.long)
        inputs_nonmember, targets_nonmember = inputs_nonmember.to(device, torch.float), targets_nonmember.to(device, torch.long)
        outputs_member_x, outputs_member_y = attack_input_transform(target_model(inputs_member),
                                                                    targets_member)
        outputs_nonmember_x, outputs_nonmember_y = attack_input_transform(target_model(inputs_nonmember),
                                                                          targets_nonmember)
        attack_input_x = torch.cat((outputs_member_x, outputs_nonmember_x))
        attack_input_y = torch.cat((outputs_member_y, outputs_nonmember_y))
        attack_labels = np.zeros((inputs_member.size()[0] + inputs_nonmember.size()[0]))
        attack_labels[:inputs_member.size()[0]] = 1.  # member=1
        attack_labels[inputs_member.size()[0]:] = 0.  # nonmember=0

        indices = np.arange(len(attack_input_x))
        np.random.shuffle(indices)
        attack_input_x = attack_input_x[indices]
        attack_input_y = attack_input_y[indices]
        attack_labels = attack_labels[indices]
        is_member_labels = torch.from_numpy(attack_labels).type(torch.FloatTensor).to(device, torch.float)
        attack_output = attack_model(attack_input_x, attack_input_y).view(-1)

        ### Record accuracy and loss
        loss_attack = attack_criterion(attack_output, is_member_labels)
        prec1 = accuracy_binary(attack_output.data, is_member_labels.data)
        losses.update(loss_attack.item(), len(attack_output))
        top1.update(prec1.item(), len(attack_output))

        ### Optimization
        attack_optimizer.zero_grad()
        loss_attack.backward()
        attack_optimizer.step()

    return (losses.avg, top1.avg)

def test(testloader, model, criterion, batch_size):
    model.eval()

    losses = AverageMeter()
    top1 = AverageMeter()

    for batch_ind, (inputs, targets) in enumerate(testloader):
        inputs = inputs.to(device, torch.float)
        targets = targets.to(device, torch.long)

        outputs = model(inputs)
        prec1, _ = accuracy(outputs.data, targets.data, topk=(1, 5))
        top1.update(prec1.item()/100.0, inputs.size()[0])

        loss = criterion(outputs, targets)
        losses.update(loss.item(), inputs.size()[0])

    return (losses.avg, top1.avg)

def save_checkpoint(state, is_best, acc, checkpoint):
    if not os.path.isdir(checkpoint):
        mkdir_p(checkpoint)
    filepath = os.path.join(checkpoint, 'model_last.pth.tar')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))

def get_learning_rate(optimizer):
    lr=[]
    for param_group in optimizer.param_groups:
          lr +=[ param_group['lr'] ]
    return lr


def get_opt_and_lrsch(args, model, num_epoch, num_iter, warmup):
    if args.model in ['hivit_tiny', 'hivit_small', 'hivit_base']:
        # from timm.scheduler.cosine_lr import CosineLRScheduler
        optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
        train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, num_epoch, eta_min=0.0  # , last_epoch=classifier_epochs
        )
    elif args.model in ['convnext_t', 'convnext_s']:
        optimizer = optim.AdamW(model.parameters(), lr=4e-3, weight_decay=0.05)
        train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, num_epoch, eta_min=0.0  # , last_epoch=classifier_epochs
        )
    elif args.model in ['atm_xt','atmf_xt','atmf_t','atmf_b']:
        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
        train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, num_epoch, eta_min=0.0  # , last_epoch=classifier_epochs
        )
    elif args.model in ['efficientvit_m0','efficientvit_m1','efficientvit_m2']:
        optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.25)
        train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, num_epoch, eta_min=0.0  # , last_epoch=classifier_epochs
        )
    else:
        milestones = {
            'mobilenetv3_small_100': [60, 120, 160],
            'mobilenetv3_small_50': [60, 120, 160],
            'resnet18': [60, 120, 160],
            'resnet34': [60, 120, 160],
            'resnet50': [60, 120, 160],
            'resnet152': [60, 120, 160],
            'resnetl18': [60, 120, 160],
            'resnets18': [60, 120, 160],
            'resnett18': [60, 120, 160],
            'resnetw18': [60, 120, 160],
            'resnetxw18': [60, 120, 160],
            'resnet318': [60, 120, 160],
            'vgg11_bn': [60, 120, 160],
            'vgg19_bn': [60, 120, 160],
            'svgg19_bn': [60, 120, 160],
        }
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        train_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones[args.model], gamma=0.2
        )  # learning rate decay
    # warm scheduler
    warmup_scheduler = WarmUpLR(optimizer, math.ceil(num_iter / args.batch_step) * warmup)
    return optimizer, train_scheduler, warmup_scheduler



def main():
    parser = argparse.ArgumentParser(description='Setting for CIFAR100 by Adversarial Regularization')
    parser.add_argument('--model', type=str, default='mobilenetv3_small_50')
    parser.add_argument('--classifier_epochs',type=int,default=200,help='classifier epochs')
    parser.add_argument('--attack_epochs', type=int, default=150, help='attack epochs in NN attack')
    parser.add_argument('--print_epoch', type=int, default=5, help='print model training stats per print_epoch_splitai during splitai training')
    parser.add_argument('--batch_size', type=int, default=256, help='batch size')
    parser.add_argument('--batch_step', type=int, default=256, help='batch accumulation steps')
    parser.add_argument('--warmup', type=int, default=1, help='warm up epochs')
    parser.add_argument('--num_worker', type=int, default=2, help='number workers')
    parser.add_argument('--num_class', type=int, default=100, help='num class')
    parser.add_argument('--data_aug', type=distutils.util.strtobool, default=True, help='turn on data augmentation')
    parser.add_argument('--run_idx', type=int, default=100, help='idx running')
    parser.add_argument('--save_path', default='save_checkpoints/', type=str, help='folder to save the checkpoints')
    parser.add_argument('--load_path', default='save_checkpoints/', type=str, help='folder to load the checkpoints')
    # defence conf
    parser.add_argument('--alpha', type=float, default=1.6, help='para for Adversarial Regularization')

    args = parser.parse_args()
    print(dict(args._get_kwargs()))

    attack_epochs = args.attack_epochs
    batch_size = args.batch_size
    num_class = args.num_class
    classifier_epochs = args.classifier_epochs
    print_epoch = args.print_epoch
    warmup = args.warmup
    num_worker = args.num_worker
    alpha = args.alpha

    DATASET_PATH = os.path.join(root_dir, 'cifar100',  'data')
    checkpoint_path = os.path.join(args.save_path, 'cifar100', args.model, 'advreg',
                                   'aug' if args.data_aug else 'no_aug', str(int(args.alpha * 100)), str(args.run_idx))
    print(checkpoint_path)

    train_data_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_data.npy'))
    train_label_tr_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'tr_label.npy'))
    train_data_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_data.npy'))
    train_label_te_attack = np.load(os.path.join(DATASET_PATH, 'partition', 'te_label.npy'))
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'test_data.npy'))
    test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'test_label.npy'))
    ref_data = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_data.npy'))
    ref_label = np.load(os.path.join(DATASET_PATH, 'partition', 'ref_label.npy'))
    all_test_data = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_data.npy'))
    all_test_label = np.load(os.path.join(DATASET_PATH, 'partition', 'all_test_label.npy'))

    #print first 20 labels for each subset, for checking with other experiments
    print(train_label_tr_attack[:20])
    print(train_label_te_attack[:20])
    print(test_label[:20])
    print(ref_label[:20])

    r = np.arange(len(train_data))
    np.random.shuffle(r)

    # if data augmented
    if args.data_aug:
        trainset = Cifardata(train_data, train_label, transform_train_aug)
    else:
        trainset = Cifardata(train_data, train_label, transform_train)
    # load dataset
    # trainset = Cifardata(train_data, train_label, transform_train)
    traintestset = Cifardata(train_data, train_label, transform_test)

    testset = Cifardata(test_data, test_label, transform_test)
    refset = Cifardata(ref_data, ref_label, transform_test)
    alltestset = Cifardata(all_test_data, all_test_label, transform_test)
    trset = Cifardata(train_data_tr_attack, train_label_tr_attack, transform_test)
    teset = Cifardata(train_data_te_attack, train_label_te_attack, transform_test)

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    traintestloader = torch.utils.data.DataLoader(traintestset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    refloader = torch.utils.data.DataLoader(refset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    alltestloader = torch.utils.data.DataLoader(alltestset, batch_size = batch_size, shuffle = False, num_workers = num_worker)
    trloader = torch.utils.data.DataLoader(trset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    teloader = torch.utils.data.DataLoader(teset, batch_size=batch_size, shuffle=False, num_workers=num_worker)


    best_acc = 0.00
    model_1 = get_network(arch=args.model, num_classes=100)
    model = ModelwNorm(model_1)
    criterion = (nn.CrossEntropyLoss()).to(device, torch.float)
    model = model.to(device, torch.float)
    print("training sets: {:d}".format(len(train_data)))

    iter_per_epoch = len(trainloader)
    #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    #train_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones= [60, 120, 160], gamma=0.2) #learning rate decay
    #warmup_scheduler = WarmUpLR(optimizer,  iter_per_epoch * args.warmup)
    optimizer, train_scheduler, warmup_scheduler = get_opt_and_lrsch(
        args, model, classifier_epochs, iter_per_epoch, warmup
    )

    attackset = CifarAttack(train_data, train_label, all_test_data, all_test_label, transform_test)
    #trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_worker)
    #attackloader = torch.utils.data.DataLoader(attackset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
    print("attack set: {:d}".format(len(attackset)))

    trer = TrainRecorder()
    best_epoch = 0

    attack_model0 = InferenceAttack_HZ(num_class).to(device, torch.float)
    attack_criterion0 = nn.MSELoss().to(device, torch.float)
    attack_optimizer0 = optim.Adam(attack_model0.parameters(),lr=0.0001)

    for epoch in range(1, classifier_epochs+1):

        trainset.update()
        attackset.update()
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_worker)
        attackloader = torch.utils.data.DataLoader(attackset, batch_size=batch_size, shuffle=False, num_workers=num_worker)

        if epoch > 1:
            train_scheduler.step(epoch)


        if epoch <= 4:
            training_loss, training_acc = train(trainloader, model, criterion, optimizer, epoch, warmup_scheduler, args)

        else:
            _, _, = train_attack(args, attackloader, model, attack_model0, attack_criterion0, attack_optimizer0)
            _, _, = train_privately(args, trainloader, model, attack_model0, criterion, optimizer, epoch, warmup_scheduler)

        training_loss, training_acc = test(trainloader, model, criterion, batch_size)
        train_loss, train_acc = test(traintestloader, model, criterion, batch_size)
        test_loss, test_acc = test(testloader, model, criterion, batch_size)
        # record
        trer.update(train_loss, train_acc, test_loss, test_acc, training_loss, training_acc)

        # save model
        is_best = test_acc>best_acc
        best_acc = max(test_acc, best_acc)
        if is_best:
            best_epoch = epoch
            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'acc': test_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best, best_acc, checkpoint=checkpoint_path)

        #if (epoch)%print_epoch ==0:
        lr = get_learning_rate(optimizer)
        print('Epoch: [{:d} | {:d}]: learning rate:{:.4f}. acc: test: {:.4f}. loss: test: {:.4f}'.format(epoch, classifier_epochs, lr[0], test_acc, test_loss))
        sys.stdout.flush()

    # save the record
    trer.save(checkpoint_path, 'train_record.csv')
    # save the last
    save_checkpoint({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'acc': test_acc,
        'best_acc': best_acc,
        'optimizer': optimizer.state_dict(),
    }, False, best_acc, checkpoint=checkpoint_path)
    print("Final saved epoch {:d} with best acc {:.4f}".format(best_epoch, best_acc))


if __name__ == '__main__':
    main()
