from __future__ import print_function
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import random
import sys
import PIL
from PIL import Image
import json
import torch
import torchvision
import torchvision.transforms as T
from torchvision import transforms
from smooth_rank_ap import *
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
from pytorch_metric_learning import distances, losses, miners, reducers, testers

##### My imports #####
from HierarchicalSampling import *
######################

device = torch.device("cuda")

def iNat_get_xform(augmentation):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if augmentation == 'none':
        return transforms.Compose([
                    transforms.Resize((227,227)),
                    transforms.ToTensor(),
                    normalize])
    elif augmentation == 'bigtrain':
        return transforms.Compose([
                    # transforms.Resize(288),
                    transforms.RandomResizedCrop(size=224, scale=[0.16, 1], ratio=[0.75, 1.33]), ###
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    normalize])
    
    elif augmentation == 'bigtest':
        return transforms.Compose([
                    # transforms.Resize((288,288)),
                    # transforms.CenterCrop((256, 256)),
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize])
    else:
        raise NotImplemented

# cars and CUB
def CUB_get_xform(augmentation):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if augmentation == 'none':
        return transforms.Compose([
                    transforms.Resize((227,227)),
                    transforms.ToTensor(),
                    normalize])
    elif augmentation == 'bigtrain':
        return transforms.Compose([
#                     transforms.Resize(288),
                    transforms.RandomResizedCrop(size=256, scale=[0.16, 1], ratio=[0.75, 1.33]),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    normalize])
    
    elif augmentation == 'bigtest':
        return transforms.Compose([
                    transforms.Resize(288),
                    transforms.CenterCrop(256),
                    transforms.ToTensor(),
                    normalize])
    else:
        raise NotImplemented

def SOP_get_xform(augmentation):
    normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if augmentation == 'none':
        return transforms.Compose([
                    transforms.Resize((227,227)),
                    transforms.ToTensor(),
                    normalize])
    elif augmentation == 'bigtrain':
        return transforms.Compose([
                    transforms.Resize(288),
                    transforms.RandomResizedCrop(size=256, scale=[0.16, 1], ratio=[0.75, 1.33]),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ToTensor(),
                    normalize])
    
    elif augmentation == 'bigtest':
        return transforms.Compose([
                    transforms.Resize((288,288)),
                    transforms.CenterCrop((256, 256)),
                    transforms.ToTensor(),
                    normalize])
    else:
        raise NotImplemented

def test(G, test_loader):
    torch.set_num_threads(4)
    features = []
    labels = []
    G.eval()
    i = 0
    features = torch.zeros((len(test_loader.dataset),args.embedding_size))
    labels = torch.zeros(len(test_loader.dataset))
    with torch.no_grad():
        for img, target in test_loader:
            feature = G(img.to(device)).cpu().detach()
            assert feature.shape[1] == features.shape[1]
            features[i:i+feature.shape[0]] = feature
            labels[i:i+feature.shape[0]] = target.view(-1)
            i += feature.shape[0]

    assert i == features.shape[0]
    
    n_splits = 24
    inc = labels.shape[0] // n_splits
    predictions = []
    f_T = features.T.cuda()
    for i in tqdm(range(n_splits)):
        sm = features[inc*i:inc*(i+1)].cuda() @ f_T
        for j in range(sm.shape[0]):
            sm[j,inc*i+j] = -1.
        _, indices = sm.max(1)
        predictions.append(indices.cpu())
        del indices, sm

    if labels.shape[0] % n_splits > 0:
        sm = features[inc*n_splits:].cuda() @ f_T
        for j in range(sm.shape[0]):
            sm[j,inc*n_splits+j] = -1.
        _, indices = sm.max(1)
        predictions.append(indices.cpu())
        del indices, sm

    del f_T
    predictions = torch.cat(predictions)
    print(predictions.shape)
    for i in range(predictions.shape[0]):
        assert predictions[i] != i
        
    print('Test R @ 1: ', (labels == labels[predictions]).sum().item() / labels.shape[0])
    return (labels == labels[predictions]).sum().item() / labels.shape[0]

class CUB_Network(nn.Module):
    def __init__(self, embedding_size=512, device='cuda'):
        super().__init__()
        self.backbone = torchvision.models.resnet50(pretrained=True).cuda()
        self.backbone.avgpool = nn.AdaptiveMaxPool2d((1, 1))
        self.backbone.fc = nn.Identity()
        self.standardize = nn.LayerNorm(2048, elementwise_affine=False).to(device)
        self.remap = nn.Linear(2048, embedding_size, bias=True).to(device)
        
    def forward(self, x):
        with torch.cuda.amp.autocast(enabled=True):
            return F.normalize(self.remap(self.standardize(self.backbone(x))))
        
class SOP_Network(nn.Module):
    def __init__(self, embedding_size=512, device='cuda'):
        super().__init__()
        self.backbone = torchvision.models.resnet50(pretrained=True).cuda()
        self.backbone.fc = nn.Identity()
        self.remap = nn.Linear(2048, embedding_size, bias=True).to(device)
        
    def forward(self, x):
        with torch.cuda.amp.autocast(enabled=True):
            return F.normalize(self.remap(self.backbone(x)))

def get_jaccard_similarity(w, scores, k, eps):
    # features has been normalized before this function
    # D is ||u-v||^2 = 2 - 2<u,v>
    D = 2. - 2. * scores  # Euclidedan, assume features are L2 normalized
    D = D.clamp(min=0.)   # for stability
    vk, ik = (-D).topk(k)                  # get k closest neighbors

    Nk_mask1 = GreaterThan.apply(-D + args.Nk_mask_eps, vk[:,-1:].detach())
    Nk_mask2 = GreaterThan.apply(-D + eps, vk[:,-1:].detach())
    intersection = (Nk_mask2 @ Nk_mask2.T) / Nk_mask2.sum(1, keepdim=True).detach()
    Nk_mask_not = 1. - Nk_mask2
    intersection_not = (Nk_mask_not @ Nk_mask_not.T) / Nk_mask_not.sum(1, keepdim=True).detach()
    
    if args.ww == 1 or args.ww == 3:
        W_tilde = 0.5 * (intersection + intersection_not) * w
    elif args.ww == 2:
        W_tilde = 0.5 * (intersection + intersection_not) * Nk_mask2.detach()
    else:
        W_tilde = 0.5 * (intersection + intersection_not) * Nk_mask2
        
    if args.ww == 3:
        return Nk_mask1, W_tilde

    Nk_over_2_mask = GreaterThan.apply(-D + eps, vk[:,(k // 2)-1:(k // 2)].detach())
    Rk_over_2_mask = Nk_over_2_mask * Nk_over_2_mask.T
    
    jaccard_similarity = (Rk_over_2_mask @ W_tilde) / Rk_over_2_mask.sum(1, keepdim=True)

    return Nk_mask1 , 0.5 * (jaccard_similarity + jaccard_similarity.T)
    
def contrastive_jaccard(scores, w):
    reg_loss = (scores.mean() - args.reg_sim).square()

    js1, js2 = get_jaccard_similarity(w, scores, k=4, eps=args.eps)
    I_neg = torch.ones_like(js1) - torch.eye(js1.shape[0], device=js1.device)

    l1 = ((1. - js1).square() * w * I_neg)     # positive pairs 
    l2 = (js1.square() * (1. - w) * I_neg)  # negative pairs

    l3 = ((1. - js2).square() * w * I_neg)     # positive pairs 
    l4 = (js2.square() * (1. - w) * I_neg)  # negative pairs
    
    loss_nk = ((l1.sum() + l2.sum()) / js1.shape[0] / js1.shape[0])
    loss_context = ((l3.sum() + l4.sum()) / js2.shape[0] / js2.shape[0])

    return loss_nk, loss_context, reg_loss

def margin_contrastive(scores, w, pos_margins=0.9, neg_margins=0.6):
    '''pos_margins and neg_margins must be column vectors or scalars
    '''
    L_pos = F.relu(pos_margins - scores) * w
    L_neg = F.relu(scores - neg_margins) * (1. - w)
    
    if (L_pos > 0.).sum() < 1e-5:
        l_p = torch.tensor(0.)
    else:
        l_p = L_pos.sum() / (L_pos > 0.).sum()
    if (L_neg > 0.).sum() < 1e-5:
        l_n = torch.tensor(0.)
    else:
        l_n = L_neg.sum() / (L_neg > 0.).sum()
        
    return l_p + l_n
    
class GreaterThan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return (x >= y).float()
    @staticmethod
    def backward(ctx, grad_output):
        x,y = ctx.saved_tensors
        g = 10.
        return grad_output * g, - grad_output * g
    
#################################################################################### 
parser = argparse.ArgumentParser(description='')
parser.add_argument('--schedule', type=str, default='', help='')
parser.add_argument('--trainxform', type=str, default='bigtrain', help='')
parser.add_argument('--root', type=str, default='', help='')
parser.add_argument('--lam', type=float, default=0.5, help='')
parser.add_argument('--testxform', type=str, default='bigtest', help='')
parser.add_argument('--loss', type=str, default='hybrid', help='')
parser.add_argument('--embedding_size', type=int, default=512, help='')
parser.add_argument('--batch_size', type=int, default=128, help='')
parser.add_argument('--n_epochs', type=int, default=80, help='')
parser.add_argument('--test_freq', type=int, default=5, help='')
parser.add_argument('--lr', type=float, default=0.00001, help='')
parser.add_argument('--weight_decay', type=float, default=0.0001, help='')
parser.add_argument('--eps', type=float, default=0.05, help='')
parser.add_argument('--Nk_mask_eps', type=float, default=1e-5, help='')
parser.add_argument('--start_epoch', type=int, default=0, help='')
parser.add_argument('--checkpoint', type=str, default='', help='')
parser.add_argument('--dataset', type=str, default='', help='')
parser.add_argument('--a', type=float, default=0.0, help='') ###
parser.add_argument('--b', type=float, default=0.4, help='') ###
parser.add_argument('--c', type=float, default=0.1, help='') ###
parser.add_argument('--neg_margins', type=float, default=0.6, help='')
parser.add_argument('--pos_margins', type=float, default=0.75, help='')
parser.add_argument('--ww', type=int, default=0, help='')
parser.add_argument('--reg_sim', type=float, default=0.25, help='') # try 0.25 and 0.3

args = parser.parse_args()
print(args)
#################################################################################### 

#### For ROADMAP ####
from smooth_rank_ap import SupAP, SmoothAP
criterion = SupAP()
smoothAP_criterion = SmoothAP()
miner = miners.DistanceWeightedMiner()
triplet_criterion = losses.TripletMarginLoss(margin=0.05)
ms_criterion = losses.MultiSimilarityLoss()
fastap_criterion = losses.FastAPLoss()
ntxent_criterion = losses.NTXentLoss()

root = args.root
if args.dataset == 'SOP':
    testxform = SOP_get_xform(augmentation=args.testxform)
    trainxform = SOP_get_xform(augmentation=args.trainxform)
    trainset = SOPDataset(root, 'train', transform=trainxform)
    testset = SOPDataset(root, 'test', transform=testxform)
elif args.dataset == 'CUB':
    train_txt = 'cub200_train.txt'
    test_txt = 'cub200_test.txt'
    testxform = CUB_get_xform(augmentation=args.testxform)
    trainxform = CUB_get_xform(augmentation=args.trainxform)
    trainset = Imagelist(image_list=train_txt, root=root, transform=trainxform)
    testset = Imagelist(image_list=test_txt, root=root, transform=testxform)
elif args.dataset == 'Cars':
    train_txt = 'cars196_train.txt'
    test_txt = 'cars196_test.txt'
    testxform = CUB_get_xform(augmentation=args.testxform)
    trainxform = CUB_get_xform(augmentation=args.trainxform)
    trainset = Imagelist(image_list=train_txt, root=root, transform=trainxform)
    testset = Imagelist(image_list=test_txt, root=root, transform=testxform)
elif args.dataset == 'iNat':
    assert args.batch_size == 256
    test_txt = 'iNaturalist_test.txt'
    train_txt = 'iNaturalist_train.txt'
    testxform = iNat_get_xform(augmentation=args.testxform)
    trainxform = iNat_get_xform(augmentation=args.trainxform)
    trainset = Imagelist_iNat(image_list=train_txt, root=root, transform=trainxform)
    testset = Imagelist_iNat(image_list=test_txt, root=root, transform=testxform)
else:
    assert False

if args.dataset == 'SOP' or args.dataset == 'iNat':
    model = torch.nn.DataParallel(SOP_Network(embedding_size=args.embedding_size))
else:
    model = torch.nn.DataParallel(CUB_Network(embedding_size=args.embedding_size))
    
scaler = torch.cuda.amp.GradScaler()
if args.start_epoch > 0:
    assert args.checkpoint != ''
if args.checkpoint != '':
    print('loading checkpoint {} and starting on epoch {}'.format(args.checkpoint, args.start_epoch))
    model.module, _, scaler = torch.load(args.checkpoint)

loss_list = []
model_save_path = '{}_{}_lam_{}_{}.pt'.format(args.dataset, args.loss, args.lam, random.random()*10000000)
print('model will be saved in: ', model_save_path)

G_lr = args.lr / 2.
W_lr = args.lr

optimizer1 = optim.Adam(model.module.backbone.parameters(), lr=G_lr, weight_decay=args.weight_decay)
optimizer2 = optim.Adam(model.module.remap.parameters(), lr=W_lr, weight_decay=args.weight_decay)

if args.schedule == 'faster':
    scheduler1 = optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[10,20,30,40], gamma=0.33)
    scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer2, milestones=[10,20,30,40], gamma=0.33)
else:
    assert args.schedule == ''
    scheduler1 = optim.lr_scheduler.MultiStepLR(optimizer1, milestones=[15,30,45], gamma=0.3)
    scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer2, milestones=[15,30,45], gamma=0.3)

for _ in range(args.start_epoch):
    scheduler1.step()
    scheduler2.step()
#################################################################################### 
best_acc = 0.

for epoch in range(args.start_epoch, args.n_epochs):
    print('epoch {} G_lr {} W_lr {}'.format(epoch, optimizer1.param_groups[0]['lr'], optimizer2.param_groups[0]['lr']))
    
    model.module.backbone.eval()
    
    if args.dataset == 'SOP':
        trainsampler = HierarchicalSamplerBalanced(
                    dataset = trainset,
                    batch_size=args.batch_size,
                    samples_per_class=4,
                    batches_per_super_pair=10,
                    nb_categories=2
                )
    elif args.dataset == 'iNat':
        trainsampler = HierarchicalSamplerBalancediNat(
                    dataset = trainset,
                    batch_size=args.batch_size,
                    samples_per_class=4,
                    batches_per_super_pair=10,
                    nb_categories=2
                )
    else:
        trainsampler = RandomSampler(
                dataset = trainset,
                batch_size=args.batch_size,
                samples_per_class=4,
                num_batches = 42,
                tries = 5 if args.dataset == 'CUB' else 10
        )
    
    trainloader = torch.utils.data.DataLoader(trainset, batch_sampler=trainsampler, num_workers=16, pin_memory=True)
    
    for x, y in tqdm(trainloader):
        with torch.cuda.amp.autocast(enabled=True):
            f = model(x.to(device))
            
        with torch.cuda.amp.autocast(enabled=False):
            y = y.reshape(-1).to(device)
            w = (y.float().unsqueeze(0).T == y.float().unsqueeze(0)).float().to(device)
            scores = f.float() @ f.float().T
            loss_nk, loss_context, reg_loss = contrastive_jaccard(scores, w)
            
            if args.loss == 'roadmap':
                loss_supap = criterion(scores, w)
                loss_con = margin_contrastive(scores, w, pos_margins=0.9, neg_margins=0.6)
                loss = args.lam * loss_supap + (1. - args.lam) * loss_con
            elif args.loss == 'triplet':
                miner_output = miner(f, y)
                loss = triplet_criterion(f, y, miner_output)
            elif args.loss == 'multisimilarity':
                loss = ms_criterion(f, y)
            elif args.loss == 'smoothap':
                loss = smoothAP_criterion(scores, w)
            elif args.loss == 'fastap':
                loss = fastap_criterion(f, y)
            elif args.loss == 'contrastive':
                loss = margin_contrastive(scores, w, pos_margins=0.9, neg_margins=0.6)
            elif args.loss == 'ntxent':
                loss = ntxent_criterion(f, y)
            else:
                assert args.loss == 'hybrid'
                loss_contrast = margin_contrastive(scores, w, pos_margins=args.pos_margins, neg_margins=args.neg_margins)
                loss = args.a * loss_nk + args.b * loss_context + args.c * reg_loss + (1. - args.a - args.b - args.c) * loss_contrast
            
            loss_list.append(loss.item())
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer1)
        scaler.step(optimizer2)
        scaler.update()
        
    del trainloader, f, y, w
    scheduler1.step()
    scheduler2.step()
    
    if (epoch + 1) % args.test_freq == 0:

        print('Testing there are {} samples in testset'.format(len(testset)))
        testloader =  torch.utils.data.DataLoader(testset, batch_size=128, num_workers=16, shuffle=False, drop_last=False, pin_memory=True)
        test_acc = test(model, testloader)
        if test_acc > best_acc:
            print('beat test acc, saving model ... ')
            torch.save((model.module, loss_list, scaler), model_save_path)
            best_acc = test_acc
        del testloader
