from datetime import datetime
from tqdm import tqdm
import argparse
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torchvision.models as models
from moco.loader import load_data, load_linear_train_data
from moco.builder import ModelBase,SplitBatchNorm
from pathlib import Path
import random
import torch.backends.cudnn as cudnn
import time
import numpy as np
import pdb
import os
from utils import setup_logger, get_rank, accuracy, AverageMeter, copy_script

parser = argparse.ArgumentParser(description='Train CLEAN on CIFAR-10')


## dataset
parser.add_argument('--arch', default='resnet50')
parser.add_argument('--dataset-name', default='cifar10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--data-path', default='/export/home/dataset/CIFAR10', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--seed', default=9367, type=int, metavar='PATH', help='contrastive, CLSA, PC')
parser.add_argument('--bn-splits', default=8, type=int, help='simulate multi-gpu behavior of BatchNorm in one gpu; 1 is SyncBatchNorm in multi-gpu')
parser.add_argument('--dim', default=2048, type=int, help='feature dim of encoder')
parser.add_argument('--num_classes', default=10, type=int, help='classfication number')


# CLEAN specific configs:
parser.add_argument('--resume', default=None, type=str, metavar='PATH', help='path to latest checkpoint (default: none)')
parser.add_argument('--results-dir', default='./pretrained_models', type=str, metavar='PATH', help='path to cache (default: none)')
parser.add_argument('--pretrained', default='./pretrained_models/model.pth', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')

## optimizer pretrained_model_strong_augmentation
parser.add_argument('--optimizer', default='adam', type=str, metavar='PATH', \
                    help='rmsprop, adam, sgd, asgd, AdamW, Adagrad')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--batch-size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--schedule', default=[], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x); does not take effect if --cos is on')
parser.add_argument('--cos', default=True, help='use cosine lr schedule')
parser.add_argument('--wd', default=0, type=float, metavar='W', help='weight decay')

# CUDA_VISIBLE_DEVICES=0 python main_linear.py --data-path /export/home/dataset/CIFAR10 --epochs 200

def train(net, model_fc, criterion, data_loader, train_optimizer, epoch, args):
    net.eval()
    model_fc.train()
    adjust_learning_rate(train_optimizer, epoch, args)
    # pdb.set_trace()
    total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader)
    for (im_1, label) in train_bar:
        im_1, label = im_1.cuda(non_blocking=True), label.cuda(non_blocking=True)
        with torch.no_grad():
            feature = net(im_1)
            feature = torch.flatten(feature, 1)

        loss = criterion(model_fc(feature), label)
        train_optimizer.zero_grad()
        loss.backward()
        train_optimizer.step()

        total_num += data_loader.batch_size
        total_loss += loss.item() * data_loader.batch_size
        train_bar.set_description('Train Epoch: [{}/{}], lr: {:.6f}, Loss: {:.4f}'.format(epoch, args.epochs, optimizer.param_groups[0]['lr'],
                                                                    total_loss / total_num))
    return total_loss / total_num


# lr scheduler for training
def adjust_learning_rate(optimizer, epoch, args):
    """Decay the learning rate based on schedule"""
    lr = args.lr
    if args.cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
    else:  # stepwise lr schedule
        for milestone in args.schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def validate_test(model, model_fc, criterion, val_loader):
    # pdb.set_trace()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model_fc.eval()
    with torch.no_grad():
        val_bar = tqdm(val_loader)
        for (images, target) in val_bar:
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
            feature = model(images)
            feature = torch.flatten(feature, 1)


            output = model_fc(feature)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))

            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

    model_fc.train()
    return top1.avg, top5.avg

if __name__ == "__main__":
    args = parser.parse_args('')  # running in ipynb
    if args.seed is None:
        args.seed = random.randint(0,10000)
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.deterministic = True

    print(args)
    path_save = '%s/test_batch%d_epoch%d_opt%s_lr%.5f_wd%.2e_cos%d' % (args.results_dir,\
                                    args.batch_size, args.epochs, args.optimizer, args.lr, args.wd, args.cos)
    copy_script(args.results_dir, files_to_same=['main_moco.py', 'main_linear.py', 'utils.py',  'moco/builder.py', 'moco/loader.py'])
    logger = setup_logger("Test", path_save, get_rank(), name='')
    logger.info(args)

    # pdb.set_trace()
    train_loader, test_loader = load_linear_train_data(args.data_path, args.batch_size)

    # pdb.set_trace()
    norm_layer = partial(SplitBatchNorm, num_splits=args.bn_splits) if args.bn_splits > 1 else nn.BatchNorm2d
    resnet_arch = getattr(models, args.arch)
    net2 = resnet_arch(num_classes=10, norm_layer=norm_layer)
    net = []
    for name, module in net2.named_children():
        if name == 'conv1':
            module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        if isinstance(module, nn.MaxPool2d) or isinstance(module, nn.Linear):
            continue
        net.append(module)
    net = nn.Sequential(*net)

    for name, param in net.named_parameters():
        param.requires_grad = False
    net.eval()
    # pdb.set_trace()
    net_fc = nn.Linear(args.dim, args.num_classes).cuda()

    # load from pre-trained, before DistributedDataParallel constructor
    if args.pretrained:
        resume_model = args.pretrained
        if os.path.isfile(resume_model):
            checkpoint = torch.load(resume_model, map_location="cpu")
            # rename pre-trained keys
            state_dict = checkpoint['state_dict']
            # for k in list(state_dict.keys()):
            #     # retain only encoder_q up to before the embedding layer
            #     if k.startswith('encoder_q') and not k.startswith('encoder_q.projection_head'):
            #         # remove prefix
            #         state_dict[k[len("encoder_q.net."):]] = state_dict[k]
            #     # delete renamed or unused k
            #     del state_dict[k]
            # pdb.set_trace()
            # torch.save({'state_dict': net.state_dict(),},args.results_dir + '/model_best_encoder.pth')
            msg = net.load_state_dict(state_dict, strict=False)
            # logger.info(msg.missing_keys)
            # logger.info(msg.unexpected_keys)
            logger.info("=> loaded pre-trained model '{}'".format(resume_model))
        else:
            logger.info("=> no checkpoint found at '{}'".format(resume_model))
    # pdb.set_trace()
    net.cuda()
    # optimize only the linear classifier
    assert len(net_fc.state_dict().keys()) == 2  # fc.weight, fc.bias
    optimizer = torch.optim.Adam(net_fc.parameters(), lr=args.lr, weight_decay=0.000001)
    criterion = nn.CrossEntropyLoss().cuda()

    # pdb.set_trace()
    if args.resume is not None:
        # checkpoint = torch.load('%s/%s'%(args.results_dir,args.resume))
        checkpoint = torch.load('%s' % ( args.resume))
        net_fc.load_state_dict(checkpoint['state_dict'])
        logger.info('Loaded from: {}'.format(args.resume_linear))
        test_acc_1, test_acc_5 = validate_test(net, net_fc, criterion, test_loader)
        results_str = 'top1_acc %.4f, top5_acc %.4f'%(test_acc_1, test_acc_5)
        logger.info(results_str)
    else:
        best_acc1, best_acc5 = 0, 0
        epoch_start = 1
        for epoch in range(epoch_start, args.epochs + 1):
            t1 = time.time()
            train_loss = train(net, net_fc,  criterion, train_loader, optimizer, epoch, args)
            # pdb.set_trace()
            test_acc_1, test_acc_5 = validate_test(net, net_fc, criterion, test_loader)
            t2 = time.time()
            remaining_time = (args.epochs - epoch) * (t2 - t1) / 3600.0
            if best_acc1 < test_acc_1:
                best_acc1, best_acc5 = test_acc_1, test_acc_5
                torch.save({'epoch': epoch, 'state_dict': net_fc.state_dict(), 'optimizer': optimizer.state_dict(), },
                           path_save + '/model_best.pth')

            lr = optimizer.param_groups[0]['lr']
            results_str = '%d epoch, lr %.6f training loss %.6f, top1_acc %.4f, top5_acc %.4f, best_top1 %.4f, best_top5 %.4f, remaining time %.2f'%(epoch, lr, train_loss, \
                   test_acc_1, test_acc_5, best_acc1, best_acc5, remaining_time)
            logger.info(results_str)
            torch.save({'epoch': epoch, 'state_dict': net_fc.state_dict(), 'optimizer' : optimizer.state_dict(),}, path_save + '/model_linear_last.pth')
