import logging
import os
import time
import argparse
import datetime
import numpy as np

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from timm.utils import AverageMeter

from config import get_config
from models import build_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils import save_images, load_checkpoint, load_mae_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper

os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3"

try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None


def parse_option():
    parser = argparse.ArgumentParser('Vision Electra pre-training script', add_help=False)
    parser.add_argument('--cfg', default=".../VE/configs/vision_electra_vit_base/vision_electra_mae_vit_base__img224_pretrain.yaml",type=str, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path',type=str, help='path to dataset')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')

    # distributed training
    parser.add_argument("--local_rank", default=-1, type=int, help='local rank for DistributedDataParallel')

    args = parser.parse_args()

    config = get_config(args)

    return args, config


def main(config):
    data_loader_train = build_loader(config, logger, is_pretrain=True)

    logger.info(f"Creating generator model:{config.GENERATOR.TYPE}/{config.GENERATOR.NAME}, discriminator model:{config.DISCRIMINATOR.TYPE}/{config.DISCRIMINATOR.NAME}")
    generator, discriminator = build_model(config, is_pretrain=True)
    generator.cuda()
    logger.info(str(generator))
    discriminator.cuda()
    logger.info(str(discriminator))

    # Tensorboard
    writer = None
    if config.LOCAL_RANK in [-1, 0]:
        writer = SummaryWriter(log_dir=os.path.join("logs", config.TAG))

    optimizer_g = build_optimizer(config, generator, logger, 'g', is_pretrain=True)
    optimizer_d = build_optimizer(config, discriminator, logger, 'd', is_pretrain=True)
    if config.AMP_OPT_LEVEL != "O0":
        generator, optimizer_g = amp.initialize(generator, optimizer_g, opt_level=config.AMP_OPT_LEVEL)
        discriminator, optimizer_d = amp.initialize(discriminator, optimizer_d, opt_level=config.AMP_OPT_LEVEL)
    if config.LOCAL_RANK != -1:
        generator = torch.nn.parallel.DistributedDataParallel(generator, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
        generator_without_ddp = generator.module
        discriminator = torch.nn.parallel.DistributedDataParallel(discriminator, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
        discriminator_without_ddp = discriminator.module
    else:
        generator_without_ddp = generator
        discriminator_without_ddp = discriminator

    n_parameters_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
    logger.info(f"Generator number of params: {n_parameters_g}")
    n_parameters_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    logger.info(f"Discriminator number of params: {n_parameters_d}")
    if hasattr(generator_without_ddp, 'flops'):
        flops = generator_without_ddp.flops()
        logger.info(f"Generator number of GFLOPs: {flops / 1e9}")
    if hasattr(discriminator_without_ddp, 'flops'):
        flops = discriminator_without_ddp.flops()
        logger.info(f"Discriminator number of GFLOPs: {flops / 1e9}")

    lr_scheduler_g = build_scheduler(config, optimizer_g, len(data_loader_train))
    lr_scheduler_d = build_scheduler(config, optimizer_d, len(data_loader_train))

    if config.TRAIN.AUTO_RESUME:
        resume_file_g, resume_file_d = auto_resume_helper(config.OUTPUT, logger)
        if resume_file_g:
            if config.GENERATOR.RESUME:
                logger.warning(f"auto-resume changing resume file from {config.GENERATOR.RESUME} to {resume_file_g}")
            config.defrost()
            config.GENERATOR.RESUME = resume_file_g
            config.freeze()
            logger.info(f'Generator auto resuming from {resume_file_g}')
        if resume_file_d:
            if config.DISCRIMINATOR.RESUME:
                logger.warning(f"auto-resume changing resume file from {config.DISCRIMINATOR.RESUME} to {resume_file_d}")
            config.defrost()
            config.DISCRIMINATOR.RESUME = resume_file_d
            config.freeze()
            logger.info(f'Discriminator auto resuming from {resume_file_d}')
        else:
            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
    if config.GENERATOR.RESUME:
        if config.GENERATOR.TYPE == 'mae':
            load_mae_checkpoint(config, generator_without_ddp, optimizer_g, lr_scheduler_g, logger)
        else: 
            load_checkpoint(config, generator_without_ddp, optimizer_g, lr_scheduler_g, logger, 'g')
    if config.DISCRIMINATOR.RESUME:
        load_checkpoint(config, discriminator_without_ddp, optimizer_d, lr_scheduler_d, logger, 'd')

    logger.info("Start training")
    start_time = time.time()
    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
        if config.LOCAL_RANK != -1:
            data_loader_train.sampler.set_epoch(epoch)

        train_one_epoch(config, generator, discriminator, data_loader_train, optimizer_g, optimizer_d, epoch, lr_scheduler_g, lr_scheduler_d, writer)
        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
            save_checkpoint(config, epoch, generator_without_ddp, 0., optimizer_g, lr_scheduler_g, logger, 'g')
            save_checkpoint(config, epoch, discriminator_without_ddp, 0., optimizer_d, lr_scheduler_d, logger, 'd')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Training time {}'.format(total_time_str))
    if config.LOCAL_RANK in [-1, 0]:
        writer.close()


def process_image(image):
    imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=image.device).view(1, 3, 1, 1)
    imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=image.device).view(1, 3, 1, 1)
    # image is [B, 3, H, W]
    assert image.shape[1] == 3
    image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()
    return image

def train_one_epoch(config, generator, discriminator, data_loader, optimizer_g, optimizer_d, epoch, lr_scheduler_g, lr_scheduler_d, writer):
    generator.train()
    discriminator.train()
    optimizer_g.zero_grad()
    optimizer_d.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter_g = AverageMeter()
    loss_meter_d = AverageMeter()
    norm_meter_g = AverageMeter()
    norm_meter_d = AverageMeter()

    start = time.time()
    end = time.time()
    image_revocer_dict = dict()
    for idx, (img, mask, label, _) in enumerate(data_loader):
        img = img.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True) # mask也可以充当最终的segmentation ground truth 1表示replaced 0表示original
        label = label.cuda(non_blocking=True) # fake label 将生成的部分也认为是0，交给Discriminator来计算loss，这个是为了训练Generator
        # Generator
        ## mask unrecover part + real part for visualization
        if config.GENERATOR.TYPE != 'mae':
            x_rec, loss_g1 = generator(img, mask)
            _, _, h, w = x_rec.shape
            _, mh, mw = mask.shape
            w = mask.repeat_interleave(int(h / mh), dim=1).repeat_interleave(int(w / mw), dim=2)
            w = w.unsqueeze(1)
        else:
            loss_g1, x_rec, mask, label, mask_label = generator(img, config.DATA.MASK_RATIO)
            #label = torch.zeros_like(mask, device=img.device)
            # x_rec:[B,3,224,224], mask:[B,3,224,224], label:[B,14,14], mask_label:[B,14,14]
        #print('img:'+str(img.shape))
        #print('mask'+str(mask.shape))
        # masked image
        im_masked = img * (1 - mask)

        # MAE reconstruction pasted with visible patches
        im_paste = img * (1 - mask) + x_rec * mask
        
        image_revocer_dict = {'image':process_image(img.detach()).float(), 'mask': process_image(im_masked.detach()).float(), 'reconstruction': process_image(x_rec.detach()).float(), 'reconstruction + visible': process_image(im_paste.detach()).float()}
        
        gt_image_label = torch.zeros(img.shape[0], 1).cuda(non_blocking=True)
        x_disc, loss_g2 = discriminator(img, x_rec, mask, gt_image_label, mask_label, label)

        loss_g = loss_g1 + 0.2 * loss_g2 # 基于预训练模型 用0.2
        # loss_g = loss_g1 + 0.4 * loss_g2 # from sratch 尝试用0.5

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss_g = loss_g / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss_g, optimizer_g) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_g = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_g), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_g = get_grad_norm(amp.master_params(optimizer_g))
            else:
                loss_g.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_g = get_grad_norm(generator.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer_g.step()
                optimizer_g.zero_grad()
                lr_scheduler_g.step_update(epoch * num_steps + idx)
        else:
            optimizer_g.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss_g, optimizer_g) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_g = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_g), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_g = get_grad_norm(amp.master_params(optimizer_g))
            else:
                loss_g.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_g = get_grad_norm(generator.parameters())
            optimizer_g.step()
            lr_scheduler_g.step_update(epoch * num_steps + idx)

        # Discriminator
        fake_image_label = torch.ones(img.shape[0], 1).cuda(non_blocking=True)
        x_disc, loss_d1 = discriminator(img, x_rec.detach(), mask, fake_image_label, mask_label, None)
        x_disc, loss_d2 = discriminator(img, None, mask, gt_image_label, mask_label, label)
        loss_d = loss_d1 + loss_d2
        
        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss_d = loss_d / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss_d, optimizer_d) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_d),
                                                                 config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_d = get_grad_norm(amp.master_params(optimizer_d))
            else:
                loss_d.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_d = get_grad_norm(discriminator.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer_d.step()
                optimizer_d.zero_grad()
                lr_scheduler_d.step_update(epoch * num_steps + idx)
        else:
            optimizer_d.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss_d, optimizer_d) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_d),
                                                                 config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_d = get_grad_norm(amp.master_params(optimizer_d))
            else:
                loss_d.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm_d = torch.nn.utils.clip_grad_norm_(discriminator.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm_d = get_grad_norm(discriminator.parameters())
            optimizer_d.step()
            lr_scheduler_d.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter_g.update(loss_g.item(), img.size(0))
        norm_meter_g.update(grad_norm_g)

        loss_meter_d.update(loss_d.item(), img.size(0))
        norm_meter_d.update(grad_norm_d)

        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr_g = optimizer_g.param_groups[0]['lr']
            lr_d = optimizer_d.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr_g {lr_g:.6f} lr_d {lr_d:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss_g {loss_meter_g.val:.4f} ({loss_meter_g.avg:.4f})\t'
                f'loss_d {loss_meter_d.val:.4f} ({loss_meter_d.avg:.4f})\t'
                f'grad_norm_g {norm_meter_g.val:.4f} ({norm_meter_g.avg:.4f})\t'
                f'grad_norm_d {norm_meter_d.val:.4f} ({norm_meter_d.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
            # Tensorboard
            if config.LOCAL_RANK in [-1, 0]:
                writer.add_scalar("train/loss_g", scalar_value=loss_meter_g.val, global_step=epoch * num_steps + idx)
                writer.add_scalar("train/loss_d", scalar_value=loss_meter_d.val, global_step=epoch * num_steps + idx)
                writer.add_scalar("train/lr_g", scalar_value=lr_g, global_step=epoch * num_steps + idx)
                writer.add_scalar("train/lr_d", scalar_value=lr_d, global_step=epoch * num_steps + idx)
                save_images(writer, 'train', image_revocer_dict, global_step=epoch * num_steps + idx)

    epoch_time = time.time() - start
    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")


if __name__ == '__main__':
    _, config = parse_option()

    if config.AMP_OPT_LEVEL != "O0":
        assert amp is not None, "amp not installed!"

    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
    else:
        rank = -1
        world_size = -1
    if config.LOCAL_RANK != -1:
        torch.cuda.set_device(config.LOCAL_RANK)
        torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
        torch.distributed.barrier()

    # seed = config.SEED + dist.get_rank()
    seed = 3407
    torch.manual_seed(seed)
    np.random.seed(seed)
    cudnn.benchmark = True

    if config.LOCAL_RANK != -1:
    # linear scale the learning rate according to total batch size, may not be optimal
        # linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_lr_g = config.TRAIN.GENERATOR_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_lr_d = config.TRAIN.DISCRIMINATOR_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
    else:
        # linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE / 512.0
        linear_scaled_lr_g = config.TRAIN.GENERATOR_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_lr_d = config.TRAIN.DISCRIMINATOR_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
        linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE / 512.0
        linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE / 512.0
    # gradient accumulation also need to scale the learning rate
    if config.TRAIN.ACCUMULATION_STEPS > 1:
        # linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_lr_g = linear_scaled_lr_g * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_lr_d = linear_scaled_lr_d * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
    config.defrost()
    # config.TRAIN.BASE_LR = linear_scaled_lr
    config.TRAIN.GENERATOR_LR = linear_scaled_lr_g
    config.TRAIN.DISCRIMINATOR_LR = linear_scaled_lr_d
    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
    config.TRAIN.MIN_LR = linear_scaled_min_lr
    config.freeze()

    os.makedirs(config.OUTPUT, exist_ok=True)

    if config.LOCAL_RANK != -1:
        logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"Vision Electra GENERATOR:{config.GENERATOR.TYPE}, DISCRIMINATOR: {config.DISCRIMINATOR.TYPE}")

        if dist.get_rank() == 0:
            path = os.path.join(config.OUTPUT, "config.json")
            with open(path, "w") as f:
                f.write(config.dump())
            logger.info(f"Full config saved to {path}")
    else:
        logger = create_logger(output_dir=config.OUTPUT)

    # print config
    logger.info(config.dump())

    main(config)
