import argparse
import os
import random
import shutil
import time
import warnings
import numpy as np
import pprint
import math
from collections import defaultdict

import torch
import torch.nn.parallel
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torch.nn.functional as F
from torch.cuda.amp import GradScaler
from datasets.cifar10 import CIFAR10_LT
from datasets.cifar100 import CIFAR100_LT
from datasets.imagenet import ImageNet_LT
from datasets.inaturalist import Inaturalist_LT
from models import resnet
from models import resnet_cifar
from models.head import Head
from losses.contrastive import SCL
from losses.logitadjust import LogitAdjust
from utils import config, update_config, create_logger
from utils import AverageMeter, ProgressMeter
from utils import accuracy, calibration



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg',
                        default='', help='The configs of datasets and training setting', 
                        type=str)
    parser.add_argument('opts',
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)
    file_name = os.path.abspath(__file__)
    parser.add_argument('--file_name', default=file_name, type=str)
    args = parser.parse_args()
    update_config(config, args)  # logger.py

    return args


def main():
    args = parse_args()
    logger, model_dir = create_logger(config, args.cfg, args.file_name)
    logger.info('\n' + pprint.pformat(args))
    logger.info('\n' + str(config))

    if config.deterministic:  #
        seed = 0
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(seed)
        np.random.seed(seed)
        os.environ['PYTHONHASHSEED'] = str(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    if config.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if config.dist_url == "env://" and config.world_size == -1:
        config.world_size = int(os.environ["WORLD_SIZE"])

    config.distributed = config.world_size > 1 or config.multiprocessing_distributed

    ngpus_per_node = torch.cuda.device_count()
    if config.multiprocessing_distributed:
        config.world_size = ngpus_per_node * config.world_size
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger))
    else:
        # Simply call main_worker function
        main_worker(config.gpu, ngpus_per_node, config, logger, model_dir)

def main_worker(gpu, ngpus_per_node, config, logger, model_dir):
    global best_acc1
    global cls_num_list_cuda
    config.gpu = gpu

    if config.gpu is not None:
        logger.info("Use GPU: {} for training".format(config.gpu))

    if config.distributed:
        if config.dist_url == "env://" and config.rank == -1:
            config.rank = int(os.environ["RANK"])
        if config.multiprocessing_distributed:
            config.rank = config.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
                                world_size=config.world_size, rank=config.rank)

    if config.dataset == 'cifar10' or config.dataset == 'cifar100':
        model = getattr(resnet_cifar, config.backbone)()
        classifier = getattr(resnet_cifar, 'Classifier')(feat_in=config.feat_size, num_classes=config.num_classes)

    elif config.dataset == 'imagenet' or config.dataset == 'ina2018':
        model = getattr(resnet, config.backbone)()
        classifier = getattr(resnet, 'Classifier')(feat_in=config.feat_size, num_classes=config.num_classes)

    head = Head(dim_in=config.feat_size, num_classes=config.num_classes)
    classifier.fc = head.fc

    if not torch.cuda.is_available():
        logger.info('using CPU, this will be slow')
    elif config.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config.gpu is not None:
            torch.cuda.set_device(config.gpu)
            model.cuda(config.gpu)
            classifier.cuda(config.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config.batch_size = int(config.batch_size / ngpus_per_node)
            config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu])
            classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu])
            head.cuda(config.gpu)
            head = torch.nn.parallel.DistributedDataParallel(head, device_ids=[config.gpu])
        else:
            model.cuda()
            classifier.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
            classifier = torch.nn.parallel.DistributedDataParallel(classifier)
            head.cuda()
            head = torch.nn.parallel.DistributedDataParallel(head)
    elif config.gpu is not None:
        torch.cuda.set_device(config.gpu)
        model = model.cuda(config.gpu)
        classifier = classifier.cuda(config.gpu)
        head.cuda(config.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        model = torch.nn.DataParallel(model).cuda()
        classifier = torch.nn.DataParallel(classifier).cuda()
        head = torch.nn.DataParallel(head).cuda()

    # optionally resume from a checkpoint
    if config.resume:
        if os.path.isfile(config.resume):
            logger.info("=> loading checkpoint '{}'".format(config.resume))
            if config.gpu is None:
                checkpoint = torch.load(config.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(config.gpu)
                checkpoint = torch.load(config.resume, map_location=loc)
            # config.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if config.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(config.gpu)
            model.load_state_dict(checkpoint['state_dict_model'])
            classifier.load_state_dict(checkpoint['state_dict_classifier'])
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                        .format(config.resume, checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(config.resume))

    randaug = True

    # Data loading code
    if config.dataset == 'cifar10':
        dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
                             batch_size=config.batch_size, num_works=config.workers,
                             autoaug=config.autoaug, randaug=randaug)
    elif config.dataset == 'cifar100':
        dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
                              batch_size=config.batch_size, num_works=config.workers,
                              autoaug=config.autoaug, randaug=randaug)
    elif config.dataset == 'imagenet':
        dataset = ImageNet_LT(config.distributed, root=config.data_path,
                              batch_size=config.batch_size, num_works=config.workers,
                              randaug=randaug)
    elif config.dataset == 'ina2018':
        dataset = Inaturalist_LT(config.distributed, root=config.data_path,
                              batch_size=config.batch_size, num_works=config.workers,
                              randaug=randaug)

    train_loader = dataset.train_instance
    val_loader = dataset.eval
    if config.distributed:
        train_sampler = dataset.dist_sampler

    cls_num_list = train_loader.dataset.get_cls_num_list()
    print('cls num list:')
    print(cls_num_list)

    # params = [{"params": classifier2.parameters()}]
    params = []
    params.append({"params": head.parameters()})
    params.append({"params": model.parameters()})

    optimizer = torch.optim.SGD(params, config.lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    scaler = GradScaler()

    result = [torch.Tensor([]).cuda(config.gpu),torch.Tensor([]).cuda(config.gpu),torch.Tensor([]).cuda(config.gpu)   ]
    cls_num_list_cuda = torch.from_numpy(np.array(cls_num_list)).float().cuda()
    for epoch in range(config.num_epochs):
        if config.distributed:
            train_sampler.set_epoch(epoch)

        adjust_learning_rate(optimizer, epoch, config)

        train_sampler = None
        # define loss function (criterion) and optimizer
        criterion = {"SCCP": SCL(cls_num_list=cls_num_list).cuda(config.gpu),
                     'LA': LogitAdjust(cls_num_list=cls_num_list).cuda(config.gpu)
                     }

        # train for one epoch
        result = train(train_loader, model, classifier, criterion,
                       optimizer, scaler, epoch,config, logger, result, head)

        # evaluate on validation set
        is_best = validate(val_loader, model, classifier, config, logger)

        # save checkpoint
        if not config.multiprocessing_distributed or (config.multiprocessing_distributed
                                                      and config.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict_model': model.state_dict(),
                'state_dict_classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'result': result,
            }, is_best, model_dir)


def train(train_loader, model, classifier, criterion, optimizer, scaler, epoch, config,
          logger, result, head=None):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.3f')
    top1 = AverageMeter('Acc@1', ':6.3f')
    top5 = AverageMeter('Acc@5', ':6.3f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    classifier.train()

    training_data_num = len(train_loader.dataset)
    end_steps = int(training_data_num / train_loader.batch_size)

    contrast_images_fit, target_fit, tail_masks = result

    end = time.time()
    for i, (input1, target1) in enumerate(train_loader):
        if i > end_steps:
            break

        input1, input1_r1, input1_r2 = input1
        # measure data loading time
        data_time.update(time.time() - end)

        if torch.cuda.is_available():
            input1 = input1.cuda(config.gpu, non_blocking=True)
            target1 = target1.cuda(config.gpu, non_blocking=True)
            input1_r1 = input1_r1.cuda(config.gpu, non_blocking=True)
            input1_r2 = input1_r2.cuda(config.gpu, non_blocking=True)
        # Grad-CAM
        mask, logit = get_background_mask(model, classifier, input1, target1, config)
        prob = F.softmax(logit, dim=1)
        fit = (prob[target1 >= 0, target1] >= config.fit_thresh)
        mask[mask >= 0.5] = 1.0
        mask[mask < 0.5] = 0.0
        current_images = input1.clone().detach()
        target = target1.clone().detach()
        result = update_contrast(contrast_images_fit,target_fit,tail_masks,
                                 current_images[fit],target[fit], mask[fit], config.bank_size)  # config.bank_size
        contrast_images_fit, target_fit, tail_masks = result

        if epoch >= 0 and target_fit.shape[0] >= config.bank_size:
            batch_size = mask.shape[0]
            cur_mask = torch.ones_like(mask)
            cur_mask[mask == 0.0] = 0.0
            cur_mask[(mask == 1.0) & (tail_masks[:batch_size] == 0.0)] = 0.0
            mask_fit = ((prob[target1 >= 0, target1] < config.fit_thresh) & (prob[target1 >= 0, target1] >= 0.5))
            cur_mask[mask_fit == False] = 0.0
            merge_image1 = cur_mask[:batch_size] * (contrast_images_fit[:batch_size])[torch.randperm(batch_size)] + \
                           (1-cur_mask[:batch_size]) * current_images

            inputs = torch.cat([input1, input1_r1, input1_r2, merge_image1], dim=0)
            batch_size = target1.shape[0]
            feat = model(inputs)
            feat_mlp, logits, centers = head(feat)
            centers = centers[:config.num_classes]
            f1, f2, f3, f4 = torch.split(feat_mlp,[batch_size,batch_size,batch_size,batch_size], dim=0)
            features = torch.cat([f2.unsqueeze(1), f3.unsqueeze(1)], dim=1)
            output1,output2,output3,output4 = torch.split(logits,[batch_size,batch_size,batch_size,batch_size], dim=0)

            scl_loss = criterion["SCCP"](centers, features, torch.cat([target1, target1], dim=0).long())
            ce_loss = criterion["LA"](output1, target1)
            ce_loss1 = criterion["LA"](output4, target1)
            loss = (ce_loss + ce_loss1) + scl_loss
        else:
            feat = model(input1)
            output1 = classifier(feat)
            loss = criterion['LA'](output1, target1)


        acc1, acc5 = accuracy(output1, target1, topk=(1, 5))
        losses.update(loss.item(), input1.size(0))
        top1.update(acc1[0], input1.size(0))
        top5.update(acc5[0], input1.size(0))
        # compute gradient and do SGD step
        optimizer.zero_grad()
        if config.amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.print_freq == 0:
            progress.display(i, logger)
    return result

def update_contrast(contrast_images_fit, target_fit, tail_mask,
                    current_images, current_target, current_mask, banksize):
    if target_fit.shape[0] + current_target.shape[0] <= banksize:
        contrast_images_fit = torch.cat([contrast_images_fit, current_images], dim=0)
        target_fit = torch.cat([target_fit, current_target], dim=0)
        tail_mask = torch.cat([tail_mask, current_mask], dim=0)
    else:
        total_del_num = target_fit.shape[0] + current_target.shape[0] - banksize
        total_target = torch.cat([target_fit, current_target], dim=0)
        total_images = torch.cat([contrast_images_fit, current_images], dim=0)
        total_mask = torch.cat([tail_mask, current_mask], dim=0)
        unique_target, counts = torch.unique(total_target, return_counts=True)
        bank_pro = counts.float() / (total_target.shape[0])
        if banksize >= (counts == 1).sum():
            bank_pro[counts == 1] = 0.0
            bank_pro = bank_pro / bank_pro.sum()
            del_num = bank_pro * total_del_num
            del_num, del_index = torch.sort(del_num, descending=True)
            total_num = 0.0
            for i in range(del_num.shape[0]):
                del_num[i] = torch.ceil(del_num[i])
                total_num = total_num + torch.ceil(del_num[i])
                if total_num > total_del_num:
                    del_num[i] = del_num[i] - total_num + total_del_num
                    del_num[i + 1:] = 0.0
                    break
            mask = torch.where(del_num == 0)[0]
            if mask.numel() > 0:
                del_num_mask = mask[0]
            else:
                del_num_mask = del_num.shape[0]
            del_target = (unique_target[del_index])[:del_num_mask]
            del_mask = torch.ones_like(total_target, dtype=torch.bool)
            for cls, num in zip(del_target, del_num):
                cls = int(cls)
                num = int(num)
                indices = (total_target == cls).nonzero(as_tuple=True)[0]
                if len(indices) <= num:
                    chosen = indices
                else:
                    perm = torch.randperm(len(indices))
                    chosen = indices[perm[:num]]
                del_mask[chosen] = False
        else:
            del_mask = torch.ones_like(total_target, dtype=torch.bool)
            maj_cla = unique_target[torch.where(counts > 1)]
            maj_counts = counts[torch.where(counts > 1)]
            for cla, cous in zip(maj_cla,maj_counts):
                indices = (total_target == cla).nonzero(as_tuple=True)[0]
                perm = torch.randperm(len(indices))
                chosen = indices[perm[:cous-1]]
                del_mask[chosen] = False
            del_num = maj_cla.shape[0] + (counts == 1).sum() - banksize
            indices = torch.where(del_mask == True)[0]
            perm = torch.randperm(len(indices))
            chosen = indices[perm[:del_num]]
            del_mask[chosen] = False

        contrast_images_fit = total_images[del_mask]
        target_fit = total_target[del_mask]
        tail_mask = total_mask[del_mask]

    return contrast_images_fit, target_fit, tail_mask

feat_map_global = None
grad_map_global = None

def _hook_a(module, input, output):
    global feat_map_global
    feat_map_global[output.device.index] = output


def _hook_g(module, grad_in, grad_out):
    global grad_map_global
    grad_map_global[grad_out[0].device.index] = grad_out[0]

def get_background_mask(model, classifier, input, target, config, mode='GradCAM'):
    if config.gpu is not None:
        target_layer = model.last_layer
        fc_layer = classifier.weight
    else:
        target_layer = model.module.last_layer
        fc_layer = classifier.module.weight

    hook_a = target_layer.register_forward_hook(_hook_a)
    hook_g = target_layer.register_full_backward_hook(_hook_g)

    training_mode = model.training
    model.eval()
    classifier.eval()

    global feat_map_global
    global grad_map_global
    feat_map_global = {}
    grad_map_global = {}

    if config.gpu is not None:
        with torch.no_grad():
            feat = model.forward_1(input)
        feat = model.forward_2(feat.detach())
    else:
        with torch.no_grad():
            feat = model.module.forward_1(input)
        feat = model.module.forward_2(feat.detach())
    output = classifier(feat)
    loss = output[target >= 0, target].sum()
    model.zero_grad()
    classifier.zero_grad()
    loss.backward(retain_graph=False)

    hook_a.remove()
    hook_g.remove()

    if isinstance(model, torch.nn.DataParallel):
        feat_map = []
        grad_map = []
        for i in model.device_ids:
            if i in feat_map_global.keys():
                feat_map.append(feat_map_global[i].cuda(config.gpu))
                grad_map.append(grad_map_global[i].cuda(config.gpu))
        feat_map = torch.cat(feat_map)
        grad_map = torch.cat(grad_map)
    else:
        device_id = input.device.index
        feat_map = feat_map_global[device_id]
        grad_map = grad_map_global[device_id]

    with torch.no_grad():
        if mode == 'CAM':
            weights = fc_layer[target].unsqueeze(-1).unsqueeze(-1)
            cam = (weights * feat_map).sum(dim=1, keepdim=True)
        elif mode == 'GradCAM':
            weights = grad_map.mean(dim=(2, 3), keepdim=True)
            cam = (weights * feat_map).sum(dim=1, keepdim=True)
            cam = F.relu(cam, inplace=True)

    def _normalize(x):
        x.sub_(x.flatten(start_dim=-2).min(-1).values.unsqueeze(-1).unsqueeze(-1))
        x.div_(x.flatten(start_dim=-2).max(-1).values.unsqueeze(-1).unsqueeze(-1))

    _normalize(cam)
    input_h, input_w = input.shape[-2], input.shape[-1]
    resized_cam = F.interpolate(cam, size=(input_h, input_w), mode='bicubic', align_corners=False)
    resized_cam = resized_cam.clamp(0, 1)
    mask = (1 - resized_cam) ** 2

    model.train(training_mode)
    classifier.train(training_mode)
    return mask, output.detach()


class AccMeter:
    def __init__(self):
        self.top1 = AverageMeter('Acc@1', ':6.3f')
        self.top5 = AverageMeter('Acc@5', ':6.3f')

        self.class_num = torch.zeros(config.num_classes).cuda(config.gpu)
        self.correct = torch.zeros(config.num_classes).cuda(config.gpu)

        self.confidence = np.array([])
        self.pred_class = np.array([])
        self.true_class = np.array([])

    def update(self, output, target, is_prob=False):
        if not is_prob:
            output = torch.softmax(output, dim=1)

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        self.top1.update(acc1[0], target.size(0))
        self.top5.update(acc5[0], target.size(0))

        _, predicted = output.max(1)
        target_one_hot = F.one_hot(target, config.num_classes)
        predict_one_hot = F.one_hot(predicted, config.num_classes)
        self.class_num = self.class_num + target_one_hot.sum(dim=0).to(torch.float)
        self.correct = self.correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float)

        confidence_part, pred_class_part = torch.max(output, dim=1)
        self.confidence = np.append(self.confidence, confidence_part.cpu().numpy())
        self.pred_class = np.append(self.pred_class, pred_class_part.cpu().numpy())
        self.true_class = np.append(self.true_class, target.cpu().numpy())

    def get_shot_acc(self):
        acc_classes = self.correct / self.class_num
        head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100
        med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100
        tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100
        return head_acc, med_acc, tail_acc

    def get_cal(self):
        cal = calibration(self.true_class, self.pred_class, self.confidence, num_bins=15)
        return cal


best_acc1 = defaultdict(float)


def validate(val_loader, model, classifier, config, logger):
    batch_time = AverageMeter('Time', ':6.3f')
    acc_meter = {
        'classifier': AccMeter(),}
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, acc_meter['classifier'].top1, acc_meter['classifier'].top5],
        prefix='Eval: ')

    # switch to evaluate mode
    model.eval()
    classifier.eval()

    with torch.no_grad():
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            if config.gpu is not None:
                input = input.cuda(config.gpu, non_blocking=True)
            if torch.cuda.is_available():
                target = target.cuda(config.gpu, non_blocking=True)

            # compute output
            feat = model(input)
            output1 = classifier(feat)

            # measure accuracy and record loss
            acc_meter['classifier'].update(output1, target)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % config.print_freq == 0:
                progress.display(i, logger)

        global best_acc1
        is_classifier_best = False

        for name in acc_meter.keys():
            entry = acc_meter[name]

            acc1, acc5 = entry.top1.avg, entry.top5.avg
            head_acc, med_acc, tail_acc = entry.get_shot_acc()

            # remember best acc@1
            is_best = acc1 > best_acc1[name]
            if is_best:
                best_acc1[name] = acc1
                if name == 'classifier':
                    is_classifier_best = True

            logger.info(('* ({name})  Acc@1 {acc1:.3f}  HAcc {head_acc:.3f}  MAcc {med_acc:.3f}  TAcc {tail_acc:.3f}  '
                         '(Best Acc@1 {best_acc1:.3f}).').format(
                name=name, acc1=acc1, acc5=acc5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc,
                best_acc1=best_acc1[name]))

    return is_classifier_best


def save_checkpoint(state, is_best, model_dir):
    filename = model_dir + '/current.pth.tar'
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, model_dir + '/model_best.pth.tar')


def adjust_learning_rate(optimizer, epoch, config):
    """Sets the learning rate"""
    if config.cos:
        lr_min = 0
        lr_max = config.lr
        lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(epoch / config.num_epochs * 3.1415926535))
        warmup_epochs = 5
        if epoch <= warmup_epochs:
            lr = config.lr / warmup_epochs * (epoch + 1)
    else:
        warmup_epochs = 5
        epoch = epoch + 1
        if epoch <= warmup_epochs:
            lr = config.lr * epoch / warmup_epochs
        elif epoch > 180:
            lr = config.lr * 0.02
        elif epoch > 160:
            lr = config.lr * 0.1
        else:
            lr = config.lr
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


if __name__ == '__main__':
    main()
