import torch
from torch.utils.data import DataLoader
import torch.utils.data.distributed
import torch.backends.cudnn as cudnn
import torch.distributed as dist

import models_mae
from Dataset import DESINet
from util.misc import colorstr, SaveCheckpoint, NativeScaler
from util.lr_sched import adjust_learning_rate, param_groups_weight_decay
from collections import OrderedDict

import numpy as np
from pathlib import Path
import os, math, sys
import time
import argparse
from torch.utils.tensorboard import SummaryWriter
import pdb

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def init_distributed_mode(args):
    
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
        print("Use GPU: {} for training".format(args.gpu))
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)

    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    dist.barrier()
    setup_for_distributed(args.rank == 0)


def main(args):
    init_distributed_mode(args)

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True    

    # pretraining dataset with simple augmentation
    train_dataset = DESINet()
    print("train dataset size: {}".format(len(train_dataset)))
    
    if args.lr is None:
        args.lr = args.blr * args.batch_size / 256
        args.batch_size = int(args.batch_size / args.world_size)

    print(args)
    print(colorstr('green', "epochs: {}, images per gpu: {}, base lr: {}, absolute lr: {}".format(
            args.epochs, args.batch_size, args.blr, args.lr)))
    
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True)
    else:
        raise ValueError("Distributed init error.")

    train_loader = DataLoader(train_dataset, 
                              batch_size=args.batch_size,
                              num_workers=args.workers, 
                              pin_memory=True, 
                              sampler=train_sampler, 
                              drop_last=True)
    
    # create model
    model = models_mae.__dict__[args.model_name](norm_pix_loss=args.norm_pix_loss)
    model.to(device)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.2f' % (n_parameters / 1.e6))
    
    params = param_groups_weight_decay(model=model_without_ddp, weight_decay=args.weight_decay, no_weight_decay_list=model_without_ddp.no_weight_decay(),)
    optimizer = torch.optim.AdamW(params=params,
                                  lr=args.lr,
                                  betas=(0.9, 0.95))
    loss_scaler = NativeScaler()
        
    # file path
    if dist.get_rank() == 0:
        # weights
        save_dir = Path(args.save_dir)
        weights = save_dir / 'weights'
        weights.mkdir(parents=True, exist_ok=True)
        last = weights / 'last'
        best = weights / 'best'

        # acc,loss
        acc_loss = save_dir / 'acc_loss'
        acc_loss.mkdir(parents=True, exist_ok=True)
        train_loss_savepath = acc_loss / 'train_loss.npy'

        # tensorboard
        logdir = save_dir / 'logs'
        logdir.mkdir(parents=True, exist_ok=True)
        summary_writer = SummaryWriter(logdir, flush_secs=120)

        # result
        model_file = str(save_dir / 'model.txt')
        with open(model_file, "a") as f:
            print(model_without_ddp, file=f)
            print(args, file=f)
    
    if args.resume:
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        elif torch.cuda.is_available():
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
           
        args.start_epoch = checkpoint['epoch']
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        num_groups = len(optimizer.param_groups)
        for i in range(num_groups):
            optimizer.param_groups[i]['capturable'] = True
        best_loss = torch.tensor(checkpoint['best_loss'])
        loss_scaler.load_state_dict(checkpoint['scaler'])

        train_loss = checkpoint['train_loss']
        if dist.get_rank() == 0:
            print(colorstr('green', 'Resuming training from {} epoch'.format(args.start_epoch)))
    else:
        best_loss = float("inf")
        train_loss = []
    
    print("Start training")
    for epoch in range(args.start_epoch, args.epochs):
        print("Epoch {}/{}".format(epoch + 1, args.epochs))
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_epoch_loss = train(model=model,
                                 train_loader=train_loader,
                                 optimizer=optimizer,
                                 args=args,
                                 epoch=epoch,
                                 scaler=loss_scaler)
        
        s = "Train Loss: {:.4f}, lr: {:.5f}".format(train_epoch_loss, optimizer.param_groups[-1]['lr'])
        print(colorstr('green', s))

        if dist.get_rank() == 0:
            # save acc,loss
            train_loss.append(train_epoch_loss)

            # save model
            is_best = train_epoch_loss < best_loss
            best_loss = min(best_loss, train_epoch_loss)
            state = {
                'epoch': epoch + 1,
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_loss': best_loss,
                'train_loss': train_loss,
                'scaler': loss_scaler.state_dict()
            }

            last_path = last / 'epoch_{}_loss_{:.4f}'.format(epoch + 1, train_epoch_loss)
            best_path = best / 'epoch_{}_loss_{:.4f}'.format(epoch + 1, best_loss)
            
            SaveCheckpoint(state, last, last_path, best, best_path, is_best)

            summary_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
            summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
    
    if dist.get_rank() == 0:
        summary_writer.close()
        if not os.path.exists(train_loss_savepath):
            np.save(train_loss_savepath, train_loss)


def train(model, train_loader, optimizer, args, epoch, scaler):
    train_loss = AverageMeter()

    # Model on train mode
    model.train()
    step_per_epoch = len(train_loader)
    for step, (images, _) in enumerate(train_loader):
        torch.cuda.synchronize()
        start = time.time()

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, step / step_per_epoch + epoch, args)

        if args.gpu is not None and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
        
        # compute gradient and do SGD step
        if args.use_amp:
            with torch.cuda.amp.autocast():
                loss, _, _ = model(images, mask_ratio=args.mask_ratio)
            
            if not math.isfinite(loss.item()):
                print("Loss is {}, stopping training".format(loss.item()))
                sys.exit(1)

            scaler(loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters())

        # record loss
        train_loss.update(loss.item(), images.size(0))

        torch.cuda.synchronize()
        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - train_loss: {:.3f}'.format(1000 * (time.time()-start), train_loss.val)
        print(s1+s2, end='', flush=True)
    
    print()
    return train_loss.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def reduce_tensor(tensor, args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


if __name__ == '__main__':  
    parser = argparse.ArgumentParser(description='MAE pre-training')
    # model parameters
    parser.add_argument("--model_name", type=str, default="mae_vit_large_patch16", help="model architecture")
    parser.add_argument('--mask_ratio', default=0.75, type=float, help='Masking ratio (percentage of removed patches).')
    parser.add_argument('--norm_pix_loss', action='store_true', help='Use (per-patch) normalized pixels as targets for computing loss')
    # parser.set_defaults(norm_pix_loss=True)
    parser.add_argument('--use_amp', action='store_true')
    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')

    # optimizer parameters
    parser.add_argument("--dataset", type=str, default='AstroImgNet')
    parser.add_argument("--epochs", type=int, default=800)
    parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
    parser.add_argument("--batch_size", type=int, default=4096, help="total batch size")
    parser.add_argument('--workers', default=16, type=int, help='number of data loading workers')
    parser.add_argument("--weight_decay", type=float, default=0.05)
    parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
    
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument('--seed', default=0, type=int)

    # distributed training parameters
    parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')

    parser.add_argument("--resume", type=str, help="ckpt's path to resume most recent training")
    parser.add_argument("--save_dir", type=str, default="./run", help="save path, eg, acc_loss, weights, tensorboard, and so on")
    args = parser.parse_args()

    print(colorstr('green', ' Pre-training ' + args.model_name + ' on ' + args.dataset + ' ...'))
    main(args=args)