import os
import time

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, DistributedSampler, sampler
from fvcore.nn import FlopCountAnalysis, parameter_count_table
from tqdm import tqdm

import importlib
import config
import myutils
from loss import Loss



def load_checkpoint(args, model, optimizer, scheduler, checkpoint_path, main_rank):
    if checkpoint_path is not None:
        if torch.distributed.get_rank() == main_rank:
            print("loading checkpoint %s" % checkpoint_path)
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(local_rank))
        loadStateDict = checkpoint['state_dict']
        # loadStateDict = {'module.'+k : v for k,v in loadStateDict.items()}
        epoch = checkpoint['epoch']
        args.start_epoch = epoch + 1

        modelStateDict = model.state_dict()
        for k,v in loadStateDict.items():
            if v.shape == modelStateDict[k].shape:
                if torch.distributed.get_rank() == main_rank:
                    print("Loading " , k)
                    modelStateDict[k] = v
            else:
                if torch.distributed.get_rank() == main_rank:
                   print("Not loading" , k)        
        model.load_state_dict(loadStateDict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        lr = checkpoint.get("lr" , args.lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        if torch.distributed.get_rank() == main_rank:
            print('Checkpoint for epoch %s has been loaded' %checkpoint['epoch'])
        return model, epoch, optimizer, scheduler
    else:
        if torch.distributed.get_rank() == main_rank:
            print('No checkpoint need to be loaded')        

def load_pretrained_model(pretrained, model, main_rank):
    if pretrained is not None:
        ## For low data, it is better to load from a supervised pretrained model
        loadStateDict = torch.load(pretrained, map_location='cuda:{}'.format(local_rank))['state_dict']
        # loadStateDict = {'module.'+k : v for k,v in loadStateDict.items()}

        modelStateDict = model.state_dict()

        for k,v in loadStateDict.items():
            if v.shape == modelStateDict[k].shape:
                print("Loading " , k)
                modelStateDict[k] = v
            else:
                print("Not loading" , k)

        model.load_state_dict(modelStateDict)
        if torch.distributed.get_rank() == main_rank:
            print('Pretrained model has been loaded from: %s' %pretrained)
        return model

    else:
        if torch.distributed.get_rank() == main_rank:
            print('No pretrained model need to be loaded')


##### Parse CmdLine Arguments #####
args, unparsed = config.get_args()
cwd = os.getcwd()

##### Distributed DataParallel #####
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group('nccl', world_size=args.num_gpu, rank=local_rank)
torch.cuda.set_device(local_rank)
print('Device %s has been set successfully' %torch.distributed.get_rank())

if torch.distributed.get_rank() == args.main_rank:
    dic = vars(args)
    print('####### Setting Parameters are as follows ########')
    for k,v in dic.items():
        print(k,':',v)
    print("#######          That's it                ########")



save_loc = os.path.join(args.checkpoint_dir , "saved_models_final" , args.dataset , args.model)
if not os.path.exists(save_loc):
    os.makedirs(save_loc)
opts_file = os.path.join(save_loc , "opts.txt")
with open(opts_file , "w") as fh:
    fh.write(str(args))


##### TensorBoard & Misc Setup #####
writer_loc = os.path.join(args.checkpoint_dir , 'tensorboard_logs_%s_final/%s' % (args.dataset , args.model))
writer = SummaryWriter(writer_loc)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
myutils.set_seed(seed=args.random_seed , cuda=args.cuda)


if args.dataset == "vimeo90K_septuplet_6":
    from dataset.vimeo90k_septuplet import VimeoSepTuplet
    train_set = VimeoSepTuplet(args.data_root, is_training=True)
    train_sampler = DistributedSampler(train_set)
    train_loader = DataLoader(train_set, args.batch_size, shuffle=False, sampler=train_sampler,num_workers=args.num_workers, pin_memory=True)
    
    test_set = VimeoSepTuplet(args.data_root, is_training=False)
    test_sampler = DistributedSampler(test_set)
    test_loader = DataLoader(test_set, args.test_batch_size, shuffle=False, sampler=test_sampler, num_workers=args.num_workers, pin_memory=True) 

    # debug_sampler = sampler.SubsetRandomSampler(range(1, 10))  
    # test_loader = DataLoader(test_set, args.test_batch_size, shuffle=False, sampler=debug_sampler, num_workers=args.num_workers, pin_memory=True) 


    
elif args.dataset == "gopro":
    from dataset.GoPro import get_loader
    train_loader = get_loader(args.data_root, args.batch_size, shuffle=True, num_workers=args.num_workers, test_mode=False, interFrames=args.n_outputs, n_inputs=args.nbr_frame)
    test_loader = get_loader(args.data_root, args.batch_size, shuffle=False, num_workers=args.num_workers, test_mode=True, interFrames=args.n_outputs, n_inputs=args.nbr_frame)
else:
    raise NotImplementedError


from torch.optim import Adam

if torch.distributed.get_rank() == args.main_rank:
    print("Building model: %s"%args.model)

model = getattr(importlib.import_module("model.{}".format(args.model)), args.model)().cuda()



optimizer = Adam(model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epoch, eta_min=1e-6)
load_pretrained_model(args.pretrained, model, args.main_rank)
load_checkpoint(args, model, optimizer ,scheduler, args.checkpoint_path, args.main_rank)

if torch.distributed.get_rank() == args.main_rank:
    print('# of parameters: %d' % sum(p.numel() for p in model.parameters()))
    params = sum(p.numel() for p in model.parameters())
    print(f'Number of parameters: {params / 1e6:.2f}M' )

model = nn.parallel.DistributedDataParallel(model.cuda(local_rank), device_ids=[local_rank])# 模型拷贝，放入DistributedDataParallel,,find_unused_parameters=True

##### Define Loss & Optimizer #####
criterion = Loss(args)

## ToDo: Different learning rate schemes for different parameters

def train(args, epoch):
    losses, psnrs, ssims = myutils.init_meters(args.loss)
    model.train()
    criterion.train()

    t = time.time()
    train_sampler.set_epoch(epoch)
    for i, (images, gt_image) in enumerate(train_loader):

        # Build input batch
        images = [img_.cuda() for img_ in images]
        gt = [gt_.cuda() for gt_ in gt_image]

        # Forward
        optimizer.zero_grad()
        out = model(images)
        
        out = torch.cat(out)
        gt = torch.cat(gt)

        loss, loss_specific = criterion(out, gt)
        
        # Save loss values
        for k, v in losses.items():
            if k != 'total':
                v.update(loss_specific[k].item())
        losses['total'].update(loss.item())

        loss.backward()
        optimizer.step()

        # Calc metrics & print logs
        if i % args.log_iter == 0: 
            myutils.eval_metrics(out, gt, psnrs, ssims)

            if torch.distributed.get_rank() == args.main_rank:
                print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}\tPSNR: {:.4f}\tTime({:.2f})'.format(
                    epoch, i, len(train_loader), losses['total'].avg, psnrs.avg, time.time() - t, flush=True))
            # Log to TensorBoard
            timestep = epoch * len(train_loader) + i
            writer.add_scalar('Loss/train', loss.data.item(), timestep)
            writer.add_scalar('PSNR/train', psnrs.avg, timestep)
            writer.add_scalar('SSIM/train', ssims.avg, timestep)
            writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], timestep)

            # Reset metrics
            losses, psnrs, ssims = myutils.init_meters(args.loss)
            t = time.time()


def test(args, epoch):
    if torch.distributed.get_rank() == args.main_rank:
        print('Evaluating for epoch = %d' % epoch)
    losses, psnrs, ssims = myutils.init_meters(args.loss)
    model.eval()
    criterion.eval()
        
    t = time.time()
    with torch.no_grad():
        for i, (images, gt_image) in enumerate(tqdm(test_loader)):

            images = [img_.cuda() for img_ in images]
            gt = [gt_.cuda() for gt_ in gt_image]

            out = model(images) ## images is a list of neighboring frames
            out = torch.cat(out)
            gt = torch.cat(gt)

            # Save loss values
            loss, loss_specific = criterion(out, gt)

            torch.distributed.barrier()

            loss = myutils.reduce_tensor(loss, args.num_gpu)
            loss_specific['L1'] = myutils.reduce_tensor(loss_specific['L1'], args.num_gpu)

            for k, v in losses.items():
                if k != 'total':
                    v.update(loss_specific[k].item())
            losses['total'].update(loss.item())

            # Evaluate metrics
            myutils.eval_metrics(out, gt, psnrs, ssims, args.num_gpu)
                    
    # Print progress
    if torch.distributed.get_rank() == args.main_rank:
        print("Loss: %f, PSNR: %f, SSIM: %f\n" % (losses['total'].avg, psnrs.avg, ssims.avg))

        # Save psnr & ssim
        save_fn = os.path.join(save_loc, 'results.txt')
        with open(save_fn, 'a') as f:
            f.write('For epoch=%d\t' % epoch)
            f.write("PSNR: %f, SSIM: %f\n" %
                    (psnrs.avg, ssims.avg))

        # Log to TensorBoard
        timestep = epoch +1
        writer.add_scalar('Loss/test', loss.data.item(), timestep)
        writer.add_scalar('PSNR/test', psnrs.avg, timestep)
        writer.add_scalar('SSIM/test', ssims.avg, timestep)

    return losses['total'].avg, psnrs.avg, ssims.avg


""" Entry Point """
def main(args):



    best_psnr = 0
    for epoch in range(args.start_epoch, args.max_epoch):
        train(args, epoch)
        if torch.distributed.get_rank() == args.main_rank:    
            print("Saving Checkpoint for Current epoch")
        myutils.save_checkpoint({
            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler':scheduler.state_dict(),
            'lr' : optimizer.param_groups[-1]['lr']
        }, save_loc, is_best=False, save_rank=args.main_rank)

        if epoch % 10 ==0 or epoch >90:
            test_loss, psnr, _ = test(args, epoch)

            # save checkpoint
            is_best = psnr > best_psnr
            best_psnr = max(psnr, best_psnr)
            myutils.save_checkpoint({
                'epoch': epoch,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':scheduler.state_dict(),
                'best_psnr': best_psnr,
                'lr' : optimizer.param_groups[-1]['lr']
            }, save_loc, is_best, args.main_rank)

        # update optimizer policy
        scheduler.step()

if __name__ == "__main__":
    main(args)
