import torch, timm
from torch import nn
from torch.utils.data import DataLoader
import torch.utils.data.distributed
import torch.backends.cudnn as cudnn
import torch.distributed as dist

import models_vit
from Dataset import galaxy_sdss
from timm.data.mixup import Mixup
from timm.models.layers import trunc_normal_
from timm.utils import ModelEmaV2
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

from util.misc import colorstr, SaveCheckpoint
from util.lr_sched import adjust_learning_rate, param_groups_lrd
from util.pos_embed import interpolate_pos_embed

import numpy as np
from pathlib import Path
import os
import time
import argparse
from torch.utils.tensorboard import SummaryWriter


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop("force", False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def init_distributed_mode(args):
    
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ["WORLD_SIZE"])
        args.gpu = int(os.environ["LOCAL_RANK"])
        print("Use GPU: {} for training".format(args.gpu))
    elif "SLURM_PROCID" in os.environ:
        args.rank = int(os.environ["SLURM_PROCID"])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print("Not using distributed mode")
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)

    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
    )
    dist.barrier()
    setup_for_distributed(args.rank == 0)


def create_dataset(args):
    train_dataset, val_dataset, num_class = galaxy_sdss()
    print(len(train_dataset), len(val_dataset), num_class)
    
    args.batch_size = int(args.batch_size / args.world_size)

    print(colorstr('green', "epochs: {}, images per gpu: {}, absolute lr: {}".format(
            args.epochs, args.batch_size, args.lr)))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=True)
        val_sampler = torch.utils.data.distributed.DistributedSampler(
            val_dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
    else:
        raise ValueError("Distributed init error.")

    train_loader = DataLoader(train_dataset, 
                              batch_size=args.batch_size, 
                              num_workers=args.workers, 
                              pin_memory=True, 
                              sampler=train_sampler, 
                              drop_last=True)

    val_loader = DataLoader(val_dataset, 
                            batch_size=args.batch_size, 
                            num_workers=args.workers, 
                            pin_memory=True, 
                            sampler=val_sampler, 
                            drop_last=False)
    
    return train_loader, val_loader, num_class, train_sampler


def create_model(args, num_class):
    vit = ['vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
    assert args.model_name in vit

    model = models_vit.__dict__[args.model_name](
        img_size=args.input_size,
        num_classes=num_class, 
        drop_path_rate=args.drop_path, 
        global_pool=args.global_pool,
    )

    if args.finetune:
        ckpt = torch.load(args.finetune, map_location='cpu')['model']

        # interpolate position embedding
        interpolate_pos_embed(model, ckpt)

        msg = model.load_state_dict(ckpt, strict=False)
        print(msg)

        if args.global_pool:
            assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}

        # init task head weights
        trunc_normal_(model.head.weight, std=2e-5)
        print(colorstr('green', "Load pre-trained checkpoint from: %s" % args.finetune))

    return model


def main(args):
    init_distributed_mode(args)
    print(args)

    cudnn.benchmark = True

    device = torch.device(args.device)

    # data loaders
    train_loader, val_loader, num_class, train_sampler = create_dataset(args=args)

    mixup_fn = None
    mixup_active = args.mixup > 0. or args.cutmix > 0.
    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=args.mixup,
            cutmix_alpha=args.cutmix, 
            label_smoothing=args.smoothing, 
            num_classes=num_class,
        )
    if mixup_fn is not None:
        criterion = SoftTargetCrossEntropy().to(device)
    elif args.smoothing > 0.:
        criterion = LabelSmoothingCrossEntropy().to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    # create model
    model = create_model(args=args, num_class=num_class)
    model.to(device)

    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
        model_ema = ModelEmaV2(model, decay=args.model_ema_decay)
    
    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.4f' % (n_parameters / 1.e6))

    # build optimizer with layer-wise lr decay (lrd)
    param_groups = param_groups_lrd(model_without_ddp, args.weight_decay, 
        wd_head=args.wd_head,
        no_weight_decay_list=model_without_ddp.no_weight_decay(),
        layer_decay=args.layer_decay,
    )
    optimizer = torch.optim.AdamW(params=param_groups,
                                  lr=args.lr,
                                  betas=(0.9, 0.999))
    scaler = torch.cuda.amp.GradScaler()
    
    # file path
    if dist.get_rank() == 0:
        # weights
        save_dir = Path(args.save_dir)
        weights = save_dir / 'weights'
        weights.mkdir(parents=True, exist_ok=True)
        last = weights / 'last'
        best = weights / 'best'

        # tensorboard
        logdir = save_dir / 'logs'
        logdir.mkdir(parents=True, exist_ok=True)
        summary_writer = SummaryWriter(logdir, flush_secs=120)

        # model
        model_file = str(save_dir / 'model.txt')
        with open(model_file, "a") as f:
            print(model_without_ddp, file=f)
            print(args, file=f)

    if args.resume:
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        elif torch.cuda.is_available():
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
           
        args.start_epoch = checkpoint['epoch']
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_acc = torch.tensor(checkpoint['best_acc'])
        if args.gpu is not None:
            # best_acc may be from a checkpoint from a different GPU
            best_acc = best_acc.to(args.gpu)

        if model_ema:
            model_ema.load_state_dict(checkpoint["model_ema"])

        print(colorstr('green', 'Resuming training from {} epoch'.format(args.start_epoch)))
    else:
        best_acc = 0
    
    print("Start training")
    for epoch in range(args.start_epoch, args.epochs):
        print("Epoch {}/{}".format(epoch + 1, args.epochs))
        if args.distributed:
            train_sampler.set_epoch(epoch)

        train_epoch_loss = train(model=model,
                                 train_loader=train_loader,
                                 optimizer=optimizer,
                                 criterion=criterion,
                                 mixup_fn=mixup_fn,
                                 scaler=scaler,
                                 args=args,
                                 epoch=epoch,
                                 model_ema=model_ema)

        val_epoch_loss, val_acc1 = validate(model=model,
                                            val_loader=val_loader,
                                            args=args)
        
        s = "Train Loss: {:.3f}, Test Loss: {:.3f}, Test Acc: {:.3f}, lr: {:.1e}".format(
            train_epoch_loss, val_epoch_loss, val_acc1, optimizer.param_groups[-1]['lr'])
        print(colorstr('green', s))

        if dist.get_rank() == 0:
            # save model
            is_best = val_acc1 > best_acc
            best_acc = max(best_acc, val_acc1)
            state = {
                'epoch': epoch + 1,
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_acc': best_acc,
            }
            if model_ema:
                state["model_ema"] = model_ema.state_dict()

            last_path = last / 'epoch_{}_loss_{:.4f}_acc_{:.3f}'.format(
                epoch + 1, val_epoch_loss, val_acc1)
            best_path = best / 'epoch_{}_acc_{:.4f}'.format(
                epoch + 1, best_acc)
            
            SaveCheckpoint(state, last, last_path, best, best_path, is_best)

            summary_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
            summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
            summary_writer.add_scalar('val_loss', val_epoch_loss, epoch)
            summary_writer.add_scalar('val_acc', val_acc1, epoch)
    
    if dist.get_rank() == 0:
        summary_writer.close()


def train(model, train_loader, optimizer, criterion, mixup_fn, scaler, args, epoch, model_ema):
    train_loss = AverageMeter()

    # Model on train mode
    model.train()
    step_per_epoch = len(train_loader)
    for step, (images, labels) in enumerate(train_loader):
        torch.cuda.synchronize()
        start = time.time()

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, step / step_per_epoch + epoch, args)

        if args.gpu is not None and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            labels = labels.cuda(args.gpu, non_blocking=True)
        
        if mixup_fn is not None:
            images, labels = mixup_fn(images, labels)
        
        # compute output
        if args.use_amp:
            with torch.cuda.amp.autocast():
                logits = model(images)
                loss = criterion(logits, labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        train_loss.update(loss.item(), images.size(0))

        torch.cuda.synchronize()
        if model_ema is not None:
            model_ema.update(model)

        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - train_loss: {:.3f}'.format(1000 * (time.time()-start), train_loss.val)
        print(s1+s2, end='', flush=True)
    
    print()
    return train_loss.avg


def validate(model, val_loader, args):
    val_loss = AverageMeter()
    val_acc1 = AverageMeter()
    
    # model to evaluate mode
    criterion = nn.CrossEntropyLoss()
    model.eval()
    with torch.no_grad():
        for step, (images, labels) in enumerate(val_loader):
            if args.gpu is not None and torch.cuda.is_available():
                 images = images.cuda(args.gpu, non_blocking=True)
                 labels = labels.cuda(args.gpu, non_blocking=True)

            # compute output
            if args.use_amp:
                with torch.cuda.amp.autocast():
                    logits = model(images)
                    loss = criterion(logits, labels)
            else:
                logits = model(images)
                loss = criterion(logits, labels)

            # measure accuracy and record loss
            acc1 = accuracy(logits, labels, topk=(1, ))

            # Average loss and accuracy across processes
            if args.distributed:
                loss = reduce_tensor(loss, args)
                acc1 = reduce_tensor(acc1[0], args)

            val_loss.update(loss.item(), images.size(0))
            val_acc1.update(acc1[0].item(), images.size(0))
    
    return val_loss.avg, val_acc1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def reduce_tensor(tensor, args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= dist.get_world_size()
    return rt


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            # res.append(correct_k.mul_(100.0 / batch_size))
            res.append(correct_k.mul_(1.0 / batch_size))
        return res


def testmodel(model, test_data, args):
    val_acc1 = AverageMeter()
    val_acc5 = AverageMeter()
    
    # model to evaluate mode
    model.eval()

    test_dataloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                 num_workers=args.workers, pin_memory=True)

    with torch.no_grad():
        for step, (images, labels) in enumerate(test_dataloader):
            images, labels = images.cuda(), labels.cuda()
            # compute output
            pred = model(images)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(pred, labels, topk=(1, 5))

            val_acc1.update(acc1[0], images.size(0))
            val_acc5.update(acc5[0], images.size(0))
    
    return val_acc1.avg, val_acc5.avg

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Training for visual tuning.')
    # model parameters
    parser.add_argument("--model_name", type=str, default="vit_large_patch16", help="model architecture")
    parser.add_argument('--input_size', default=224, type=int, help='images input size')
    parser.add_argument("--drop_path", type=float, default=0.1, help='Drop path rate')
    parser.add_argument('--model_ema', action='store_true')
    parser.add_argument('--model_ema_decay', type=float, default=0.9999)

    # Finetuning parameters
    parser.add_argument('--finetune', type=str, default='', help='finetune from pretrained checkpoint')
    parser.add_argument('--global_pool', action='store_true')
    parser.set_defaults(global_pool=True)
    parser.add_argument('--use_amp', action='store_true')

    # optimizer parameters
    parser.add_argument("--dataset", type=str, default='galaxy')
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
    parser.add_argument("--batch_size", type=int, default=64, help="batch size")
    parser.add_argument('--workers', default=8, type=int, help='number of data loading workers')
    parser.add_argument("--weight_decay", type=float, default=0.05)
    parser.add_argument("--wd_head", type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (absolute lr)')
    parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay from ELECTRA/BEiT')
    parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
    
    # Augmentation parameters
    parser.add_argument('--color_jitter', type=float, default=None, help='Color jitter factor (enabled only when not using Auto/RandAug)')
    parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
    parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')

    # * Random Erase params
    parser.add_argument('--reprob', type=float, default=0.25, help='Random erase prob (default: 0.25)')
    parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")')
    parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)')

    # * Mixup params
    parser.add_argument('--mixup', type=float, default=0.8, help='mixup alpha, mixup enabled if > 0.')
    parser.add_argument('--cutmix', type=float, default=1.0, help='cutmix alpha, cutmix enabled if > 0.')
    
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")

    # distributed training parameters
    parser.add_argument('--dist_url', default='env://', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')

    parser.add_argument("--resume", type=str, help="ckpt's path to resume most recent training")
    parser.add_argument("--save_dir", type=str, default="./run", help="save path, eg, acc_loss, weights, tensorboard, and so on")
    args = parser.parse_args()

    print(colorstr('green', 'Fine-tuning ' + args.model_name + ' on ' + args.dataset + ' ...'))
    main(args=args)