import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from utils import load_data, accuracy
from models import ResNet18
from simclr_lib.contrastive_learning_dataset import ContrastiveLearningDataset
from attack_lib import LinfPGDAttack, L2PGDAttack

#LAMDA = 0.25
LAMDA = 0.75

trainset, testset, trainloader, testloader, normalizer = load_data(train_batch_size=256)
gt_trainset, _, gt_trainloader, _, _ = load_data(train_aug=False)
print (len(trainset), len(testset))
ADV_STR = 'adv'
ATK = 'l2'
#EPSILON=1.0
EPSILON=0.1

if ATK == 'l2':
    SAVE_STR = ADV_STR+'-l2-%.4f'%EPSILON
else:
    SAVE_STR = ADV_STR
SAVE_STR = SAVE_STR+'-lamda%.4f'%LAMDA

CLR_BATCH_SIZE = 256
CLR_TEMP = 0.1
CLR_DIM = 128
CLR_EPOCH = 100

clr_trainset = ContrastiveLearningDataset('./raw_data').get_dataset('cifar10',n_views=2)
clr_trainloader = torch.utils.data.DataLoader(clr_trainset, batch_size=CLR_BATCH_SIZE, shuffle=True, num_workers=12, pin_memory=True, drop_last=True)

model = ResNet18(normalizer)
model.clr_header = nn.Sequential(nn.Linear(model.linear.in_features, 512), nn.ReLU(), nn.Linear(512, CLR_DIM))
model = model.to('cuda')

criterion_clr = nn.CrossEntropyLoss()
optimizer_clr = torch.optim.Adam(model.parameters(), 3e-4, weight_decay=1e-4)
scheduler_clr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_clr, T_max=CLR_EPOCH, eta_min=0, last_epoch=-1)

def info_nce_loss(features1, features2):
    labels = torch.cat([torch.arange(CLR_BATCH_SIZE) for i in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float().to('cuda')
    features1 = F.normalize(features1, dim=1)
    features2 = F.normalize(features2, dim=1)
    similarity_matrix = torch.matmul(features1, features2.T)

    mask = torch.eye(labels.shape[0], dtype=torch.bool).to('cuda')
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to('cuda')
    logits = logits / CLR_TEMP
    return logits, labels


class SimCLRLinfPGDAttack(object):
    def __init__(self, model, num_steps=7, epsilon=8./255., alpha=0.007):
        self.model = model
        self.num_steps = num_steps
        self.epsilon = epsilon
        self.alpha = alpha

    def perturb(self, x):
        origin_training = model.training
        model.eval()
        features = model.clr_header(model.calc_representation(x)).detach()
        adv_x = x.detach()
        adv_x = adv_x + torch.zeros_like(adv_x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.num_steps):
            adv_x.requires_grad_()

            if ADV_STR == 'adv':
                with torch.enable_grad():
                    adv_features = model.clr_header(model.calc_representation(adv_x))
                    logits, labels = info_nce_loss(adv_features, features)
                    loss = criterion_clr(logits, labels)
                grad = torch.autograd.grad(loss, [adv_x])[0]
            elif ADV_STR == 'advrepr':
                raise NotImplementedError()
            elif ADV_STR == 'advlid':
                raise NotImplementedError()
            else:
                assert 0

            adv_x = adv_x.detach() + self.alpha * torch.sign(grad.detach())
            adv_x = torch.min(torch.max(adv_x, x-self.epsilon), x+self.epsilon)
            adv_x = torch.clamp(adv_x,0,1)
        if origin_training:
            model.train()
        return adv_x.detach()


class SimCLRL2PGDAttack(object):
    def __init__(self, model, num_steps=7, epsilon=1.0):
        self.model = model
        self.num_steps = num_steps
        self.epsilon = epsilon
        self.alpha = epsilon/4

    def perturb(self, x):
        origin_training = model.training
        model.eval()
        features = model.clr_header(model.calc_representation(x)).detach()
        adv_x = x.detach()
        adv_x = adv_x + torch.zeros_like(adv_x).uniform_(-self.epsilon, self.epsilon)
        delta = adv_x - x
        delta_norm = delta.view(x.shape[0], -1).norm(dim=1).view(-1,1,1,1)
        clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
        adv_x = x + clamp_delta
        adv_x = torch.clamp(adv_x,0,1)
        for i in range(self.num_steps):
            adv_x.requires_grad_()

            if ADV_STR == 'adv':
                with torch.enable_grad():
                    adv_features = model.clr_header(model.calc_representation(adv_x))
                    logits, labels = info_nce_loss(adv_features, features)
                    loss = criterion_clr(logits, labels)
                grad = torch.autograd.grad(loss, [adv_x])[0]
            elif ADV_STR == 'advrepr':
                with torch.enable_grad():
                    adv_features = model.clr_header(model.calc_representation(adv_x))
                    #loss = (adv_features-features).norm(2,dim=1).mean()
                    loss = (adv_features/adv_features.norm(2,dim=1,keepdim=True)-features.norm(2,dim=1,keepdim=True)).norm(2,dim=1).mean()
                grad = torch.autograd.grad(loss, [adv_x])[0]
            elif ADV_STR == 'advlid':
                raise NotImplementedError()
            else:
                assert 0

            adv_x = adv_x.detach() + self.alpha * torch.sign(grad.detach())
            delta = adv_x - x
            delta_norm = delta.view(x.shape[0], -1).norm(dim=1).view(-1,1,1,1)
            clamp_delta = delta / delta_norm * torch.clamp(delta_norm,0,self.epsilon)
            adv_x = x + clamp_delta
            adv_x = torch.clamp(adv_x,0,1)
        if origin_training:
            model.train()
        return adv_x.detach()


if ATK == 'l2':
    adversary = L2PGDAttack(model, epsilon=EPSILON)
    adversary_clr = SimCLRL2PGDAttack(model, epsilon=EPSILON)
elif ATK == 'linf':
    adversary = LinfPGDAttack(model, epsilon=EPSILON)
    adversary_clr = SimCLRLinfPGDAttack(model, epsilon=EPSILON)
for epoch_counter in range(CLR_EPOCH):
    model.train()
    with tqdm(clr_trainloader) as pbar:
        for (images, _), (x,y) in zip(pbar, trainloader):
            images = torch.cat(images, dim=0).to('cuda')
            x, y = x.to('cuda'), y.to('cuda')
            adv_images = adversary_clr.perturb(images)
            adv_x = adversary.perturb(x, y)

            features = model.clr_header(model.calc_representation(images))
            adv_features = model.clr_header(model.calc_representation(adv_images))
            logits, labels = info_nce_loss(adv_features, features)
            loss_clr = criterion_clr(logits, labels)

            pred = model(adv_x)
            loss_supervised = criterion_clr(pred, y)

            loss = LAMDA * loss_clr + (1-LAMDA) * loss_supervised
            optimizer_clr.zero_grad()
            loss.backward()
            optimizer_clr.step()

            top1, top5 = accuracy(logits, labels, topk=(1,5))
            pbar.set_description("Epoch %d, Top 1: %.4f; Top 5: %.4f"%(epoch_counter, top1, top5))
        if epoch_counter >= 10:
            scheduler_clr.step()
        
    torch.save(model.state_dict(), './saved_model/simclr%s_phase1.pth'%SAVE_STR)
#model.load_state_dict(torch.load('./saved_model/simclr%s_phase1.pth'%SAVE_STR))
del model.clr_header


criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    #model.train()
    model.eval() # Only training last layer
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(gt_trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(gt_trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(40):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/simclr%s.pth'%SAVE_STR)
