from datetime import datetime
from tqdm import tqdm
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torchvision.models as models
from moco.loader import load_data
from moco.builder import ModelBase,SplitBatchNorm
import random
import torch.backends.cudnn as cudnn
import os
from utils import setup_logger, get_rank, accuracy, AverageMeter, copy_script
parser = argparse.ArgumentParser(description='Train CLEAN on CIFAR-10')


## dataset
parser.add_argument('--arch', default='resnet50')
parser.add_argument('--dataset-name', default='cifar10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--data-path', default='/export/home/dataset/CIFAR10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', default=None, type=int, metavar='PATH', help='contrastive, CLSA, PC')
parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--dim', default=2048, type=int, help='feature dim of encoder')
parser.add_argument('--num_classes', default=10, type=int, help='classfication number')


# CLEAN specific configs:
parser.add_argument('--resume', default=None, type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='./pretrained_models', type=str, metavar='PATH', help='path to cache (default: none)')
parser.add_argument('--pretrained', default='./pretrained_models/model.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')


# knn monitor
parser.add_argument('--batch-size', default=256, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--knn-k', default=50, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.05, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')

#CUDA_VISIBLE_DEVICES=0 python main_knn.py
# test using a knn monitor
def test_KNN(net, memory_data_loader, test_data_loader, epoch, args):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))
            # pdb.set_trace()
            feature = torch.flatten(feature, 1)
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = torch.flatten(feature, 1)
            feature = F.normalize(feature, dim=1)
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)

            total_num += data.size(0)
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            test_bar.set_description('Test Acc@1:{:.2f}%'.format(total_top1 / total_num * 100))

    return total_top1 / total_num * 100, total_top5 / total_num * 100


def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
    # compute cos similarity between each feature vector and feature bank ---> [B, N]
    # pdb.set_trace()
    sim_matrix = torch.mm(feature, feature_bank)
    # [B, K]
    sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
    # [B, K]
    sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
    sim_weight = (sim_weight / knn_t).exp()

    # counts for each class
    one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
    # [B*K, C]
    one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
    # weighted score ---> [B, C]
    pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

    pred_labels = pred_scores.argsort(dim=-1, descending=True)
    return pred_labels



if __name__ == "__main__":
    args = parser.parse_args('')  # running in ipynb
    if args.seed is None:
        args.seed = random.randint(0,10000)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.deterministic = True

    print(args)
    path_save = '%s/KNN_temperature_%.3f_k_%d' % (args.results_dir,\
                                    args.knn_t, args.knn_k)
    copy_script(args.results_dir, files_to_same=['main_moco.py', 'main_linear.py', 'utils.py',  'moco/builder.py', 'moco/loader.py'])
    logger = setup_logger("Test_KNN", path_save, get_rank(), name='')
    logger.info(args)

    # pdb.set_trace()
    train_loader, memory_loader, test_loader, input_shape = load_data(args.data_path, args.batch_size)
    # pdb.set_trace()
    norm_layer = partial(SplitBatchNorm, num_splits=args.bn_splits) if args.bn_splits > 1 else nn.BatchNorm2d
    resnet_arch = getattr(models, args.arch)
    net2 = resnet_arch(num_classes=10, norm_layer=norm_layer)
    net = []
    for name, module in net2.named_children():
        if name == 'conv1':
            module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        if isinstance(module, nn.MaxPool2d) or isinstance(module, nn.Linear):
            continue
        net.append(module)
    net = nn.Sequential(*net)

    for name, param in net.named_parameters():
        param.requires_grad = False
    net.eval()

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        resume_model = args.pretrained
        if os.path.isfile(resume_model):
            checkpoint = torch.load(resume_model, map_location="cpu")
            # rename moco pre-trained keys
            state_dict = checkpoint['state_dict']
            # for k in list(state_dict.keys()):
            #     # retain only encoder_q up to before the embedding layer
            #     if k.startswith('encoder_q') and not k.startswith('encoder_q.projection_head'):
            #         # remove prefix
            #         state_dict[k[len("encoder_q.net."):]] = state_dict[k]
            #     # delete renamed or unused k
            #     del state_dict[k]
            # # pdb.set_trace()
            msg = net.load_state_dict(state_dict, strict=False)
            # logger.info(msg.missing_keys)
            # logger.info(msg.unexpected_keys)
            logger.info("=> loaded pre-trained model '{}'".format(resume_model))
        else:
            logger.info("=> no checkpoint found at '{}'".format(resume_model))


    net.cuda()
    test_acc_1, test_acc_5 = test_KNN(net, memory_loader, test_loader, 0, args)
    results_str = 'KNN test: k %d, temp %.3f, top1_acc %.4f, top5_acc %.4f'% (args.knn_k, args.knn_t, test_acc_1, test_acc_5)
    logger.info(results_str)

