import builtins
import datetime
import os
import sys
import time
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import wandb

from torch.utils.data import ConcatDataset

import hydra
import numpy as np
import omegaconf
import data
import utils.misc as misc
from args import setup_args
from balanced import DietCL
from trainer import Trainer
from twostagetrainer import TwoStageTrainer
from utils.metric import AverageMeter, ContinaulMetric

@hydra.main(version_base=None)
def main(args):
    # args and output dir
    # args = setup_args()
    args = omegaconf.OmegaConf.to_container(args)
    args = Namespace(**args)
    # setup output dir
    if args.run_name is None:
        args.run_name = '-'.join([
            args.method,
            args.dataset,
            os.environ.get("SLURM_JOB_ID", ""),
        ])
    args.output_dir = os.path.join(args.output_dir, args.run_name)
    os.makedirs(args.output_dir, exist_ok=True)
    print("{}".format(args).replace(', ', ',\n'))

    # for log
    args.start_time = time.time()
    # random seed
    misc.fix_random_seed(args.seed)
    # ddp
    # args.distributed = args.world_size > 1 or args.multiprocessing_distributed
    if args.distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = args.ngpus_per_node * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=args.ngpus_per_node,
                 args=(args.ngpus_per_node, args))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, args.ngpus_per_node, args)


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    # suppress printing if not master and init ddp
    if args.distributed:
        setup_for_distributed(args.gpu == 0)
        args.rank = args.rank * ngpus_per_node + gpu
        print_now('=> Init process group')
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)

    # log
    if args.wandb_log and args.gpu == 0:
        misc.init_wandb_writer(args)
    else:
        writer = None

    # metric
    global val_top1, val_top5
    val_top1 = ContinaulMetric(args)
    val_top5 = ContinaulMetric(args)

    # init trainer
    trainer = DietCL(args)

    model_without_ddp = trainer.init_model(args)
    model_without_ddp.to(args.gpu)
    # find_unused_parameters = ((task == 0 and args.method == '2stage') or (args.method == 'distill'))
    find_unused_parameters = False
    model = torch.nn.parallel.DistributedDataParallel(model_without_ddp, device_ids=[
                                                        args.gpu], find_unused_parameters=find_unused_parameters)

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # resume
    last_task = -1
    if args.resume and (not args.evaluate):
        last_task = misc.find_latest_checkpoint(args.output_dir)
        print(f'Find last task {last_task}')

    for task in range(args.split):
        print_now(f'========= Task {task} =========')
        torch.cuda.empty_cache()

        if args.dataset == 'cglm':
            args.seen_classes = 10788
            if task == 0:
                args.new_classes = 10788
            else:
                args.new_classes = 0
        elif args.dataset == 'ImageNet2k':
            args.seen_classes = 1000 + 1000/args.split * (task+1)
            args.new_classes = 1000/args.split
        else:
            args.seen_classes = sum(
                [len(os.listdir(f'{args.data}/{i}/val/')) for i in range(task + 1)])
            args.new_classes = len(os.listdir(f'{args.data}/{task}/val/'))
        print_now(
            f'=> {args.seen_classes} classes at task {task}, {args.new_classes} new classes')
        # if args.new_classes > 0:
        #     model_without_ddp.adaptation(args.new_classes)
        #     print_now(model_without_ddp)
       
        if args.evaluate or (task <= last_task):
            print_now(f'* Evaluate task {task}')
            _ = misc.resume_ckpt(args, model.module, last_task=task)
            if args.light_eval:
                short_validate(model, criterion, task, args)
            else:
                validate(model, criterion, task, args)
            continue

        args.cur_task_separate_men_set = (
            task > 0 and args.sampling == 'batchmix')

        # =========train stage =========

        trainer.train(model, task, args)
        torch.distributed.barrier()
        torch.cuda.synchronize(0)

        if args.light_eval:
            short_validate(model, criterion, task, args)
        else:
            validate(model, criterion, task, args)

        model_without_ddp = model.module
        misc.save_checkpoint(
            {'model': model_without_ddp.state_dict()}, task, args)

    if args.wandb_log and args.gpu == 0:
        wandb.finish()
    return


def validate(model, criterion, mtask, args):
    if args.no_evaluate:
        return
    start = time.time()

    batch_time = AverageMeter()
    # switch to evaluate mode
    with torch.no_grad():

        for dtask in range(mtask + 1):
            val_loader = data.get_val_loader(args, dtask)
            losses = AverageMeter()
            end = time.time()
            for i, (input, target) in enumerate(val_loader):
                input = input.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

                # compute output
                with torch.cuda.amp.autocast():
                    output = model(input)
                    loss = criterion(output, target)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
                prec1_reduce = misc.all_reduce_mean(prec1[0])
                prec5_reduce = misc.all_reduce_mean(prec5[0])
                reduce_loss = misc.all_reduce_mean(loss.item())
                losses.update(reduce_loss, input.size(0))
                val_top1.update(mtask, dtask, prec1_reduce, input.size(0))
                val_top5.update(mtask, dtask, prec5_reduce, input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if (i + 1) % args.val_print_freq == 0:
                    print(
                        f"Test: [{i}/{len(val_loader)}]\t, ",
                        f"Time {datetime.timedelta(seconds=int(batch_time.val))} ({datetime.timedelta(seconds=int(batch_time.avg))})\t"
                        f" Loss {losses.val:.4f} ({losses.avg:.4f})\t",
                        f"Prec@1 {prec1[0]:.3f}\t",
                        f"Prec@5 {prec5[0]:.3f}")

            val_top1.update_metric(mtask, dtask)
            val_top5.update_metric(mtask, dtask)
            print(
                f' * Model {mtask}, Data {dtask}, Prec@1 {val_top1.matrix[mtask, dtask]:.3f} Prec@5 {val_top5.matrix[mtask, dtask]:.3f}')

        print(
            f' * End evaluation: task accuracy top1 {val_top1.task_avg:.2f}, top5 {val_top5.task_avg:.2f} ')
        print(
            f' * End evaluation: averaged task accuracy top1 {val_top1.sample_avg:.2f}, top5 {val_top5.sample_avg:.2f} ')
        print(f' * Final accuracy matrix ')
        val_top1.print_matrix('Top1 accuracy')
        val_top5.print_matrix('Top5 accuracy')

        print(
            f'Total evaluation time {datetime.timedelta(seconds=int(time.time() - start))}')

        misc.logging('task', mtask, "task average acc",
                     val_top1.task_avg, args)
        misc.logging('task', mtask, "average learning acc", val_top1.tla, args)
        misc.logging('task', mtask, "backward transfer", val_top1.bwt, args)
        misc.logging('task', mtask, "averaged task average acc",
                     val_top1.sample_avg, args)


def short_validate(model, criterion, mtask, args):
    if args.no_evaluate:
        return

    print(f"Directly evaluate model {mtask} over test set from 1 to t")
    start = time.time()

    batch_time = AverageMeter()

    if args.dataset in ['cglm', 'cloc']:
        val_loader = data.get_val_loader(args, mtask)
    else:
        val_sets = ConcatDataset([data.get_val_set(args, task)
                                 for task in range(mtask + 1)])
        if args.dist_eval:
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_sets)
        else:
            val_sampler = None
        val_loader = torch.utils.data.DataLoader(val_sets, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=False, sampler=val_sampler)
    # switch to evaluate mode
    with torch.no_grad():

        losses = AverageMeter()
        end = time.time()
        for i, (input, target) in enumerate(val_loader):
            input = input.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            with torch.cuda.amp.autocast():
                output = model(input)
                loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            prec1_reduce = misc.all_reduce_mean(prec1[0])
            prec5_reduce = misc.all_reduce_mean(prec5[0])
            reduce_loss = misc.all_reduce_mean(loss.item())
            losses.update(reduce_loss, input.size(0))
            val_top1.update(mtask, mtask, prec1_reduce, input.size(0))
            val_top5.update(mtask, mtask, prec5_reduce, input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % args.val_print_freq == 0:
                print(
                    f"Test: [{i}/{len(val_loader)}]\t, ",
                    f"Time {datetime.timedelta(seconds=int(batch_time.val))} ({datetime.timedelta(seconds=int(batch_time.avg))})\t"
                    f" Loss {losses.val:.4f} ({losses.avg:.4f})\t",
                    f"Prec@1 {prec1[0]:.3f}\t",
                    f"Prec@5 {prec5[0]:.3f}")

        val_top1.update_metric(mtask, mtask)
        val_top5.update_metric(mtask, mtask)
        print(
            f' * Model {mtask}, Prec@1 {val_top1.matrix[mtask, mtask]:.3f} Prec@5 {val_top5.matrix[mtask, mtask]:.3f}')
        print(
            f' * Model {mtask}, AVG Prec@1 {val_top1.tla:.3f} Prec@5 {val_top5.tla:.3f}')

        misc.logging('task', mtask, "task average acc",
                     val_top1.matrix[mtask, mtask], args)
        misc.logging('task', mtask, "averaged task average acc",
                     val_top1.tla, args)
        print(
            f'Total evaluation time {datetime.timedelta(seconds=int(time.time() - start))}')


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    builtin_print = builtins.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        force = force or (get_world_size() > 8)
        if is_master or force:
            now = datetime.datetime.now().time()
            builtin_print('[{}] '.format(now), end='')  # print with time stamp
            builtin_print(*args, **kwargs)

    builtins.print = print


def print_now(content):
    print(content)
    sys.stdout.flush()


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


if __name__ == '__main__':
    main()
