import argparse
import os, json
import pandas as pd
import torch
import torch.optim as optim
from thop import profile, clever_format
from torch.utils.data import DataLoader, sampler
from tqdm import tqdm
import numpy as np
import utils
from model import Model, ModelMNIST, ModelImageNet


# train for one epoch to learn unique features
def train(net, data_loader, train_optimizer):
    net.train()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for pos_1, pos_2, target in train_bar:
        curr_batch_size = len(pos_1)
        pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda(non_blocking=True)
        feature_1, out_1 = net(pos_1)
        feature_2, out_2 = net(pos_2)
        # [2*B, D]
        out = torch.cat([out_1, out_2], dim=0)
        # [2*B, 2*B]
        sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
        mask = (torch.ones_like(sim_matrix) - torch.eye(2 * curr_batch_size, device=sim_matrix.device)).bool()
        # [2*B, 2*B-1]
        sim_matrix = sim_matrix.masked_select(mask).view(2 * curr_batch_size, -1)

        # compute loss
        pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
        # [2*B]
        pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
        loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += curr_batch_size
        total_loss += loss.item() * curr_batch_size
        train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_num))

    return total_loss / total_num


# test for one epoch, use weighted knn to find the most similar images' label to assign the test image
def test(net, memory_data_loader, test_data_loader):
    net.eval()
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature, out = net(data.cuda(non_blocking=True))
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, _, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature, out = net(data)

            total_num += data.size(0)
            # compute cos similarity between each feature vector and feature bank ---> [B, N]
            sim_matrix = torch.mm(feature, feature_bank)
            # [B, K]
            sim_weight, sim_indices = sim_matrix.topk(k=k, dim=-1)
            # [B, K]
            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
            sim_weight = (sim_weight / temperature).exp()

            # counts for each class
            one_hot_label = torch.zeros(data.size(0) * k, c, device=sim_labels.device)
            # [B*K, C]
            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
            # weighted score ---> [B, C]
            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)

            pred_labels = pred_scores.argsort(dim=-1, descending=True)
            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%'
                                     .format(epoch, epochs, total_top1 / total_num * 100, total_top5 / total_num * 100))

    return total_top1 / total_num * 100, total_top5 / total_num * 100

def test_full(net, memory_data_loader, test_data_loader, k_sim=200):
    net.eval()
    total_knn, total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature, out = net(data.cuda(non_blocking=True))
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, _, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature, out = net(data)

            total_num += data.size(0)
            # compute cos similarity between each feature vector and feature bank ---> [B, N]
            sim_matrix = torch.mm(feature, feature_bank)
            # [B, K]
            sim_weight, sim_indices = sim_matrix.topk(k=k_sim, dim=-1)
            # [B, K]
            sim_labels = torch.gather(feature_labels.expand(data.size(0), -1), dim=-1, index=sim_indices)
            sim_weight = (sim_weight / temperature).exp()

            # counts for each class
            one_hot_label = torch.zeros(data.size(0) * k_sim, c, device=sim_labels.device)
            # [B*K, C]
            one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
            # weighted score ---> [B, C]
            pred_scores = torch.sum(one_hot_label.view(data.size(0), -1, c) * sim_weight.unsqueeze(dim=-1), dim=1)

            pred_labels = pred_scores.argsort(dim=-1, descending=True)
            total_top1 += torch.sum((pred_labels[:, :1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            
            # KNN predict
            xx = torch.sum(feature.pow(2), 1).reshape(-1, 1)
            cc = torch.sum(feature_bank.T.pow(2), 1).reshape(1, -1)
            xc = torch.mm(feature, feature_bank)
            z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)
            d, Idx = torch.sort(z, 1)
            ll = Idx[:,:k]
            ly = feature_labels[ll]

            if (k==1):
                pred = ly
            else:
                pred = torch.mode(ly, 1)[0]
            total_knn += torch.sum((pred == target).float()).item()
            
            test_bar.set_description('Test Epoch: [{}/{}] Acc@1:{:.2f}% Acc@5:{:.2f}%, AccKNN:{:.2f}%'
                                     .format(epoch, epochs, total_top1 / total_num * 100, 
                                             total_top5 / total_num * 100, total_knn / total_num * 100))

    return total_top1 / total_num * 100, total_top5 / total_num * 100, total_knn / total_num * 100


def test_KNN_L2(net, memory_data_loader, test_data_loader):
    net.eval()
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, _, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature, out = net(data.cuda(non_blocking=True))
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0)
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
#         import pdb; pdb.set_trace()

        for data, _, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature, out = net(data)
            total_num += data.size(0)
            
            xx = torch.sum(feature.pow(2), 1).reshape(-1, 1)
            cc = torch.sum(feature_bank.pow(2), 1).reshape(1, -1)
            xc = torch.mm(feature, feature_bank.T)
            z = xx - 2*xc + cc # (Nte,1) - 2*(Nte, Ntr) + (Ntr,1) = (Nte,Ntr)
            d, Idx = torch.sort(z, 1)
            ll = Idx[:,:k]
            ly = feature_labels[ll]

            if (k==1):
                pred = ly
            else:
                pred = torch.mode(ly, 1)[0]
            total_top1 += torch.sum((pred == target).float()).item()
            
    return total_top1 / total_num * 100

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train SimCLR')
    parser.add_argument('--feature_dim', default=300, type=int, help='Feature dim for latter used vector')
    parser.add_argument('--out_dim', default=1024, type=int, help='Out dim for latent vector')
    parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax')
    parser.add_argument('--k', default=3, type=int, help='Top k most similar images used to predict the label')
    parser.add_argument('--batch_size', default=1024, type=int, help='Number of images in each mini-batch')
    parser.add_argument('--epochs', default=100, type=int, help='Number of sweeps over the dataset to train')
    parser.add_argument('--resume', type=str, default=None)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr_milestones', type=str, default='[50, 80]')
    parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10/mnist')
    # args parse
    args = parser.parse_args()

    feature_dim, out_dim, temperature, k = args.feature_dim, args.out_dim, args.temperature, args.k
    batch_size, epochs = args.batch_size, args.epochs
    lr_milestones = json.loads(args.lr_milestones)

    # data prepare
    if args.dataset == 'cifar10':
        train_data = utils.CIFAR10Pair(root='../data/CIFAR10', train=True, transform=utils.train_transform, download=False)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)

        memory_data = utils.CIFAR10Pair(root='../data/CIFAR10', train=True, transform=utils.test_transform, download=False)
        memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        test_data = utils.CIFAR10Pair(root='../data/CIFAR10', train=False, transform=utils.test_transform, download=False)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        # model setup and optimizer config
        model = Model(feature_dim=feature_dim, out_dim=out_dim).cuda()

        flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
        flops, params = clever_format([flops, params])
    elif args.dataset == 'mnist':
        train_data = utils.MNISTPair(root='../data/MNIST/', train=True, transform=utils.mnist_train_transform, download=True)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=False)

        memory_data = utils.MNISTPair(root='../data/MNIST/', train=True, transform=utils.mnist_test_transform, download=True)
        memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        test_data = utils.MNISTPair(root='../data/MNIST/', train=False, transform=utils.mnist_test_transform, download=True)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

        # model setup and optimizer config
        model = ModelMNIST(feature_dim).cuda()

        flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),))
        flops, params = clever_format([flops, params])

    elif args.dataset == 'imagenet':
        data_transforms = {'train': utils.imagenet_train_transform, 'test': utils.imagenet_test_transform}
        train_data = utils.ImageNetPair(os.path.join('../data/ImageNet', 'train'), data_transforms['train'])
        memory_data = utils.ImageNetPair(os.path.join('../data/ImageNet', 'train'), data_transforms['test'])
        test_data = utils.ImageNetPair(os.path.join('../data/ImageNet', 'val'), data_transforms['test'])
        train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, pin_memory=True)
        memory_loader = torch.utils.data.DataLoader(memory_data, batch_size=batch_size, num_workers=4, shuffle=False, pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=4, shuffle=False, pin_memory=True)
        # model setup and optimizer config
        model = ModelImageNet(feature_dim).cuda()

        flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224).cuda(),))
        flops, params = clever_format([flops, params])


    save_name_pre = '{}_{}_{}_{}_{}_{}_{}'.format(args.dataset, feature_dim, out_dim, temperature, k, batch_size, epochs)

    if args.resume is not None:
        print("reload checkpoint ...")
        model.load_state_dict(torch.load(args.resume))
    
    print('# Model Params: {} FLOPs: {}'.format(params, flops))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=0.1, last_epoch=-1)

    c = len(memory_data.classes)

    # training loop
    results = {'train_loss': [], 'test_acc@1': [], 'sim_acc@1': [], 'sim_acc@5': []}
    
    if not os.path.exists('results'):
        os.mkdir('results')
    best_acc = 0.0
    for epoch in range(1, epochs + 1):
        train_loss = train(model, train_loader, optimizer)
        results['train_loss'].append(train_loss)
#         test_acc_1 = test_KNN_L2(model, memory_loader, test_loader)
#         sim_acc_1, sim_acc_5 = test(model, memory_loader, test_loader)
        sim_acc_1, sim_acc_5, test_acc_1 = test_full(model, memory_loader, test_loader)
        results['test_acc@1'].append(test_acc_1)
        results['sim_acc@1'].append(sim_acc_1)
        results['sim_acc@5'].append(sim_acc_5)
        print("TrainLoss: {}, TestAcc: {}, SimAcc@1: {}, SimAcc@5: {}".format(train_loss, test_acc_1, sim_acc_1, sim_acc_5))
        # save statistics
        data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
        data_frame.to_csv('results/{}_statistics.csv'.format(save_name_pre), index_label='epoch')
        if test_acc_1 > best_acc:
            best_acc = test_acc_1
            torch.save(model.state_dict(), 'results/{}_model.pth'.format(save_name_pre))
        
        scheduler.step()
