import datetime
import math
import sys
import time

import torch
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from torch.nn import CrossEntropyLoss
from torch.utils.data import ConcatDataset

import data
import models
import utils.lr_decay as lrd
import utils.lr_sched as lr_sched
import utils.misc as misc
from utils.metric import AverageMeter
from utils.misc import NativeScalerWithGradNormCount as NativeScaler


class Trainer(object):

    def __init__(self, args):
        self.input_size = 224
        self.mask_ratio = 0.75
        self.drop_path = 0.1
        self.global_pool = True
        self.weight_decay = 0.05
        self.layer_decay = 0.65
        self.mixup = 0.8
        self.cutmix = 1.0
        self.cutmix_minmax = None
        self.mixup_prob = 1.0
        self.mixup_switch_prob = 0.5
        self.smoothing = 0.1
        self.max_norm = None
        self.mixup_mode = 'batch'

        self.model = args.model_name

    def init_mixup(self, nb_classes):
        self.mixup_fn = None
        mixup_active = self.mixup > 0 or self.cutmix > 0. or self.cutmix_minmax is not None
        if mixup_active:
            print("Mixup is activated!")
            self.mixup_fn = Mixup(
                mixup_alpha=self.mixup, cutmix_alpha=self.cutmix, cutmix_minmax=self.cutmix_minmax,
                prob=self.mixup_prob, switch_prob=self.mixup_switch_prob, mode=self.mixup_mode,
                label_smoothing=self.smoothing, num_classes=nb_classes)

    def init_model(self, args):
        model = models.__dict__[self.model](decoder=args.unsup_loss, nb_classes=args.nb_classes)
        if args.pretrained:
            checkpoint = torch.load(args.pretrained_model, map_location=torch.device('cpu'))
            model.load_state_dict(checkpoint['model'], strict=False)
            del checkpoint
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)

        return model

    def adapt_optm(self, model, args):
        param_groups = lrd.param_groups_lrd(model, args, args.weight_decay,
                                            no_weight_decay_list=model.no_weight_decay(),
                                            layer_decay=args.layer_decay
                                            )
        optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
        return optimizer

    def get_batch_size(self, args):
        self.unlabeled_bs = int(args.batch_size * (args.batch_split))
        if args.cur_task_separate_men_set:
            self.mem_labeled_bs = int(args.batch_size * args.mem_sampling_rate)
            self.cur_labeled_bs = args.batch_size - self.unlabeled_bs - self.mem_labeled_bs
        else:
            self.cur_labeled_bs = args.batch_size - self.unlabeled_bs


    def get_labeled_laoder(self, args, task):
        labeled_set, mem_dataset = data.get_labeled_set(args, task)
        print(f'=> Prepare labeled set, labeled train set length {labeled_set.__len__()}')

        if args.distributed:
            labeled_train_sampler = torch.utils.data.distributed.DistributedSampler(labeled_set)
            if args.cur_task_separate_men_set:
                mem_sampler = torch.utils.data.distributed.DistributedSampler(mem_dataset)
            else:
                mem_sampler = None
        else:
            labeled_train_sampler = None
            mem_sampler = None

        labeled_loader = torch.utils.data.DataLoader(
            labeled_set, batch_size=self.cur_labeled_bs,
            shuffle=(labeled_train_sampler is None),
            num_workers=args.workers, pin_memory=True, sampler=labeled_train_sampler, drop_last=True)
        if args.cur_task_separate_men_set:
            mem_loader = torch.utils.data.DataLoader(
                mem_dataset, batch_size=self.mem_labeled_bs,
                shuffle=(mem_sampler is None),
                num_workers=args.workers, pin_memory=True, sampler=mem_sampler, drop_last=True)
        else:
            mem_loader = None
            # compute budget
        labeled_epoch = args.iters // len(labeled_loader) + 1
        misc.logging('task', task, 'label epoch', labeled_epoch, args)
        print(f"=> Labeled iter {args.iters}. Labeled Epoch {labeled_epoch}")
        return labeled_loader, mem_loader

    # def get_unlabeled_loader(self, args, task):
    #     unlabeled_set = data.get_unlabeled_set(args, task)
    #     print(f'=> Prepare unlabeled set done, unlabeled train set length {unlabeled_set.__len__()}')
    #     unlabeled_bs = int(args.batch_size * (args.batch_split))
    #     if args.distributed:
    #         unlabeled_train_sampler = torch.utils.data.distributed.DistributedSampler(unlabeled_set)
    #     else:
    #         unlabeled_train_sampler = None
    #     unlabeled_loader = torch.utils.data.DataLoader(
    #         unlabeled_set, batch_size=int(self.unlabeled_bs),
    #         shuffle=(unlabeled_train_sampler is None),
    #         num_workers=args.workers, pin_memory=True, sampler=unlabeled_train_sampler, drop_last=True)
    #     return unlabeled_loader

    # def compute_cur_labeled_loss(self, model, labeled_samples, labeled_targets, args):

    #     if self.mixup_fn is not None:
    #         # smoothing is handled with mixup label transform
    #         criterion = SoftTargetCrossEntropy()
    #     else:
    #         criterion = torch.nn.CrossEntropyLoss()
    #     labeled_samples = labeled_samples.float().cuda(args.gpu, non_blocking=True)
    #     labeled_targets = labeled_targets.float().cuda(args.gpu, non_blocking=True)
    #     if self.mixup_fn:
    #         if len(labeled_samples) % 2 != 0:
    #             labeled_samples = labeled_samples[:-1]
    #             labeled_targets = labeled_targets[:-1]
    #         labeled_samples, labeled_targets = self.mixup_fn(labeled_samples, labeled_targets)
    #     with torch.cuda.amp.autocast(enabled=False):
    #         labeled_logtis = model(labeled_samples, loss_pattern="classification")
    #         # masked cross-entropy
    #         if args.mask_cur_loss:
    #             loss = criterion(labeled_logtis[:, -args.new_classes:], labeled_targets[:, -args.new_classes:])
    #         else:
    #             loss = criterion(labeled_logtis, labeled_targets)
    #     del labeled_samples, labeled_targets
    #     return loss

    # def compute_mem_loss(self, model, mem_samples, mem_targets, args):

    #     if self.mixup_fn is not None:
    #         # smoothing is handled with mixup label transform
    #         criterion = SoftTargetCrossEntropy()
    #     else:
    #         criterion = torch.nn.CrossEntropyLoss()
    #     mem_samples = mem_samples.float().cuda(args.gpu, non_blocking=True)
    #     mem_targets = mem_targets.float().cuda(args.gpu, non_blocking=True)

    #     if self.mixup_fn is not None:
    #         if len(mem_samples) % 2 != 0:
    #             mem_samples = mem_samples[:-1]
    #             mem_targets = mem_targets[:-1]
    #         mem_samples, mem_targets = self.mixup_fn(mem_samples, mem_targets)

    #     with torch.cuda.amp.autocast(enabled=False):
    #         mem_logits = model(mem_samples, loss_pattern="classification")
    #         loss = criterion(mem_logits, mem_targets)
    #     del mem_samples, mem_targets

    #     return loss

    # def compute_unlabeled_loss(self, model, unlabeled_samples, args):
    #     unlabeled_input = unlabeled_samples.float().cuda(args.gpu, non_blocking=True)
    #     with torch.cuda.amp.autocast(enabled=False):
    #         loss_unlabeled = args.unlabeled_coef * model(unlabeled_input, mask_ratio=self.mask_ratio,
    #                                                      loss_pattern="reconstruction")[0]
    #     del unlabeled_samples
    #     return loss_unlabeled

    # def current_loss_data(self, args, task, iter_count):
    #     if args.unsup_loss:
    #         unsup = True
    #     else:
    #         unsup = False
    #     if args.cur_task_separate_men_set:
    #         mem = True
    #     else:
    #         mem = False
    #     cur = True
    #     return cur, mem, unsup

    def train(self, model, task, args):

        self.init_mixup(args.seen_classes)
        optimizer = self.adapt_optm(model.module, args)

        self.get_batch_size(args)
        labeled_loader, mem_loader = self.get_labeled_laoder(args, task)
        unlabeled_loader = self.get_unlabeled_loader(args, task) if args.unsup_loss else None

        # when to reset epoch for dataloaders
        labeled_epoch_count = 0
        labeled_iter_per_epoch = len(labeled_loader)
        unlabeled_epoch_count = 0
        unlabeled_iter_per_epoch = len(unlabeled_loader)
        if mem_loader is not None:
            mem_epoch_count = 0
            mem_iter_per_epoch = len(mem_loader)

        cur_iter_count, mem_iter_count, unsup_iter_count = 0, 0, 0

        start_time = time.time()
        train_iter = args.iters
        accum_iter = args.accum

        # loss and time logger
        loss_scaler = NativeScaler()
        data_time = AverageMeter()
        batch_time = AverageMeter()


        model.train(True)
        optimizer.zero_grad()

        loss_labeled_value_accum, loss_unlabeled_value_accum = 0, 0
        cur_loss_accum, pre_loss_accum = 0,0

        for iter_count in range(train_iter):
            end = time.time()
            # we use a per iteration  lr scheduler
            if (iter_count + 1) % accum_iter == 0:
                lr = lr_sched.adjust_learning_rate(optimizer, iter_count, args)
                misc.logging('step', iter_count // accum_iter, f'task{task}/lr', lr, args)

            cur, mem, unsup = self.current_loss_data(args, task, iter_count)


            # reset epoch for laoder
            if cur and cur_iter_count % labeled_iter_per_epoch == 0:
                labeled_loader.sampler.set_epoch(labeled_epoch_count)
                labeled_iter = iter(labeled_loader)
                labeled_epoch_count += 1
            if unsup and unsup_iter_count % unlabeled_iter_per_epoch == 0:
                unlabeled_loader.sampler.set_epoch(unlabeled_epoch_count)
                unlabeled_iter = iter(unlabeled_loader)
                unlabeled_epoch_count += 1
            if mem and mem_iter_count % mem_iter_per_epoch == 0:
                mem_loader.sampler.set_epoch(mem_epoch_count)
                mem_iter = iter(mem_loader)
                mem_epoch_count += 1

            cur_iter_count += 1 if cur else 0
            mem_iter_count += 1 if mem else 0
            unsup_iter_count += 1 if unsup else 0

            # get mini-batch and log data time
            if unsup:
                unlabeled_samples, _ = next(unlabeled_iter)
            if cur:
                labeled_samples, labeled_targets = next(labeled_iter)
            if mem:
                mem_samples, mem_targets = next(mem_iter)

            cur_data_time = time.time() - end
            cur_data_time_reduce = misc.all_reduce_mean(cur_data_time)
            data_time.update(cur_data_time_reduce)

            loss = torch.zeros(1, dtype=torch.float32).cuda(args.gpu, non_blocking=True)
            if unsup:
                unsup_loss = self.compute_unlabeled_loss(model, unlabeled_samples, args)
                loss += unsup_loss
                loss_unlabeled_value_accum += unsup_loss.item() / accum_iter
                del unlabeled_samples
                torch.cuda.empty_cache()
            if cur:
                if args.debug:
                    print (labeled_targets.shape, labeled_targets)
                cur_loss = self.compute_cur_labeled_loss(model, labeled_samples, labeled_targets, args)
                loss += cur_loss
                loss_labeled_value_accum += cur_loss.item() / accum_iter
                cur_loss_accum += cur_loss.item() / accum_iter
                del labeled_samples, labeled_targets
                torch.cuda.empty_cache()
            if mem:
                mem_loss = self.compute_mem_loss(model, mem_samples, mem_targets, args)
                loss += mem_loss
                if args.debug:
                    print (mem_targets.shape, mem_targets)
                loss_labeled_value_accum += mem_loss.item() / accum_iter
                pre_loss_accum += mem_loss.item() / accum_iter
                torch.cuda.empty_cache()

            loss_value = loss.item()
            # exit when infinite loss value
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            # backwawrd
            loss /= accum_iter
            loss_scaler(loss, optimizer, parameters=model.parameters(),
                        update_grad=(iter_count + 1) % accum_iter == 0)
            if (iter_count + 1) % accum_iter == 0:
                optimizer.zero_grad()
                misc.logging('step', iter_count // accum_iter, f'task{task}/loss_labeled',
                             loss_labeled_value_accum,
                             args)
                misc.logging('step', iter_count // accum_iter, f'task{task}/loss_unlabeled',
                             loss_unlabeled_value_accum,
                             args)
                misc.logging('step', iter_count // accum_iter, f'task{task}/cur_task_loss',
                             cur_loss_accum,
                             args)
                misc.logging('step', iter_count // accum_iter, f'task{task}/buffer_loss',
                             pre_loss_accum,
                             args)
                loss_labeled_value_accum, loss_unlabeled_value_accum = 0, 0
                cur_loss_accum, pre_loss_accum = 0, 0

            # log
            loss_value_reduce = misc.all_reduce_mean(loss_value)
            cur_batch_time = time.time() - end
            cur_batch_time_reduce = misc.all_reduce_mean(cur_batch_time)
            batch_time.update(cur_batch_time_reduce)

            misc.logging('step', iter_count // accum_iter, f'task{task}/train_loss', loss_value_reduce,
                         args)

            if iter_count % args.print_freq == 0:
                print('Unlabeled Epoch: [{0}][{1}/{2}], Labeled Epoch: [{3}][{4}/{5}] Iter: [{6}]\t'.format(
                    unlabeled_epoch_count - 1, iter_count % unlabeled_iter_per_epoch, unlabeled_iter_per_epoch,
                    labeled_epoch_count - 1, iter_count % labeled_iter_per_epoch, labeled_iter_per_epoch,
                    iter_count),
                    f'Time {batch_time.val: .3f} ({batch_time.avg: .3f})\t'
                    f'Data {data_time.val: .3f} ({data_time.avg: .3f})\t'
                    f'Loss {loss_value_reduce:.4f} \t'
                )
            if (iter_count + 1) % args.eval_freq == 0:
                short_validate(model, CrossEntropyLoss().cuda(args.gpu), task, args, iter_count // accum_iter)

        stage_time = time.time() - start_time
        stage_time_reduce = misc.all_reduce_mean(stage_time)
        task_time = datetime.timedelta(seconds=int(stage_time_reduce))
        print(f'Training time: {task_time}')
        misc.logging("task", task, "time", stage_time_reduce / 3600, args)


class TwoStageTrainer(Trainer):
    def current_loss_data(self, args, task, iter_count):
        if iter_count < args.iter * args.split:
            cur = False
            mem = True if task > 0 else False
            unsup = True
        else:
            cur = True
            mem = True if task > 0 else False
            unsup = False

        return cur, mem, unsup


def short_validate(model, criterion, mtask, args, step, writer):
    if args.no_evaluate:
        return

    print(f"validate model {mtask} ")

    val_sets = ConcatDataset([data.get_val_set(args, task) for task in range(mtask + 1)])
    if args.dist_eval:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_sets)
    else:
        val_sampler = None
    val_loader = torch.utils.data.DataLoader(val_sets, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=False, sampler=val_sampler)
    # switch to evaluate mode
    with torch.no_grad():

        losses = AverageMeter()
        acc = AverageMeter()

        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast():
                output = model(input)
                loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            prec1_reduce = misc.all_reduce_mean(prec1[0])

            reduce_loss = misc.all_reduce_mean(loss.item())
            losses.update(reduce_loss, input.size(0))
            acc.update(prec1_reduce, input.size(0))

        print(
            f' * Model {mtask}, Prec@1 {acc.avg:.3f} ')

        misc.logging('step', step, f"{mtask}/val_acc", acc.avg, args)
        misc.logging('step', step, f"{mtask}/val_loss", losses.avg, args)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
