import argparse
import os
import random
import shutil
import time

import warnings
import copy 
import datetime
import torchvision.datasets as dset
from datasets import ImagenetNoise
from networks import Ensemble  
from utils import check_dir, prepare_dset, update_print, get_relative_maha_distance, maha, \
    get_pretrained_model, get_maha_distance, MahaDistNormalizer, ranking_loss, LogitsMinMax
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision
import torch.nn.functional as F
import numpy as np
import pickle
import crl_utils
import utils
from metrics import crl_metrics
from torch.utils.tensorboard import SummaryWriter

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--maha_file', default='./ssl/maha_dict.npy', type=str)
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet34')
parser.add_argument('--pretrained', default='', type=str,
                    help='path to moco pretrained checkpoint')
parser.add_argument('--pretrained_model', default='vit', type=str, help='SSL feature map type')
parser.add_argument('--comp_dis', action='store_true', default=False)

parser.add_argument('--loss_type', default='rank0', type=str, help='rank0/margin_rank')
parser.add_argument('--rank_weight', default=1.0, type=float, help='ranking loss weight')
parser.add_argument('--warmup', default=0, type=int)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--dataset', default='imagenet', type=str, help='cifar10/cifar100')
parser.add_argument('--left', default=1.0, type=float)
parser.add_argument('--right', default=1.0, type=float)
parser.add_argument('--rank_target', default='softmax', type=str, help='Rank_target name to use [softmax, margin, entropy]')

parser.add_argument('--data',
                    metavar='DIR',
                    default='/data/LargeData/Large/ImageNet',
                    help='path to dataset')
parser.add_argument('--lr',
                    '--learning-rate',
                    default=0.1,
                    type=float,
                    metavar='LR',
                    help='initial learning rate',
                    dest='lr')
parser.add_argument('-j',
                    '--workers',
                    default=4,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs',
                    default=120,
                    type=int,
                    metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch',
                    default=0,
                    type=int,
                    metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b',
                    '--batch-size',
                    default=512,
                    type=int,
                    metavar='N',
                    help='mini-batch size (default: 3200), this is the total '
                    'batch size of all GPUs on the current node when '
                    'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--momentum',
                    default=0.9,
                    type=float,
                    metavar='M',
                    help='momentum')
parser.add_argument('--local_rank',
                    default=-1,
                    type=int,
                    help='node rank for distributed training')
parser.add_argument('--wd',
                    '--weight-decay',
                    default=1e-4,
                    type=float,
                    metavar='W',
                    help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p',
                    '--print-freq',
                    default=10,
                    type=int,
                    metavar='N',
                    help='print frequency (default: 10)')
parser.add_argument('-e',
                    '--evaluate',
                    dest='evaluate',
                    action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--seed',
                    default=None,
                    type=int,
                    help='seed for initializing training. ')
parser.add_argument('--ynoise_type', default='symmetric', type=str, help='symmetric/pairflip')
parser.add_argument('--ynoise_rate', default=0.0, type=float, help='label noise rate')
parser.add_argument('--xnoise_rate', default=0.0, type=float)
parser.add_argument('--xnoise_type', default='contrast', type=str)
parser.add_argument('--xnoise_arg', default=5, type=str)

parser.add_argument('--ensemble_num', default=1, type=int, help="number of model to ensumble")
parser.add_argument('--random_state', type=int, default=0)
parser.add_argument('--num_classes', type=int, default=1000)
parser.add_argument('--stop', type=int, default=1000)


def reduce_mean(tensor, nprocs):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= nprocs
    return rt



def main():
    args = parser.parse_args()
    args.nprocs = torch.cuda.device_count()
    filename =  'resnet34_crl'+ '_'+ args.dataset \
            + '_' + args.arch + '_' + args.loss_type
    check_dir('./checkpoint')
    base_dir = os.path.join('./checkpoint', 'deep_ens')
    check_dir(base_dir)
    args_root = os.path.join(base_dir, filename + '_args.pkl')
    print('Saving args to '+ args_root)
    # pickle.dump(args, open(args_root, 'wb'))
    torch.save(args, args_root)
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    main_worker(args.local_rank, args.nprocs, args)


def main_worker(local_rank, nprocs, args):
    best_acc1 = .0

    dist.init_process_group(backend='nccl')
    # # create model
    # if args.pretrained:
    #     print("=> using pre-trained model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch](pretrained=True)
    # else:
    #     print("=> creating model '{}'".format(args.arch))
    #     model = models.__dict__[args.arch]()
    if args.ensemble_num > 1:
        model = Ensemble(args.ensemble_num, args.num_classes) # use resnet 18
    else:
        model = torchvision.models.resnet34(num_classes=args.num_classes)
    torch.cuda.set_device(local_rank)
    model.cuda(local_rank)
    # When using a single GPU per process and per
    # DistributedDataParallel, we need to divide the batch size
    # ourselves based on the total number of GPUs we have
    args.batch_size = int(args.batch_size / nprocs)
    model = torch.nn.parallel.DistributedDataParallel(model,
                                                      device_ids=[local_rank])

    # define loss function (criterion) and optimizer
    # criterion = nn.CrossEntropyLoss().cuda(local_rank)
    cls_criterion = nn.CrossEntropyLoss().cuda(local_rank)
    ranking_criterion = nn.MarginRankingLoss(margin=0.0).cuda(local_rank)
    

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = ImagenetNoise(
        transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]),
        xnoise_rate=args.xnoise_rate,
        xnoise_arg=args.xnoise_arg,
        xnoise_type=args.xnoise_type,
        ynoise_type=args.ynoise_type,
        ynoise_rate=args.ynoise_rate,
        random_state=args.random_state,
        num_classes=args.num_classes
    )
    if local_rank == 0:
        train_dataset.report_noise()
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=2,
                                               pin_memory=True,
                                               sampler=train_sampler)
    correctness_history = crl_utils.History(len(train_loader.dataset))
    metric_dataset = copy.deepcopy(train_dataset)
    metric_dataset.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])

    val_dataset = ImagenetNoise(
        train=False,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]),
        num_classes=args.num_classes
    )
    # one_hot_encoding
    test_onehot = crl_utils.one_hot_encoding(val_dataset.targets)
    test_label = val_dataset.targets
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             num_workers=2,
                                             pin_memory=True,
                                             sampler=val_sampler)

    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), 1e-5)
    if args.evaluate:
        validate(val_loader, model, cls_criterion, local_rank, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if epoch > args.stop: break
        train_sampler.set_epoch(epoch)
        val_sampler.set_epoch(epoch)

        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, cls_criterion, ranking_criterion, optimizer, epoch, local_rank, correctness_history,
              args)
        # scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, cls_criterion, local_rank, args)

        # remember best acc@1 and save checkpoint
        filename =  'resnet34_crl' + '_'+ args.dataset \
            + '_' + args.arch + '_' + args.loss_type
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        if args.local_rank == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'net': model.module.state_dict(),
                    'acc': best_acc1,
                }, is_best, filename)

        # if epoch % args.save_period == 0:
        #     eval_metric(metric_loader, model, epoch, local_rank, args)
    print('Best Acc1:', best_acc1)
    # calc measure
    acc, aurc, eaurc, aupr, fpr, ece, nll, brier = crl_metrics.calc_metrics(val_loader,
                                                                        test_label,
                                                                        test_onehot,
                                                                        model,
                                                                        cls_criterion)
    # result write
    result_logger = crl_utils.Logger(os.path.join('./', 'crl_result.log'))
    result_logger.write([acc, aurc*1000, eaurc*1000, aupr*100, fpr*100, ece*100, nll*10, brier*100])

n_iter = 0
def train(train_loader, model, criterion_cls, criterion_ranking, optimizer, epoch, local_rank, history, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             [batch_time, data_time, losses, top1, top5],
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (idx, (images, xnoisy), (target, true_tar)) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.cuda(local_rank, non_blocking=True)
        target = target.cuda(local_rank, non_blocking=True)

        # compute output
        outputs = model(images)
        # loss = criterion(output, target)
        optimizer.zero_grad()

        # start confidence-aware learning
        if args.rank_target == 'softmax':
            conf = F.softmax(outputs, dim=1)
            confidence, _ = conf.max(dim=1)
        # entropy
        elif args.rank_target == 'entropy':
            if args.data == 'cifar100':
                value_for_normalizing = 4.605170
            else:
                value_for_normalizing = 2.302585
            confidence = crl_utils.negative_entropy(outputs,
                                                    normalize=True,
                                                    max_value=value_for_normalizing)
        # margin
        elif args.rank_target == 'margin':
            conf, _ = torch.topk(F.softmax(outputs), 2, dim=1)
            conf[:,0] = conf[:,0] - conf[:,1]
            confidence = conf[:,0]

        # make input pair
        rank_input1 = confidence
        rank_input2 = torch.roll(confidence, -1)
        idx2 = torch.roll(idx, -1)

        # calc target, margin
        rank_target, rank_margin = history.get_target_margin(idx, idx2)
        rank_target_nonzero = rank_target.clone()
        rank_target_nonzero[rank_target_nonzero == 0] = 1
        rank_input2 = rank_input2 + rank_margin / rank_target_nonzero

        # ranking loss
        ranking_loss = criterion_ranking(rank_input1,
                                         rank_input2,
                                         rank_target)

        # total loss
        cls_loss = criterion_cls(outputs, target)
        ranking_loss = args.rank_weight * ranking_loss
        loss = cls_loss + ranking_loss

        
        
        # ce_loss = criterion(outputs, target)
        # loss = ce_loss
            
        loss.backward()
        optimizer.step()

        prec, correct = crl_utils.accuracy(outputs, target)
        history.correctness_update(idx, correct, outputs)
        # # compute gradient and do SGD step
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        # measure accuracy and record loss
        if args.ensemble_num > 1:
            mean_output = F.softmax(torch.stack(outputs),dim=-1).mean(dim=0)
        else:
            mean_output = outputs 
        acc1, acc5 = accuracy(mean_output, target, topk=(1, 5))

        torch.distributed.barrier()

        reduced_rank_loss = reduce_mean(ranking_loss, args.nprocs)
        reduced_loss = reduce_mean(loss, args.nprocs)
        reduced_acc1 = reduce_mean(acc1, args.nprocs)
        reduced_acc5 = reduce_mean(acc5, args.nprocs)

        losses.update(reduced_loss.item(), images.size(0))
        top1.update(reduced_acc1.item(), images.size(0))
        top5.update(reduced_acc5.item(), images.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
            global n_iter
            n_iter += 1
            with SummaryWriter(log_dir='./logs_loss/crl_imagenet_fr80_'+args.loss_type, comment='imagenet') as writer:
                writer.add_scalar(tag='Loss/rank_loss',scalar_value=reduced_rank_loss.item(),global_step=n_iter)
                writer.add_scalar('Loss/train_loss', reduced_loss.item(), global_step=n_iter)
    # max correctness update
    history.max_correctness_update(epoch)


def validate(val_loader, model, criterion, local_rank, args):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader), [batch_time, losses, top1, top5],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (_, images, target) in enumerate(val_loader):
            images = images.cuda(local_rank, non_blocking=True)
            target = target.cuda(local_rank, non_blocking=True)

            # compute output
            output = model(images)
            if args.ensemble_num > 1:
                mean_output = F.softmax(torch.stack(output),dim=-1).mean(dim=0)
            else:
                mean_output = output 
            loss = criterion(mean_output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(mean_output, target, topk=(1, 5))

            torch.distributed.barrier()

            reduced_loss = reduce_mean(loss, args.nprocs)
            reduced_acc1 = reduce_mean(acc1, args.nprocs)
            reduced_acc5 = reduce_mean(acc5, args.nprocs)

            losses.update(reduced_loss.item(), images.size(0))
            top1.update(reduced_acc1.item(), images.size(0))
            top5.update(reduced_acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

    return top1.avg

def save_checkpoint(state, is_best, filename):
    
    if is_best:
        save_point = 'checkpoint'
        check_dir(save_point)
        base_dir = os.path.join(save_point, "deep_ens")
        check_dir(base_dir)
        save_path = os.path.join(base_dir, filename + '.pkl')
        print('Save Model to', save_path)
        torch.save(state, save_path)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1**(epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def accuracy(output, target, topk=(1, )):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()