# no normalize
# todo: repr, sim, ntxent
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdm

from utils import load_data, accuracy
from models import ResNet18
from rocl_lib.cifar import CIFAR10
from rocl_lib.attack_lib import RepresentationAdv
from rocl_lib.loss import pairwise_similarity, NT_xent
from torchlars import LARS
from warmup_scheduler import GradualWarmupScheduler

gt_trainset, testset, gt_trainloader, testloader, normalizer = load_data(train_aug=False)
print (len(gt_trainset), len(testset))
mean = torch.tensor([0,0,0], dtype=torch.float32).cuda()
std = torch.tensor([1,1,1], dtype=torch.float32).cuda()
from advertorch.utils import NormalizeByChannelMeanStd
normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

ROCL_DIM = 128
ROCL_EPOCH = 100
#ROCL_EPOCH = 1000

atk = 'linf'
#eps = 4./255.
eps = 8./255.
#eps = 16./255.

color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
rnd_gray = transforms.RandomGrayscale(p=0.2)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    #rnd_color_jitter,
    #rnd_gray,
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomResizedCrop(32),
    #transforms.ToTensor(),
])
rocl_trainset = CIFAR10(root='./raw_data', train=True, download=True, transform=transform_train,contrastive_learning='contrastive')
rocl_trainloader = torch.utils.data.DataLoader(rocl_trainset,batch_size=256,num_workers=4,pin_memory=False,shuffle=True)

model = ResNet18(normalizer)
model.linear = nn.Sequential(nn.Linear(model.linear.in_features, 2048), nn.ReLU(), nn.Linear(2048, ROCL_DIM))
model = model.to('cuda')

Rep = RepresentationAdv(model, _type=atk, epsilon=eps, alpha=eps/4)
base_optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=1e-6)
optimizer_rocl = LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_rocl, ROCL_EPOCH)
scheduler_warmup = GradualWarmupScheduler(optimizer_rocl, multiplier=15.0, total_epoch=10, after_scheduler=scheduler_cosine)

for epoch_counter in range(ROCL_EPOCH):
    model.train()
    scheduler_warmup.step()
    reg_loss = 0.0
    reg_simloss = 0.0
    total_loss = 0.0
    with tqdm(rocl_trainloader) as pbar:
        for batch_idx, (_, inputs_1, inputs_2, _) in enumerate(pbar):
            inputs_1, inputs_2 = inputs_1.cuda(), inputs_2.cuda()
            #Rep.max_iters = 0 # tmp 73/10
            advinputs, adv_loss = Rep.get_loss(original_images=inputs_1, target=inputs_2, optimizer=optimizer_rocl, weight=256, random_start=True)
            reg_loss += adv_loss.data

            inputs = torch.cat((inputs_1, inputs_2, advinputs))
            outputs = model(inputs)
            similarity, _ = pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='Rep')
            simloss = NT_xent(similarity, 'Rep')

            #### tmp - 73 acc, 2 robust
            #adv_loss = 0.0
            #inputs = torch.cat((inputs_1, inputs_2))
            #outputs = model(inputs)
            #similarity, _ = pairwise_similarity(outputs, temperature=0.5, multi_gpu=False, adv_type='None')
            #simloss = NT_xent(similarity, 'None')
            ####

            loss = simloss + adv_loss
            optimizer_rocl.zero_grad()
            loss.backward()
            optimizer_rocl.step()
            total_loss += loss.data
            reg_simloss += simloss.data

            pbar.set_description('Loss %.3f | SimLoss %.3f | Adv %.3f'%(total_loss/(batch_idx+1), reg_simloss/(batch_idx+1), reg_loss/(batch_idx+1)))

    torch.save(model.state_dict(), './saved_model/roclweak-%s-%.4f_phase1.pth'%(atk, eps))
#model.load_state_dict(torch.load('./saved_model/rocl_phase1.pth'))
model.linear = nn.Linear(model.linear[0].in_features, 10).cuda()

#del model.linear
##ckpt = torch.load('../RoCL/src/checkpoint/ckpt.t7tmp1Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU2_l256_0_epoch_900')['model']
#ckpt = torch.load('../RoCL/src/checkpoint/ckpt.t7tmp3Rep_attack_ep_0.0314_alpha_0.007_min_val_0.0_max_val_1.0_max_iters_7_type_linf_randomstart_Truecontrastive_ResNet18_cifar-10_b256_nGPU1_l256_0_epoch_200')['model']
##ckpt = torch.load('../RoCL/src/checkpoint/ckpt.t7tmptest_Evaluate_linear_eval_ResNet18_cifar-10_0')['model']
#new_dict = {}
#for k,v in ckpt.items():
#    name = k[7:]
#    new_dict[name] = v
#new_dict['normalizer.mean'] = torch.tensor([0,0,0], dtype=torch.float32).cuda()
#new_dict['normalizer.std'] = torch.tensor([1,1,1], dtype=torch.float32).cuda()
#model.load_state_dict(new_dict)
#model.linear = nn.Linear(512,10).cuda()
##ckpt = torch.load('../RoCL/src/checkpoint/ckpt.t7tmptest_Evaluate_linear_eval_ResNet18_cifar-10_0_linear')['model']
##new_dict = {}
##for k,v in ckpt.items():
##    name = k[9:]
##    new_dict[name] = v
##model.linear.load_state_dict(new_dict)



criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.linear.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20,30], gamma=0.1)

def train(epoch):
    print('\nEpoch: %d' % epoch)
    #model.train()
    model.eval() # Only training last layer
    train_loss = 0
    correct = 0
    total = 0
    with tqdm(gt_trainloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            optimizer.zero_grad()
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(train_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return train_loss/len(gt_trainloader), acc

def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad(), tqdm(testloader) as pbar:
        for batch_idx, (x, y) in enumerate(pbar):
            x, y = x.to('cuda'), y.to('cuda')
            pred = model(x)
            loss = criterion(pred, y)

            test_loss += loss.item()
            _, pred_c = pred.max(1)
            total += y.size(0)
            correct += pred_c.eq(y).sum().item()
            pbar.set_description('Loss: %.3f | Acc:%.3f%%'%(test_loss/(batch_idx+1), 100.*correct/total))

    acc = 100.*correct/total
    return test_loss/len(testloader), acc


best_acc = 0.0
for epoch in range(40):
    train(epoch)
    _, cur_acc = test(epoch)
    scheduler.step()
    if cur_acc > best_acc:
        best_acc = cur_acc
        torch.save(model.state_dict(), './saved_model/roclweak-%s-%.4f.pth'%(atk, eps))
