import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6,7'
import torch, timm
import torchvision
from torch import nn
from torch.utils.data import DataLoader
import torch.nn.parallel
import torch.utils.data.distributed
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from Models.vit import vit_base_patch16
from timm.models.layers import trunc_normal_
from Dataset import iNat2018
from torchsummaryX import summary
from collections import OrderedDict

import numpy as np
from pathlib import Path
import os
import time
import json
import random
import argparse
import warnings
from torch.utils.tensorboard import SummaryWriter
import pdb

import torch_pruning as tp
from functools import partial

def main(args, model):
    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    if torch.cuda.is_available():
        ngpus_per_node = torch.cuda.device_count()

    if args.multiprocessing_distributed:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        args.world_size = ngpus_per_node * args.world_size
        
        print(colorstr('green', "Multiprocess distributed training, gpus:{}, total batch size:{}, epoch:{}, lr:{}".format(ngpus_per_node, args.batch_size, args.epochs, args.lr)))
    
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, model))
    else:
        # Simply call main_worker function
        main_worker(args.gpu, ngpus_per_node, args, model)


def main_worker(gpu, ngpus_per_node, args, model):
    args.gpu = gpu
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    
    if args.distributed and args.multiprocessing_distributed:
        # For multiprocessing distributed training, rank needs to be the
        # global rank among all the processes
        args.local_rank = args.local_rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.local_rank)
    
    if args.local_rank % ngpus_per_node == 0:
        # weights
        save_dir = Path(args.save_dir)
        weights = save_dir / 'weights'
        weights.mkdir(parents=True, exist_ok=True)
        last = weights / 'last'
        best = weights / 'best'

        # acc,loss
        acc_loss = save_dir / 'acc_loss'
        acc_loss.mkdir(parents=True, exist_ok=True)
        train_acc_top1_savepath = acc_loss / 'train_acc_top1.npy'
        train_acc_top5_savepath = acc_loss / 'train_acc_top5.npy'
        train_loss_savepath = acc_loss / 'train_loss.npy'
        val_acc_top1_savepath = acc_loss / 'val_acc_top1.npy'
        val_acc_top5_savepath = acc_loss / 'val_acc_top5.npy'
        val_loss_savepath = acc_loss / 'val_loss.npy'

        # tensorboard
        logdir = save_dir / 'logs'
        logdir.mkdir(parents=True, exist_ok=True)
        summary_writer = SummaryWriter(logdir, flush_secs=120)
    
    # dataset
    train_dataset, val_dataset, num_class = iNat2018()

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if torch.cuda.is_available():
            if args.gpu is not None:
                torch.cuda.set_device(args.gpu)
                model.cuda(args.gpu)
                # When using a single GPU per process and per
                # DistributedDataParallel, we need to divide the batch size
                # ourselves based on the total number of GPUs of the current node.
                args.batch_size = int(args.batch_size / ngpus_per_node)
                args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) # , find_unused_parameters=True
            else:
                model.cuda()
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)

    # loss
    if torch.cuda.is_available():
        if args.gpu:
            device = torch.device('cuda:{}'.format(args.gpu))
        else:
            device = torch.device("cuda")
    
    # cls loss
    criterion = nn.CrossEntropyLoss().to(device)
    params = param_groups_weight_decay(model=model, weight_decay=args.weight_decay)
    optimizer = torch.optim.SGD(params=params, 
                                lr=args.lr, 
                                momentum=args.momentum, 
                                nesterov=True, 
                                weight_decay=args.weight_decay)
        
    if args.resume:
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        elif torch.cuda.is_available():
            # Map model to be loaded to specified single gpu.
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
           
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_acc = torch.tensor(checkpoint['best_acc'])
        if args.gpu is not None:
            # best_acc may be from a checkpoint from a different GPU
            best_acc = best_acc.to(args.gpu)

        train_acc_top1 = checkpoint['train_acc_top1']
        train_acc_top5 = checkpoint['train_acc_top5']
        train_loss = checkpoint['train_loss']
        test_acc_top1 = checkpoint['test_acc_top1']
        test_acc_top5 = checkpoint['test_acc_top5']
        test_loss = checkpoint['test_loss']
        if args.local_rank % ngpus_per_node == 0:
            print(colorstr('green', 'Resuming training from {} epoch'.format(start_epoch)))
    else:
        start_epoch = 0
        best_acc = 0
        train_acc_top1 = []
        train_acc_top5 = []
        train_loss = []
        test_acc_top1 = []
        test_acc_top5 = []
        test_loss = []

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
    else:
        train_sampler = None
        val_sampler = None
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
                              num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
                            num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
    
    for epoch in range(start_epoch, args.epochs):
        if args.local_rank % ngpus_per_node == 0:
            print("Epoch {}/{}".format(epoch + 1, args.epochs))
        if args.distributed:
            train_sampler.set_epoch(epoch)
        
        train_epoch_loss, train_acc1, train_acc5 = train(model=model,
                                                         train_loader=train_loader,
                                                         optimizer=optimizer,
                                                         criterion=criterion,
                                                         ngpus_per_node=ngpus_per_node,
                                                         args=args,
                                                         epoch=epoch)
        
        val_epoch_loss, val_acc1, val_acc5 = validate(model=model,
                                                      val_loader=val_loader,
                                                      criterion=criterion,
                                                      args=args)
        
        s = "Train Loss: {:.3f}, Train Acc Top1: {:.3f}, Train Acc Top5: {:.3f}, Test Loss: {:.3f}, Test Acc Top1: {:.3f}, Test Acc Top5: {:.3f}, lr: {:.5f}".format(
            train_epoch_loss, train_acc1, train_acc5, val_epoch_loss, val_acc1, val_acc5, optimizer.param_groups[0]['lr'])
        if args.local_rank % ngpus_per_node == 0:
            print(colorstr('green', s))

            # save acc,loss
            train_loss.append(train_epoch_loss)
            train_acc_top1.append(train_acc1)
            train_acc_top5.append(train_acc5)
            test_loss.append(val_epoch_loss)
            test_acc_top1.append(val_acc1)
            test_acc_top5.append(val_acc5)

            # save model
            is_best = val_acc1 > best_acc
            best_acc = max(best_acc, val_acc1)
            state = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_acc': best_acc,
                'train_acc_top1': train_acc_top1,
                'train_acc_top5': train_acc_top5,
                'train_loss': train_loss,
                'test_acc_top1': test_acc_top1,
                'test_acc_top5': test_acc_top5,
                'test_loss': test_loss,
            }

            last_path = last / 'epoch_{}_loss_{:.3f}_acc_{:.3f}'.format(
                epoch + 1, val_epoch_loss, val_acc1)
            best_path = best / 'epoch_{}_acc_{:.4f}'.format(
                epoch + 1, best_acc)
            Save_Checkpoint(state, last, last_path, best, best_path, is_best)

            if epoch == 1:
                images, labels = next(iter(train_loader))
                img_grid = torchvision.utils.make_grid(images)
                summary_writer.add_image('iNat2018 Image', img_grid)

            summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
            summary_writer.add_scalar('train_loss', train_epoch_loss, epoch)
            summary_writer.add_scalar('train_acc_top1', train_acc1, epoch)
            summary_writer.add_scalar('train_acc_top5', train_acc5, epoch)
            summary_writer.add_scalar('val_loss', val_epoch_loss, epoch)
            summary_writer.add_scalar('val_acc_top1', val_acc1, epoch)
            summary_writer.add_scalar('val_acc_top5', val_acc5, epoch)
    
    if args.local_rank % ngpus_per_node == 0:
        summary_writer.close()
        if not os.path.exists(train_acc_top1_savepath) or not os.path.exists(train_loss_savepath):
            np.save(train_acc_top1_savepath, train_acc_top1)
            np.save(train_acc_top5_savepath, train_acc_top5)
            np.save(train_loss_savepath, train_loss)
            np.save(val_acc_top1_savepath, test_acc_top1)
            np.save(val_acc_top5_savepath, test_acc_top5)
            np.save(val_loss_savepath, test_loss)


def train(model, train_loader, optimizer, criterion, ngpus_per_node, args, epoch):
    train_loss = AverageMeter()
    train_acc1 = AverageMeter()
    train_acc5 = AverageMeter()

    # Model on train mode
    model.train()
    step_per_epoch = len(train_loader)
    for step, (images, labels) in enumerate(train_loader):
        start = time.time()

        adjust_learning_rate(optimizer, step / step_per_epoch + epoch, args)

        if args.gpu is not None and torch.cuda.is_available():
            images = images.cuda(args.gpu, non_blocking=True)
            labels = labels.cuda(args.gpu, non_blocking=True)
        
        # compute logits
        logits = model(images)
        
        # cls loss
        loss = criterion(logits, labels)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(logits, labels, topk=(1, 5))

        train_loss.update(loss.item(), images.size(0))
        train_acc1.update(acc1[0].item(), images.size(0))
        train_acc5.update(acc5[0].item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        t = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        s1 = '\r{} [{}/{}]'.format(t, step+1, step_per_epoch)
        s2 = ' - {:.2f}ms/step - train_loss: {:.3f} - train_acc_top1: {:.3f} - train_acc_top5: {:.3f}'.format(
             1000 * (time.time()-start), train_loss.val, train_acc1.val, train_acc5.val)
        if args.local_rank % ngpus_per_node == 0:
            print(s1+s2, end='', flush=True)
    
    if args.local_rank % ngpus_per_node == 0:
        print()
    return train_loss.avg, train_acc1.avg, train_acc5.avg


def validate(model, val_loader, criterion, args):
    val_loss = AverageMeter()
    val_acc1 = AverageMeter()
    val_acc5 = AverageMeter()
    
    # model to evaluate mode
    model.eval()
    with torch.no_grad():
        for step, (images, labels) in enumerate(val_loader):
            if args.gpu is not None and torch.cuda.is_available():
                 images = images.cuda(args.gpu, non_blocking=True)
                 labels = labels.cuda(args.gpu, non_blocking=True)

            # compute output
            logits = model(images)
            loss = criterion(logits, labels)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(logits, labels, topk=(1, 5))

            # Average loss and accuracy across processes
            if args.distributed:
                loss = reduce_tensor(loss, args)
                acc1 = reduce_tensor(acc1, args)
                acc5 = reduce_tensor(acc5, args)

            val_loss.update(loss.item(), images.size(0))
            val_acc1.update(acc1[0].item(), images.size(0))
            val_acc5.update(acc5[0].item(), images.size(0))
    
    return val_loss.avg, val_acc1.avg, val_acc5.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def reduce_tensor(tensor, args):
    rt = tensor.clone()
    dist.all_reduce(rt, op=dist.ReduceOp.SUM)
    rt /= args.world_size
    return rt


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            # res.append(correct_k.mul_(100.0 / batch_size))
            res.append(correct_k.mul_(1.0 / batch_size))
        return res


def testmodel(model, test_data, args):
    val_acc1 = AverageMeter()
    val_acc5 = AverageMeter()
    
    # model to evaluate mode
    model.eval()

    test_dataloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                 num_workers=args.workers, pin_memory=True)

    with torch.no_grad():
        for step, (images, labels) in enumerate(test_dataloader):
            images, labels = images.cuda(), labels.cuda()
            # compute output
            _, logits = model(images)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(logits, labels, topk=(1, 5))

            val_acc1.update(acc1[0], images.size(0))
            val_acc5.update(acc5[0], images.size(0))
    
    return val_acc1.avg, val_acc5.avg


def prune_load_weights(model, pretrain_ckpt, prune_rule, prune_rate):
    # load weights
    model_dict = model.state_dict()
    state_dict = {k:v for k,v in pretrain_ckpt.items() if k in model_dict}
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)

    # prune
    if prune_rule == 'random':
        imp = tp.importance.RandomImportance()
        pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=False)
    elif prune_rule == 'L1norm':
        imp = tp.importance.MagnitudeImportance(p=1, normalizer=None, group_reduction="first")
        pruner_entry = partial(tp.pruner.MagnitudePruner, global_pruning=False)
    elif prune_rule == 'DepGraph':
        imp = tp.importance.GroupNormImportance(p=2)
        pruner_entry = partial(tp.pruner.GroupNormPruner, global_pruning=False)
    
    ignored_layers = []
    for m in model.modules():
        # DO NOT prune the final classifier!
        if isinstance(m, torch.nn.Linear) and m.out_features == num_class:
            ignored_layers.append(m)

    pruner = pruner_entry(
        model,
        example_inputs=torch.randn(1, 3, 224, 224),
        importance=imp,
        iterative_steps=1,
        ch_sparsity=prune_rate,
        ignored_layers=ignored_layers,
        round_to=12,
    )
    pruner.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch Training with Distributed Data Parallel.')
    # model train
    parser.add_argument("--model_weights", type=str, default="", help="model weights path")
    parser.add_argument("--dataset", type=str, default='iNat2018')
    parser.add_argument("--epochs", type=int, default=90)
    parser.add_argument("--batch_size", type=int, default=512, help="total batch size")
    parser.add_argument('--workers', default=16, type=int, help='number of data loading workers')
    parser.add_argument("--lr", type=float, default=0.02)
    parser.add_argument('--min_lr', type=float, default=0, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
    parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
    parser.add_argument("--weight_decay", type=float, default=1e-4)

    # prune
    parser.add_argument("--setting", type=str, default='B', choices=['B', 'D'])
    parser.add_argument("--pretrain_transfer", type=str, default='./run/ViT/iNat2018/pretrain-transfer/weights/best/epoch_82_acc_0.7164/ckpt.pth')

    parser.add_argument("--prune_rate", type=float, default=0.5, help='prune rate')
    parser.add_argument("--prune_rule", type=str, default='random', help='prune methods')

    # DDP
    parser.add_argument('--world_size', default=1, type=int, help='number of nodes for distributed training')
    parser.add_argument('--local_rank', default=0, type=int, help='node rank for distributed training')
    parser.add_argument('--dist_url', default='tcp://localhost:10000', type=str, help='url used to set up distributed training')
    parser.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')
    parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
    parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
    parser.add_argument('--multiprocessing-distributed', action='store_true', 
                        help='Use multi-processing distributed training to launch '
                             'N processes per node, which has N GPUs. This is the '
                             'fastest way to use PyTorch for either single node or '
                             'multi node data parallel training')

    parser.add_argument("--resume", type=str, help="ckpt's path to resume most recent training")
    parser.add_argument("--save_dir", type=str, default="./run", help="save path, eg, acc_loss, weights, tensorboard, and so on")
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')
    
    train_data, test_data, num_class = iNat2018()

    model = vit_base_patch16(num_classes=num_class)
    # summary(model, torch.ones(1, 3, 224, 224))

    if args.setting == 'B':
        # pre-trained model on source domain
        pretrain_model = timm.create_model('vit_base_patch16_224_in21k', pretrained=True)
        in_features = pretrain_model.head.in_features
        pretrain_model.head = torch.nn.Linear(in_features=in_features, out_features=num_class)
        base_ops, base_params = tp.utils.count_ops_and_params(pretrain_model, example_inputs=torch.randn(1, 3, 224, 224))
        trunc_normal_(pretrain_model.head.weight, std=.02)
        nn.init.constant_(pretrain_model.head.bias, 0)
        pretrain_ckpt = pretrain_model.state_dict()
    
    elif args.setting == 'D' and args.pretrain_transfer:
        pretrain_ckpt = torch.load(args.pretrain_transfer, map_location='cpu')['model_state_dict']
      
        state_dict = {k[7:]:v for k,v in pretrain_ckpt.items()}
        pretrain_ckpt = state_dict

    if args.model_weights:
        # load fine-tuned ckpt to evaluate the accuracy on the validation / test set
        print('Load fine-tuned weights for resnet-50')
        model_ckpt = torch.load(args.model_weights)['model_state_dict']
        new_state_dict = OrderedDict()
        for k, v in model_ckpt.items():
            if 'indexes' not in k:
                name = k[7:]   # remove 'module.'
                new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

        pdb.set_trace()
        model = model.cuda()
        # summary(model, torch.ones(1, 3, 224, 224).cuda())
        acc1, acc5 = testmodel(model=model, test_data=test_data, args=args)
        print('Acc Top1: {}, Acc Top5: {}'.format(acc1, acc5))
    else:
        # load pre-trained ckpt transfer or fine-tuning on target domain
        prune_load_weights(model, pretrain_ckpt, args.prune_rule, args.prune_rate)

        print(colorstr('green', args.prune_rule + ' prune, and load pre-trained weights'))
    
    # pruned_ops, pruned_size = tp.utils.count_ops_and_params(model, example_inputs=torch.randn(1, 3, 224, 224))
    # print(colorstr('green', "FLOPs prune rate: {}".format((1-pruned_ops/base_ops)*100)))

    if args.setting == 'B':
        print(colorstr('green', 'Pretrained on ImageNet-1k --> pruned and inherits weights --> transfer sparse ViT to ' + args.dataset + ' ...'))
    elif args.setting == 'D':
        print(colorstr('green', 'Pretrained on ImageNet-1k --> transfer to  ' + args.dataset + ' pruned, inherits weights and fine-tune ...'))
    # Train the model
    main(args=args, model=model)