from tqdm import tqdm
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from moco.loader import load_data
from moco.builder import ModelMoCo
import time
import random
import torch.backends.cudnn as cudnn
import pdb
import os
from utils import knn_predict, setup_logger, get_rank, copy_script, AverageMeter


from math import pi, cos
parser = argparse.ArgumentParser(description='Train Label on CIFAR-10')

## dataset
parser.add_argument('--arch', default='resnet50')
parser.add_argument('--dataset-name', default='cifar10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--data-path', default='/export/home/dataset/CIFAR10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')

## architecture
parser.add_argument('--mlp-hidden-size', default=2048, type=int, help='feature dimension')
parser.add_argument('--dim', default=128, type=int, help='feature dimension')
parser.add_argument('--k', default=4096, type=int, help='3840 queue size; number of negative keys')
parser.add_argument('--m', default=0.99, type=float, help='moco momentum of updating key encoder')

# knn monitor
parser.add_argument('--knn-k', default=50, type=int, help='k in kNN monitor')
parser.add_argument('--knn-t', default=0.05, type=float, help='softmax temperature in kNN monitor; could be different with moco-t')

## optimizer
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--epochs', default=2000, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', default=True, help='use cosine lr schedule')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')

parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--batch-size', default=256, type=int, metavar='N', help='mini-batch size')

# loss setting
parser.add_argument('--symmetric', default=False, action='store_true', help='use a symmetric loss function that backprops to both crops')
parser.add_argument('--loss-type', default='CLEAN', type=str, metavar='PATH', help='CLEAN, moco')
parser.add_argument('--teacher_weight', default=0.5, type=float, metavar='PATH', help='contrastive, CLSA, PC')

## important parameters
parser.add_argument('--results-dir', default='CIFAR10_2000/', type=str, metavar='PATH', help='path to cache (default: none)')
parser.add_argument('--seed', default=None, type=int, metavar='PATH', help='contrastive, CLSA, PC')
parser.add_argument('--resume_model_name', default=None, type=str, metavar='N', help='mini-batch size')
parser.add_argument('--warmup_epoch', default=0, type=int, metavar='PATH', help='contrastive, CLSA, PC')

parser.add_argument('--sharp_probability', default=0.8, type=float, metavar='PATH', help='sharp constant for sharp prediction probability')
parser.add_argument('--teach_T', default=0.2, type=float, metavar='N', help='temperature for sharp prediction')
parser.add_argument('--t', default=0.2, type=float, help='softmax temperature')
parser.add_argument('--lam', default=2.0, type=float, metavar='PATH', help='beta distribution parameter')
parser.add_argument('--prior_confidence', default='cos', type=str, metavar='PATH', help='contrastive, CLSA, PC')
parser.add_argument('--prior_confidence_max', default=1.0, type=float, metavar='PATH', help='contrastive, CLSA, PC')
parser.add_argument('--prior_confidence_min', default=0.0, type=float, metavar='PATH', help='contrastive, CLSA, PC')

#CUDA_VISIBLE_DEVICES=0 python main_moco.py --data-path /export/home/dataset/CIFAR10 --epochs 2000

def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def train(net, data_loader, train_optimizer, epoch, args):
    net.train()
    adjust_learning_rate(train_optimizer, epoch, args)
    acc_meter = AverageMeter()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    if args.prior_confidence == 'cos': ## from min to max
        prior_confidence= args.prior_confidence_max - 0.5 * (args.prior_confidence_max - args.prior_confidence_min) * (
                    1 + cos(epoch / args.epochs * pi))
    else:
        prior_confidence = args.prior_confidence_min

    for (weak1, weak2), _ in train_bar:
        weak1, weak2 = weak1.cuda(non_blocking=True), weak2.cuda(non_blocking=True),

        loss, acc = net(weak1, weak2, prior_confidence)
        acc_meter.update(acc)

        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Prior: {:.6f}, Loss: {:.4f}, label acc: {:.1f}'.format(epoch, args.epochs, \
                                optimizer.param_groups[0]['lr'], prior_confidence, total_loss / total_num, acc_meter.avg))

    return total_loss / total_num, acc_meter.avg, prior_confidence


# test using a knn monitor
def test(net, memory_data_loader, test_data_loader, epoch, args):
    net.eval()
    classes = len(memory_data_loader.dataset.classes)
    total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
    with torch.no_grad():
        # generate feature bank
        for data, target in tqdm(memory_data_loader, desc='Feature extracting'):
            feature = net(data.cuda(non_blocking=True))
            # pdb.set_trace()
            feature = torch.flatten(feature, 1)
            feature = F.normalize(feature, dim=1)
            feature_bank.append(feature)
        # [D, N]
        feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
        # [N]
        feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
        # loop test data to predict the label by weighted knn search
        test_bar = tqdm(test_data_loader)
        for data, target in test_bar:
            data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
            feature = net(data)
            feature = torch.flatten(feature, 1)
            feature = F.normalize(feature, dim=1)
            # pdb.set_trace()
            pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, args.knn_k, args.knn_t)

            total_num += data.size(0)
            # pdb.set_trace()
            total_top1 += (pred_labels[:, 0] == target).float().sum().item()
            total_top5 += torch.sum((pred_labels[:, :5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
            # pdb.set_trace()
            test_bar.set_description(
                'Test Epoch: [{}/{}] Acc@1:{:.2f}%'.format(epoch, args.epochs, total_top1 / total_num * 100))

    return total_top1 / total_num * 100, total_top5 / total_num * 100





if __name__ == "__main__":
    args = parser.parse_args('')
    args.lr = args.lr * args.batch_size /256.0

    if args.seed is None:
        args.seed = random.randint(0,10000)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.deterministic = True

    # pdb.set_trace()
    args.results_dir = './%s/%s_ws_sys%d_warm%d_sharp%.2f_weight%.2f_lam%.1f_teach%.2f_batch%d_split%d_prior%s_%.2f_%.2f' % ( \
            args.results_dir, args.loss_type, args.symmetric, args.warmup_epoch, args.sharp_probability, \
            args.teacher_weight, args.lam, args.teach_T, args.batch_size,
            args.bn_splits,  args.prior_confidence, args.prior_confidence_min,args.prior_confidence_max)


    files_to_save = ['main_moco.py', 'main_linear.py', 'utils.py',  'moco/builder.py', 'moco/loader.py','moco/augmentations.py',]
    copy_script(args.results_dir, files_to_save, name='train')
    logger = setup_logger("Training", args.results_dir, get_rank())
    logger.info(args)

    ## Step 1 loda data
    train_loader, memory_loader, test_loader, input_shape = load_data(args.data_path, args.batch_size)

    ## Step 2 create model
    model = ModelMoCo(input_shape=input_shape, args=args)
    logger.info(model.encoder_q)

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.DataParallel(model)
        model = model.cuda()
    else:
        model = model.cuda()

    ## Step 3 define optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)

    ## Step 4 load model if resume
    epoch_start = 1
    if args.resume_model_name is not None:
        checkpoint = torch.load('%s/%s'%(args.results_dir, args.resume_model_name))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch'] + 1
        logger.info('Loaded from: {}'.format(args.resume_model_name))

    # training loop
    best_acc1, best_acc5 = 0, 0
    for epoch in range(epoch_start, args.epochs + 1):
        t1 = time.time()
        ## train one epoch
        train_loss, acc_pre, prior_confidence = train(model, train_loader, optimizer, epoch, args)

        if epoch % 1 ==0:
            if num_gpus > 1:
                encoder = model.module.encoder_q.net
            else:
                encoder = model.encoder_q.net

            test_acc_1, test_acc_5 = test(encoder, memory_loader, test_loader, epoch, args)

            t2 = time.time()
            remaining_time = (args.epochs - epoch) * (t2 - t1) / 3600.0
            if best_acc1 < test_acc_1:
                best_acc1, best_acc5 = test_acc_1, test_acc_5
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, args.results_dir + '/model_best.pth')

            lr = optimizer.param_groups[0]['lr']
            results_str = '%d epoch, lr %.6f prior %.6f, training loss %.6f, label acc %.4f,  top1_acc %.4f, top5_acc %.4f, best_top1 %.4f, best_top5 %.4f, remaining time %.2f h'%( epoch, lr, prior_confidence, train_loss, acc_pre, test_acc_1, test_acc_5, best_acc1, best_acc5, remaining_time)

            logger.info(results_str)
            torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')

            if epoch % 500 == 0:
                torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, args.results_dir + '/models_%d.pth'%epoch)
