import datetime
import math
import sys
import time
import torch
from timm.loss import SoftTargetCrossEntropy
from torch.nn import CrossEntropyLoss
from torch.utils.data import ConcatDataset

import data
import utils.lr_sched as lr_sched
import utils.misc as misc
from trainer import Trainer
from utils.metric import AverageMeter
from utils.misc import NativeScalerWithGradNormCount as NativeScaler


class DietCL(Trainer):

    def get_mem_loader(self, args, task):
        if task > 0:
            mem_set = data.get_mem_set(args, task)
            mem_dataset = ConcatDataset(mem_set)
            print(f'=> Prepare mem set, mem train set length {mem_dataset.__len__()}')
            mem_sampler = torch.utils.data.distributed.DistributedSampler(mem_dataset)
            mem_loader = torch.utils.data.DataLoader(
                mem_dataset, batch_size=self.mem_labeled_bs,
                shuffle=(mem_sampler is None),
                num_workers=args.workers, pin_memory=True, sampler=mem_sampler, drop_last=True)
            return mem_loader
        else:
            return None

    def compute_mem_loss_s1(self, model, mem_samples, mem_targets, args):
        if self.mixup_fn is not None:
            # smoothing is handled with mixup label transform
            criterion = SoftTargetCrossEntropy()
        else:
            criterion = torch.nn.CrossEntropyLoss()
        mem_samples = mem_samples.float().cuda(args.gpu, non_blocking=True)
        mem_targets = mem_targets.float().cuda(args.gpu, non_blocking=True)

        if self.mixup_fn is not None:
            if len(mem_samples) % 2 != 0:
                mem_samples = mem_samples[:-1]
                mem_targets = mem_targets[:-1]
            mem_samples, mem_targets = self.mixup_fn(mem_samples, mem_targets)

        with torch.cuda.amp.autocast(enabled=False):
            mem_logits = model(mem_samples, loss_pattern="classification")
            loss = args.mem_coef_s1 * criterion(mem_logits[:, :-args.new_classes], mem_targets[:, :-args.new_classes])
        del mem_samples, mem_targets

        return loss




    def get_batch_size_for_two_stages(self, args, task, stage):
        if stage == 1:
            self.unlabeled_bs = int(args.batch_size * args.batch_split)
            if args.cur_task_separate_men_set:
                self.mem_labeled_bs = int(args.batch_size * args.mem_sampling_rate)
                self.cur_labeled_bs = args.batch_size - self.unlabeled_bs - self.mem_labeled_bs
            else:
                self.cur_labeled_bs = args.batch_size - self.unlabeled_bs
                self.mem_labeled_bs = 0
        else:
            self.mem_labeled_bs = args.batch_size
            self.cur_labeled_bs = 0
            self.unlabeled_bs = 0

    def train_extra_budget(self, steps, args, model, task, loss_scaler, optimizer):
        # data
        self.get_batch_size_for_two_stages(args, task, stage=2)
        mem_loader = self.get_mem_loader(args, task)
        # iter count
        train_iter = int(steps * 1024 / args.batch_size_total)
        accum_iter = int(1024 / args.batch_size_total)
        # optimizer
        optimizer.zero_grad()
        print(
            f'=> Extra stage: Train {steps} steps, {train_iter} batches, gradient steps every {accum_iter} batches, '
            f'unlabeled batch size {self.unlabeled_bs}, '
            f'previous labeled batch size {self.mem_labeled_bs}, '
            f'current labeled batch size {self.cur_labeled_bs}, ')

        self.train_iters(args,
                         task, train_iter, accum_iter,
                         optimizer, model, loss_scaler,
                         mem_loader=mem_loader, stage='extra', base_lr=args.lr_extra)

    def train_iters(self, args,
                    task, train_iter, accum_iter,
                    optimizer, model,
                    loss_scaler,
                    labeled_loader=None, mem_loader=None, unlabeled_loader=None, stage='', base_lr=None):
        # loss and time logger
        data_time = AverageMeter()
        batch_time = AverageMeter()
        loss_labeled_value_accum, loss_unlabeled_value_accum = 0, 0
        cur_iter_count, mem_iter_count, unsup_iter_count = 0, 0, 0
        unlabeled_epoch_count, mem_epoch_count, labeled_epoch_count = 0, 0, 0
        for iter_count in range(train_iter):
            end = time.time()
            # we use a per iteration  lr scheduler
            if (iter_count + 1) % accum_iter == 0:
                lr = lr_sched.adjust_learning_rate(optimizer, iter_count, args, total_iter=train_iter, base_lr=base_lr)
                misc.logging('step', iter_count // accum_iter, f'task{task}/{stage}lr', lr, args)
            cur, mem, unsup = labeled_loader is not None, mem_loader is not None, unlabeled_loader is not None
            if cur:
                labeled_iter_per_epoch = len(labeled_loader)
                if cur_iter_count % labeled_iter_per_epoch == 0:
                    labeled_loader.sampler.set_epoch(labeled_epoch_count)
                    labeled_iter = iter(labeled_loader)
                    labeled_epoch_count += 1

            if unsup:
                unlabeled_iter_per_epoch = len(unlabeled_loader)
                if unsup_iter_count % unlabeled_iter_per_epoch == 0:
                    unlabeled_loader.sampler.set_epoch(unlabeled_epoch_count)
                    unlabeled_iter = iter(unlabeled_loader)
                    unlabeled_epoch_count += 1
            if mem:
                mem_iter_per_epoch = len(mem_loader)
                if mem_iter_count % mem_iter_per_epoch == 0:
                    mem_loader.sampler.set_epoch(mem_epoch_count)
                    mem_iter = iter(mem_loader)
                    mem_epoch_count += 1

            cur_iter_count += 1 if cur else 0
            mem_iter_count += 1 if mem else 0
            unsup_iter_count += 1 if unsup else 0

            # get mini-batch and log data time
            if unsup:
                unlabeled_samples, _ = next(unlabeled_iter)
            if cur:
                labeled_samples, labeled_targets = next(labeled_iter)
            if mem:
                mem_samples, mem_targets = next(mem_iter)

            cur_data_time = time.time() - end
            cur_data_time_reduce = misc.all_reduce_mean(cur_data_time)
            data_time.update(cur_data_time_reduce)

            loss = torch.zeros(1, dtype=torch.float32).cuda(args.gpu, non_blocking=True)
            if unsup:
                unsup_loss = self.compute_unlabeled_loss(model, unlabeled_samples, args)
                loss += unsup_loss
                loss_unlabeled_value_accum += unsup_loss.item() / accum_iter
                del unlabeled_samples
                torch.cuda.empty_cache()
                if args.debug:
                    print(f'unsup loss {unsup_loss.item()}')

            if cur:
                cur_loss = self.compute_cur_labeled_loss(model, labeled_samples, labeled_targets, args)
                loss += cur_loss
                loss_labeled_value_accum += cur_loss.item() / accum_iter
                del labeled_samples, labeled_targets
                torch.cuda.empty_cache()
                if args.debug:
                    print(f'cur loss {cur_loss.item()}')
            if mem:

                mem_loss = self.compute_mem_loss(model, mem_samples, mem_targets, args)
                loss += mem_loss
                loss_labeled_value_accum += mem_loss.item() / accum_iter
                torch.cuda.empty_cache()
                if args.debug:
                    print(f'distill loss {mem_loss.item()}, scaling factor {args.mem_coef_s1}')

            loss_value = loss.item()
            # exit when infinite loss value
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping training".format(loss_value))
                sys.exit(1)

            # backwawrd
            loss /= accum_iter
            loss_scaler(loss, optimizer, parameters=model.parameters(),
                        update_grad=(iter_count + 1) % accum_iter == 0)
            if (iter_count + 1) % accum_iter == 0:
                optimizer.zero_grad()
                misc.logging('step', iter_count // accum_iter, f'task{task}/{stage}loss_labeled',
                             loss_labeled_value_accum,
                             args)
                misc.logging('step', iter_count // accum_iter, f'task{task}/{stage}loss_unlabeled',
                             loss_unlabeled_value_accum,
                             args)
                loss_labeled_value_accum, loss_unlabeled_value_accum = 0, 0

            # log
            loss_value_reduce = misc.all_reduce_mean(loss_value)
            cur_batch_time = time.time() - end
            cur_batch_time_reduce = misc.all_reduce_mean(cur_batch_time)
            batch_time.update(cur_batch_time_reduce)

            misc.logging('step', iter_count // accum_iter, f'task{task}/{stage}train_loss', loss_value_reduce,
                         args)

            if iter_count % args.print_freq == 0:
                print_string = f'Iter [{iter_count} / {train_iter}] \t'
                if unsup:
                    print_string += f'Unlabeled Epoch: [{unlabeled_epoch_count - 1}][{unsup_iter_count}/{unlabeled_iter_per_epoch}] \t'
                if mem:
                    print_string += f'Buffer Epoch: [{mem_epoch_count - 1}][{mem_iter_count}/{mem_iter_per_epoch}]\t'
                if cur:
                    print_string += f'Labeled Epoch: [{labeled_epoch_count - 1}][{cur_iter_count}/{labeled_iter_per_epoch}]\t'
                print(print_string,
                      f'Time {batch_time.val: .3f} ({batch_time.avg: .3f})\t'
                      f'Data {data_time.val: .3f} ({data_time.avg: .3f})\t'
                      f'Loss {loss_value_reduce:.4f} \t'
                      )
            if (iter_count + 1) % args.eval_freq == 0:
                short_validate(model, CrossEntropyLoss().cuda(args.gpu), task, args, iter_count // accum_iter)

    def adapt_optm(self, model, args, lr):
        param_groups = lrd.param_groups_lrd(model, args, args.weight_decay,
                                            no_weight_decay_list=model.no_weight_decay(),
                                            layer_decay=args.layer_decay
                                            )
        optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.95))
        return optimizer

    def train(self, model, task, args):

        start_time = time.time()
        self.init_mixup(args.nb_classes)
        model.train(True)
        optimizer = self.adapt_optm(model.module, args, args.lr)
        loss_scaler = NativeScaler()

        # iter
        steps = min(args.steps, args.min_budget) if task > 0 else args.steps
        train_iter = int(steps * 1024 / args.batch_size_total)
        accum_iter = args.accum
        # optimizer
        optimizer.zero_grad()
        # data
        args.iters = train_iter
        self.get_batch_size_for_two_stages(args, task, stage=1)
        labeled_loader, mem_loader = self.get_labeled_laoder(args, task)
        unlabeled_loader = self.get_unlabeled_loader(args, task) if args.unsup_loss else None

        print(
            f'=> Pretrain stage: Train {steps} steps, {train_iter} batches, gradient steps every {accum_iter} batches, '
            f'unlabeled batch size {self.unlabeled_bs}, '
            f'previous labeled batch size {self.mem_labeled_bs}'
            f'current labeled batch size {self.cur_labeled_bs}')
        self.train_iters(args,
                         task, train_iter, accum_iter,
                         optimizer, model, loss_scaler,
                         unlabeled_loader=unlabeled_loader, mem_loader=mem_loader, labeled_loader=labeled_loader)

        model = torch.nn.parallel.DistributedDataParallel(model.module, device_ids=[args.gpu],
                                                          find_unused_parameters=True)
        # extra stage
        if args.steps > args.min_budget and task > 0:
            optimizer = self.adapt_optm(model.module, args, args.lr * 0.1)
            self.train_extra_budget(args.steps - args.min_budget, args, model, task, loss_scaler, optimizer)

        stage_time = time.time() - start_time
        stage_time_reduce = misc.all_reduce_mean(stage_time)
        task_time = datetime.timedelta(seconds=int(stage_time_reduce))
        print(f'Training time: {task_time}')
        misc.logging("task", task, "time", stage_time_reduce / 3600, args)


def short_validate(model, criterion, mtask, args, step):
    if args.no_evaluate:
        return

    print(f"validate model {mtask} ")

    val_sets = ConcatDataset([data.get_val_set(args, task) for task in range(mtask + 1)])
    if args.dist_eval:
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_sets)
    else:
        val_sampler = None
    val_loader = torch.utils.data.DataLoader(val_sets, batch_size=args.batch_size, shuffle=False,
                                             num_workers=args.workers, pin_memory=False, sampler=val_sampler)
    # switch to evaluate mode
    with torch.no_grad():

        losses = AverageMeter()
        acc = AverageMeter()

        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast():
                output = model(input)
                loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            prec1_reduce = misc.all_reduce_mean(prec1[0])

            reduce_loss = misc.all_reduce_mean(loss.item())
            losses.update(reduce_loss, input.size(0))
            acc.update(prec1_reduce, input.size(0))

        print(
            f' * Model {mtask}, Prec@1 {acc.avg:.3f} ')

        misc.logging('step', step, f"{mtask}/val_acc", acc.avg, args)
        misc.logging('step', step, f"{mtask}/val_loss", losses.avg, args)



def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
