import os, pdb
import time
import argparse
import datetime
import numpy as np
import wandb
from collections import defaultdict
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter

from config import get_config
from models import build_model, build_CoMM_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from util import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, load_comm_pretrained

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None

import sys
sys.path.insert(0, '../src/models')


# if backbone == 'swin':
sw_window_size = (7, 7)
sw_coords_h = torch.arange(-sw_window_size[0] + 1, sw_window_size[0])
sw_coords_w = torch.arange(-sw_window_size[0] + 1, sw_window_size[1])
sw_coords_h, sw_coords_w = torch.meshgrid([sw_coords_h, sw_coords_w])

sw_coords_dist = (sw_coords_h ** 2 + sw_coords_w ** 2) ** 0.5
sw_coords = F.unfold(sw_coords_dist.reshape((1, 1, 13, 13)), kernel_size=7).flip(dims=(-1,))[0]        

# elif backbone == 'vit':
window_size = (14, 14)
coords_h = torch.arange(-window_size[0] + 1, window_size[0])
coords_w = torch.arange(-window_size[0] + 1, window_size[1])
coords_h, coords_w = torch.meshgrid([coords_h, coords_w])
coords_dist = (coords_h ** 2 + coords_w ** 2) ** 0.5
coords = F.unfold(coords_dist.reshape((1, 1, 27, 27)), kernel_size=14).flip(dims=(-1,))[0]
# else:
#     NotImplementedError

def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--task-id', type=int, help="task_id")
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--weight-decay', type=float, help="")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--pretrained', type=str, help='path to pre-trained model')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')
    parser.add_argument('--linear-probe', action='store_true', help='linear probe for validation')
    parser.add_argument("--lr-multiplier", type=float, default=1, help='multiplier for base/warmup/min_lr')
    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
    parser.add_argument('--comm', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--ours', action='store_true')
    parser.add_argument('--ours-type', type=str, default='dist', choices=['dist', 'entropy'])
    parser.add_argument('--comm-hyp', type=float, default=1.0)
    parser.add_argument('--adap-init-scale', type=float, default=0.5)
    parser.add_argument('--save-ckpt', action='store_true')
    
    args = parser.parse_args()
    config = get_config(args)
    return args, config


def main(config):
    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
    config.defrost()
    if config.DATA.TASK_ID != -1:
        config.MODEL.NUM_CLASSES = 100
    else:
        config.MODEL.NUM_CLASSES = 1000
    config.freeze()
    
    if config.COMM:
        model = build_CoMM_model(config, is_pretrain=False)
    else:
        model = build_model(config, is_pretrain=False)
        
    model.cuda()

    max_accuracy = 0.0
    
    if config.PRETRAINED:
        # load_pretrained(config, model_without_ddp, logger)
        load_pretrained(config, model, logger)
        # load_comm_pretrained(config, model, logger, True)
    if dist.get_rank() == 0:
        logger_names = config.TAG.split('/')
        prt_info = 'p-' + logger_names[2]
        prt_info = prt_info.replace('SeqImageNet_UnsupNaive_' , '')
        prt_info = prt_info.replace('0510_' , '')
        prt_info = prt_info.replace('0524_run_' , '')
        prt_info = prt_info.replace('0525_run_' , '')
        prt_info = prt_info.replace('mlr5e-06_' , '')
        prt_info = prt_info.replace('mlr1e-05_' , '')
        prt_info = prt_info.replace('trd0p1_' , '')
        prt_info = prt_info.replace('run_0' , '')
        prt_info = prt_info.replace('SeqImageNet_' , '')
        prt_info = prt_info.replace('lr0p0002_' , '')
        
        val_info = 'v-' + logger_names[0] + '_' + logger_names[3] + '_' + logger_names[4]
        val_info = val_info.replace('data_and_', 'd')
        val_info = val_info.replace('model', 'm')
        
        logger_name = prt_info + val_info
        wandb.init(project="LAPS-2022-09-10-CoMM", dir='../../wandb')
        wandb.run.name = logger_name
        logger.info(f'wandb_run_name: {logger_name}')
              
    logger.info("Start training")
    logger.info(f"COMM - ours_type {config.TRAIN.COMM.OURS_TYPE} hyp {config.TRAIN.COMM.HYP}")
    start_time = time.time()
    
    
    for tid in range(config.DATA.CONTFT.START_TASK, config.DATA.CONTFT.END_TASK):
        logger.info(f'Continuously finetuning Task {config.DATA.CONTFT.START_TASK}/{config.DATA.CONTFT.END_TASK}')
    # for tid in range(7, config.DATA.CONTFT.END_TASK):
    #     logger.info(f'Continuously finetuning Task {7}/{config.DATA.CONTFT.END_TASK}')
        config.defrost()
        config.DATA.TASK_ID = tid
        config.freeze()
        # import pdb; pdb.set_trace()
        dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger, is_pretrain=False)

        optimizer = build_optimizer(config, model, logger, is_pretrain=False)
        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))

        if config.AMP_OPT_LEVEL != "O0":
            if hasattr(model, 'module'):
                model.module.init_adapts_weights(logger)        
                model, optimizer = amp.initialize(model.module.to('cpu').to('cuda'), optimizer, opt_level=config.AMP_OPT_LEVEL)                
                model = torch.nn.parallel.DistributedDataParallel(model.module, device_ids=[config.LOCAL_RANK], find_unused_parameters=True)                    
            else:                
                model.init_adapts_weights(logger)        
                model, optimizer = amp.initialize(model.to('cpu').to('cuda'), optimizer, opt_level=config.AMP_OPT_LEVEL)                
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], find_unused_parameters=True)                                
        else:
            if hasattr(model, 'module'):
                model.module.init_adapts_weights(logger)        
                model = torch.nn.parallel.DistributedDataParallel(model.module, device_ids=[config.LOCAL_RANK], find_unused_parameters=True)                                    
            else:                
                model.init_adapts_weights(logger)        
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], find_unused_parameters=True)                                            
        model_without_ddp = model.module

        n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
        logger.info(f"number of params: {n_parameters}")
        if hasattr(model_without_ddp, 'flops'):
            flops = model_without_ddp.flops()
            logger.info(f"number of GFLOPs: {flops / 1e9}")

        if config.AUG.MIXUP > 0.:
            # smoothing is handled with mixup label transform
            criterion = SoftTargetCrossEntropy()
        elif config.MODEL.LABEL_SMOOTHING > 0.:
            criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
        else:
            criterion = torch.nn.CrossEntropyLoss()

        for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
            data_loader_train.sampler.set_epoch(epoch)            

            train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler)
            if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
                save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, tag=f'ContFT_task{tid}')

            acc1, acc5, loss = validate(config, data_loader_val, model)
            logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
            max_accuracy = max(max_accuracy, acc1)
            logger.info(f'Max accuracy: {max_accuracy:.2f}%')
            
            if dist.get_rank() == 0:        
                wandb.log({'epoch': epoch+1, 't1_acc': acc1})
                wandb.log({'epoch': epoch+1, 't5_acc': acc5})
                wandb.log({'epoch': epoch+1, 'loss': loss})
                wandb.log({'epoch': epoch+1, 'lr': optimizer.param_groups[-1]['lr']})

            if config.TRAIN.BREAK_EPOCH == (epoch+1):
                break

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        logger.info('Training time {}'.format(total_time_str))


def entropy(distribution):
    return torch.sum(-distribution * torch.log2(distribution))

def entropy2(distribution):
    return torch.sum(-distribution * torch.log2(distribution), -1)
    # return torch.sum(torch.sum(-distribution * torch.log2(distribution), -1))

def attn_dist_loss(blocks_attn, batch_size, backbone='vit'):
    loss = 0 
            
    # Reducer
    for k in range(len(blocks_attn)):
        if backbone == 'swin':
            layer_attn = blocks_attn[k].clone().sum(0) / batch_size
            target_coords = sw_coords
        elif backbone == 'vit':
            layer_attn = blocks_attn[k][:, 1:, 1:].clone() / batch_size
            target_coords = coords
        else:
            NotImplementedError
            
        # Proc the attention distribution
        for head_idx in range(layer_attn.shape[0]):
            layer_attn[head_idx].fill_diagonal_(0)            
        
        layer_attn = layer_attn / layer_attn.sum(dim=2, keepdim=True)        
        loss -= torch.log(torch.std((layer_attn * target_coords.cuda(non_blocking=True)).sum(dim=2).mean(dim=1), True) + 1e-8)
    return loss

def attn_entropy_loss(blocks_attn, batch_size, backbone='vit'):
    loss = 0
    # Proc
    # for k in blocks_attn:
    #     blocks_attn[k] = blocks_attn[k][0] / batch_size
    for k in range(len(blocks_attn)):
        if backbone == 'swin':
            layer_attn = blocks_attn[k].clone().sum(0) / batch_size
        elif backbone == 'vit':
            layer_attn = blocks_attn[k][:, 1:, 1:].clone() / batch_size
        else:
            NotImplementedError
        layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)
        
        # n_heads = layer_attn.shape[0]
        # n_positions = layer_attn.shape[1]
        # attn_entropy = torch.zeros((n_heads, n_positions))
        # for head in range(n_heads):
        #     for position in range(n_positions):
        #         attn_entropy[head, position] = entropy(layer_attn[head, position])
        attn_entropy = entropy2(layer_attn)
        loss -= torch.log(torch.std(attn_entropy.mean(dim=1), True) + 1e-8)    
    return loss

def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler):
    model.train()
    optimizer.zero_grad()
    
    logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    attn_loss = 0
    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):        
        samples = samples.cuda(non_blocking=True)
        if type(targets) == list:
            targets = list(map(int, targets))
            targets = torch.Tensor(targets).to(torch.int32).cuda(non_blocking=True)
        else:
            targets = targets.cuda(non_blocking=True)

        if config.DATA.TASK_ID != -1:
            targets -= config.DATA.TASK_ID * config.DATA.N_CLASSES_PER_TASK

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)        
                
        if not config.OURS:            
            outputs = model(samples)
        else:            
            outputs, full_attns, base_attns = model(samples, return_attns=True)
            if config.TRAIN.COMM.OURS_TYPE == 'dist':
                attn_loss = attn_dist_loss(base_attns, config.DATA.BATCH_SIZE, backbone=config.MODEL.TYPE)
                # attn_loss = attn_dist_loss(full_attns, config.DATA.BATCH_SIZE, backbone=config.MODEL.TYPE)
            
            elif config.TRAIN.COMM.OURS_TYPE == 'entropy':
                attn_loss = attn_entropy_loss(base_attns, config.DATA.BATCH_SIZE, backbone=config.MODEL.TYPE)
            else:
                NotImplementedError
            # attn_loss = 0
            # Maximize the diversity of attn_distance or entropy from base attns w.r.t. base attn weights
            # Minimize the diversity of attn_distance or entropy from APD attns w.r.t. all weights            
            
        if config.TRAIN.ACCUMULATION_STEPS > 1:            
            pass
        else:
            ce_loss = criterion(outputs, targets)
            loss = ce_loss + config.TRAIN.COMM.HYP * attn_loss
            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[-1]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    if dist.get_rank() == 0:        
        wandb.log({'epoch': epoch+1, 'tr CE loss': ce_loss.item()})
        wandb.log({'epoch': epoch+1, 'tr AT loss': attn_loss.item()})            

    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


@torch.no_grad()
def validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    torch.cuda.empty_cache()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        # target = target.cuda(non_blocking=True)
        if type(target) == list:
            target = list(map(int, target))
            target = torch.Tensor(target).type(torch.LongTensor)
            target = target.cuda(non_blocking=True)
        else:
            target = target.cuda(non_blocking=True)

        if config.DATA.TASK_ID != -1:
            target -= config.DATA.TASK_ID * config.DATA.N_CLASSES_PER_TASK
        # compute output
        output = model(images, return_attns=False)
        # if config.DATA.TASK_ID != -1:
        #     output = task_masking(config, output)
        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        acc1 = reduce_tensor(acc1)
        acc5 = reduce_tensor(acc5)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (idx+1) % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx+1}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg


@torch.no_grad()
def throughput(data_loader, model, logger):
    model.eval()

    for idx, (images, _) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        batch_size = images.shape[0]
        for i in range(50):
            model(images)
        torch.cuda.synchronize()
        logger.info(f"throughput averaged with 30 times")
        tic1 = time.time()
        for i in range(30):
            model(images)
        torch.cuda.synchronize()
        tic2 = time.time()
        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
        return


if __name__ == '__main__':
    _, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    torch.cuda.set_device(config.LOCAL_RANK)
    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
    torch.distributed.barrier()

    seed = config.SEED + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    # linear scale the learning rate according to total batch size, may not be optimal
    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    
    if config.DATA.TASK_ID == -1:
        config.DATA.N_CLASSES_PER_TASK = 1000
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)
    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")    
    logger.info(f'config.DATA.N_CLASSES_PER_TASK: {config.DATA.N_CLASSES_PER_TASK}')

    if dist.get_rank() == 0:
        path = os.path.join(config.OUTPUT, "config.json")
        with open(path, "w") as f:
            f.write(config.dump())
        logger.info(f"Full config saved to {path}")

    # print config
    logger.info(config.dump())

    main(config)