# no normalize
# todo: repr, sim, ntxent
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

from utils import load_data, accuracy
from models import ResNet18
from rocl_lib.cifar import CIFAR10
from rocl_lib.attack_lib import RepresentationAdv
from rocl_lib.loss import pairwise_similarity, NT_xent
from torchlars import LARS
from warmup_scheduler import GradualWarmupScheduler
from attack_lib import LinfPGDAttack, L2PGDAttack

LAMDA = 0.0
#LAMDA = 0.0001
#LAMDA = 0.125
#LAMDA = 0.25
#LAMDA = 0.75
#LAMDA = 0.9999

trainset, testset, trainloader, testloader, normalizer = load_data(train_batch_size=256) #TODO: back
#trainset, testset, trainloader, testloader, normalizer = load_data()
gt_trainset, _, gt_trainloader, _, _ = load_data(train_aug=False)
print (len(trainset), len(testset))
mean = torch.tensor([0,0,0], dtype=torch.float32).cuda() #TODO: back
std = torch.tensor([1,1,1], dtype=torch.float32).cuda()
from advertorch.utils import NormalizeByChannelMeanStd
normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

ROCL_DIM = 128
ROCL_EPOCH = 100
#ROCL_EPOCH = 200 #TODO: back
#ROCL_EPOCH = 1000

atk = 'linf'
#eps = 4./255.
eps = 8./255.
#eps = 16./255.

color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
transform_train = transforms.Compose([
    rnd_color_jitter,
    rnd_gray,
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(32),
    transforms.ToTensor(),
])
rocl_trainset = CIFAR10(root='./raw_data', train=True, download=True, transform=transform_train,contrastive_learning='contrastive')
rocl_trainloader = torch.utils.data.DataLoader(rocl_trainset,batch_size=256,num_workers=4,pin_memory=False,shuffle=True)

model = ResNet18(normalizer)
model.linear = nn.Sequential(nn.Linear(model.linear.in_features, 2048), nn.ReLU(), nn.Linear(2048, ROCL_DIM))
model = model.to('cuda')

Rep = RepresentationAdv(model, _type=atk, epsilon=eps, alpha=eps/4)
if atk == 'linf':
    adversary = LinfPGDAttack(model, epsilon=eps)
else:
    assert 0
base_optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=1e-6)
optimizer_rocl = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001) #TODO: back
#optimizer_rocl = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=5e-4)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_rocl, ROCL_EPOCH)
scheduler_warmup = GradualWarmupScheduler(optimizer_rocl, multiplier=15.0, total_epoch=10, after_scheduler=scheduler_cosine) #TODO: back
#scheduler_warmup = torch.optim.lr_scheduler.MultiStepLR(optimizer_rocl, milestones=[100,150], gamma=0.1)
criterion = nn.CrossEntropyLoss()

for epoch_counter in range(ROCL_EPOCH):
    model.train()
    scheduler_warmup.step()
    reg_loss = 0.0
    reg_simloss = 0.0
    total_loss = 0.0
    with tqdm(rocl_trainloader) as pbar:
        for (batch_idx, (_, inputs_1, inputs_2, _)), (x,y) in zip(enumerate(pbar), trainloader):
            #inputs_1, inputs_2 = inputs_1.cuda(), inputs_2.cuda()
            x, y = x.to('cuda'), y.to('cuda')
            #advinputs, adv_loss = Rep.get_loss(original_images=inputs_1, target=inputs_2, optimizer=optimizer_rocl, weight=256, random_start=True)
            adv_x = adversary.perturb(x, y)

            #reg_loss += adv_loss.data
            #inputs = torch.cat((inputs_1, inputs_2, advinputs))
            #outputs = model(inputs)
            #similarity, _ = pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='Rep')
            #simloss = NT_xent(similarity, 'Rep')
            #loss_rocl = simloss + adv_loss
            simloss = torch.FloatTensor([1000000000000]).to('cuda')
            loss_rocl = torch.FloatTensor([1000000000000]).to('cuda')

            pred = model(adv_x)
            loss_supervised = criterion(pred, y)

            loss = LAMDA * loss_rocl + (1-LAMDA) * loss_supervised
            optimizer_rocl.zero_grad()
            loss.backward()
            optimizer_rocl.step()
            total_loss += loss_rocl.data
            reg_simloss += simloss.data

            pbar.set_description('Loss %.3f | SimLoss %.3f | Adv %.3f'%(total_loss/(batch_idx+1), reg_simloss/(batch_idx+1), reg_loss/(batch_idx+1)))

    torch.save(model.state_dict(), './saved_model/rocl-%s-%.4f-lamda%.4f_phase1.pth'%(atk, eps, LAMDA))
#model.load_state_dict(torch.load('./saved_model/rocl-%s-%.4f-lamda%.4f_phase1.pth'%(atk, eps, LAMDA)))
model.linear = nn.Linear(model.linear[0].in_features, 10).cuda()



criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    #model.train()
    model.eval() # Only training last layer
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(gt_trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(gt_trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(40):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/rocl-%s-%.4f-lamda%.4f.pth'%(atk, eps, LAMDA))
