# no normalize
# todo: repr, sim, ntxent
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

from utils import load_data, accuracy
from models import ResNet18
#from rocl_lib.cifar import CIFAR10
import rocl_lib.cifar as rocl_cifar
from rocl_lib.attack_lib import RepresentationAdv
from rocl_lib.loss import pairwise_similarity, NT_xent
from torchlars import LARS
from warmup_scheduler import GradualWarmupScheduler
from attack_lib import LinfPGDAttack, L2PGDAttack
from advertorch.utils import NormalizeByChannelMeanStd

import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--suffix', type=str, default='')
args = parser.parse_args()


CON_DIM = 128
EPOCH = 200
BATCH_SIZE = 256

#atk, eps = None, 0.0
atk, eps = 'l2', 0.5
#atk, eps = 'linf', 8./255.

#LAMDA = 0.0 # full supervised
#LAMDA = 0.5
#LAMDA = 0.75
LAMDA = 0.875
#LAMDA = 1.0 # full contrastive

#sup_aug = 'simple'
sup_aug = 'complex'
#con_aug = 'simple'
con_aug = 'complex'

#trainset, testset, trainloader, testloader, normalizer = load_data(train_batch_size=256) #TODO: back
##trainset, testset, trainloader, testloader, normalizer = load_data()
#_, _, gt_trainloader, _, _ = load_data(train_aug=False)
#print (len(trainset), len(testset))
#mean = torch.tensor([0,0,0], dtype=torch.float32).cuda() #TODO: back
#std = torch.tensor([1,1,1], dtype=torch.float32).cuda()
#normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda()
std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda()
normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)
testset = torchvision.datasets.CIFAR10(root='./raw_data', train=False, download=True, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=4)
gt_trainset = torchvision.datasets.CIFAR10(root='./raw_data', train=True, download=True, transform=transforms.ToTensor())
gt_trainloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

if sup_aug == 'simple':
    transform_sup = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
elif sup_aug == 'complex':
    color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    transform_sup = transforms.Compose([
        rnd_color_jitter,
        rnd_gray,
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(32),
        transforms.ToTensor(),
    ])
else:
    assert 0
trainset_sup = torchvision.datasets.CIFAR10(root='./raw_data', train=True, download=True, transform=transform_sup)
trainloader_sup = torch.utils.data.DataLoader(trainset_sup, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False, shuffle=True, drop_last=True)

if con_aug == 'simple':
    transform_con = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
elif con_aug == 'complex':
    color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    transform_con = transforms.Compose([
        rnd_color_jitter,
        rnd_gray,
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(32),
        transforms.ToTensor(),
    ])
else:
    assert 0
trainset_con = rocl_cifar.CIFAR10(root='./raw_data', train=True, download=True, transform=transform_con, contrastive_learning='contrastive')
trainloader_con = torch.utils.data.DataLoader(trainset_con, batch_size=BATCH_SIZE, num_workers=0, pin_memory=False, shuffle=True, drop_last=True)
#print (len(trainset_sup), len(trainloader_sup), len(trainset_con), len(trainloader_con))
assert len(trainloader_sup) == len(trainloader_con)

model = ResNet18(normalizer)
model.sup_linear = model.linear
model.con_linear = nn.Sequential(nn.Linear(model.sup_linear.in_features, 2048), nn.ReLU(), nn.Linear(2048, CON_DIM))
model.linear = None
model = model.to('cuda')

if atk is None:
    pass
elif atk == 'linf':
    adv_sup = LinfPGDAttack(model, epsilon=eps)
elif atk == 'l2':
    adv_sup = L2PGDAttack(model, epsilon=eps)
else:
    assert 0
if atk is not None:
    adv_con = RepresentationAdv(model, _type=atk, epsilon=eps, alpha=eps/4)

base_optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=1e-6)
optimizer_p1 = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_p1, EPOCH)
scheduler_warmup = GradualWarmupScheduler(optimizer_p1, multiplier=15.0, total_epoch=10, after_scheduler=scheduler_cosine) #TODO: back
criterion = nn.CrossEntropyLoss()

for epoch_counter in range(EPOCH):
    model.train()
    scheduler_warmup.step()
    total_loss_sup = 0.0
    total_loss_con = 0.0
    with tqdm(zip(trainloader_sup, trainloader_con), total=len(trainloader_sup)) as pbar:
        #for batch_idx, (_, inputs_1, inputs_2, _), (x,y) in enumerate(pbar):
        for batch_idx, ((x_sup,y), (_, x_con_1, x_con_2, _)) in enumerate(pbar):
            if LAMDA != 1: # do supervised
                model.linear = model.sup_linear
                x_sup, y = x_sup.to('cuda'), y.to('cuda')
                if atk is not None:
                    adv_x_sup = adv_sup.perturb(x_sup, y)
                    pred = model(adv_x_sup)
                else:
                    pred = model(x_sup)
                loss_sup = criterion(pred, y)
                model.linear = None
            else:
                loss_sup = torch.FloatTensor([0]).to('cuda')

            if LAMDA != 0: # do contrastive
                model.linear = model.con_linear
                x_con_1, x_con_2 = x_con_1.to('cuda'), x_con_2.to('cuda')
                if atk is not None:
                    advinputs, adv_loss = adv_con.get_loss(original_images=x_con_1, target=x_con_2, optimizer=optimizer_p1, weight=256, random_start=True)
                    inputs = torch.cat((x_con_1, x_con_2, advinputs))
                    outputs = model(inputs)
                    similarity, _ = pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='Rep')
                    simloss = NT_xent(similarity, 'Rep')
                    loss_con = simloss + adv_loss
                else:
                    inputs = torch.cat((x_con_1, x_con_2))
                    outputs = model(inputs)
                    similarity, _ = pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='None')
                    simloss = NT_xent(similarity, 'None')
                    loss_con = simloss
                model.linear = None
            else:
                loss_con = torch.FloatTensor([0]).to('cuda')

            loss = LAMDA * loss_con + (1-LAMDA) * loss_sup
            optimizer_p1.zero_grad()
            loss.backward()
            optimizer_p1.step()
            total_loss_sup += loss_sup.item()
            total_loss_con += loss_con.item()

            pbar.set_description('Epoch %d | Supervised Loss %.3f | Contrastive loss %.3f'%(epoch_counter, total_loss_sup/(batch_idx+1), total_loss_con/(batch_idx+1)))

    torch.save(model.state_dict(), './saved_model/mixing-%s-%.4f-lamda%.4f-%s%s%s_phase1.pth'%(atk, eps, LAMDA, sup_aug, con_aug, args.suffix))
#model.load_state_dict(torch.load('./saved_model/mixing-%s-%.4f-lamda%.4f-%s%s_phase1.pth'%(atk, eps, LAMDA, sup_aug, con_aug)))
model.linear = nn.Linear(model.sup_linear.in_features, 10).cuda()
del model.sup_linear
del model.con_linear



criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    #model.train()
    model.eval() # Only training last layer
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(gt_trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(gt_trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(40):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/mixing-%s-%.4f-lamda%.4f-%s%s%s.pth'%(atk, eps, LAMDA, sup_aug, con_aug, args.suffix))
