""" Train for generating LIIF, from image to implicit representation.

    Config:
        train_dataset:
          dataset: $spec; wrapper: $spec; batch_size:
        val_dataset:
          dataset: $spec; wrapper: $spec; batch_size:
        (data_norm):
            inp: {sub: []; div: []}
            gt: {sub: []; div: []}
        (eval_type):
        (eval_bsize):

        model: $spec
        optimizer: $spec
        epoch_max:
        (multi_step_lr):
            milestones: []; gamma: 0.5
        (resume): *.pth

        (epoch_val): ; (epoch_save):
"""

import argparse
import os

import yaml
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR

import wandb

import datasets
import models
import utils
from test import eval_psnr, eval_metric


def make_data_loader(spec, logger, tag=''):
    if spec is None:
        return None

    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})

    logger.info('{} dataset: size={}'.format(tag, len(dataset)))
    # log('{} dataset: size={}'.format(tag, len(dataset)))
    for k, v in dataset[0].items():
        logger.info('  {}: shape={}'.format(k, tuple(v.shape)))

    loader = DataLoader(dataset, batch_size=spec['batch_size'],
        shuffle=(tag == 'train'), num_workers=8, pin_memory=True)
    return loader


def make_data_loaders(config, logger):
    train_loaders = []
    for key in config:
        if key.startswith('train_dataset'):
            train_loader = make_data_loader(config[key], logger, tag='train')
            train_loaders.append(train_loader)
    
    val_loader = make_data_loader(config.get('val_dataset'), logger, tag='val')
    return train_loaders, val_loader


def prepare_training(config, logger):
    '''
    Return:
        model: NN model
        optimizer: Optimizer
        epoch_start: int, the start epoch
        lr_scheduler: MultiStepLR() or None (if config.get('multi_step_lr') is None )
    '''
    if config.get('resume') is not None:
        sv_file = torch.load(config['resume'])
        # load model
        model = models.make(sv_file['model'], load_sd=True)
        model= nn.DataParallel(model)
        model = utils.to_cuda(model)
        # load optimizer
        optimizer = utils.make_optimizer(
            model.parameters(), sv_file['optimizer'], load_sd=True)
        epoch_start = sv_file['epoch'] + 1
        if config.get('multi_step_lr') is None:
            lr_scheduler = None
        else:
            # Decays the learning rate of each parameter group by gamma once epoch reaches one of the milestones
            lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
            for _ in range(epoch_start - 1):
                lr_scheduler.step()
    else:
        # make new model
        model = models.make(config['model'])
        model= nn.DataParallel(model)
        model = utils.to_cuda(model)
        # make new optimizer
        optimizer = utils.make_optimizer(
            model.parameters(), config['optimizer'])
        epoch_start = 1
        if config.get('multi_step_lr') is None:
            lr_scheduler = None
        else:
            lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])

    logger.info('model: #params={}'.format(utils.compute_num_params(model, text=True)))
    return model, optimizer, epoch_start, lr_scheduler


def train(train_loaders, model, optimizer, loss_fn, data_norm):
    model.train()
    
    train_loss = utils.Averager()

    # data_norm = config['data_norm']
    t = data_norm['inp']
    inp_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, -1, 1, 1) )
    inp_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, -1, 1, 1) )
    t = data_norm['gt']
    gt_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, 1, -1) )
    gt_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, 1, -1) )
    
    # training with multiple dataloader with different input image sizes
    for train_loader in train_loaders:
        for batch in tqdm(train_loader, leave=False, desc='train'):
            for k, v in batch.items():
                batch[k] = utils.to_cuda(v)

            # normalize input image
            inp = (batch['inp'] - inp_sub) / inp_div
            pred = model(inp, batch['coord'], batch['cell'], batch['band_coord'])

            # normalize target image
            gt = (batch['gt'] - gt_sub) / gt_div
            loss = loss_fn(pred, gt)

            train_loss.add(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            pred = None; loss = None

    return train_loss.item()


def main(config_, save_path, args):
    # global config, log, writer
    config = config_

    writer = utils.set_save_path(save_path, remove = not args.ni)
    logger = utils.setup_logging(log_file = os.path.join(save_path, "log.txt"), 
                                console=True, filemode='a')
    with open(os.path.join(save_path, 'config.yaml'), 'w') as f:
        yaml.dump(config, f, sort_keys=False)

    train_loaders, val_loader = make_data_loaders(config, logger)
    if config.get('data_norm') is None:
        config['data_norm'] = {
            'inp': {'sub': [0], 'div': [1]},
            'gt': {'sub': [0], 'div': [1]}
        }

    model, optimizer, epoch_start, lr_scheduler = prepare_training(config, logger)

    n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
    if n_gpus > 1:
        model = nn.parallel.DataParallel(model)
    # The maximum epoch
    epoch_max = config['epoch_max']
    # The epoch interval in which to eval model
    epoch_val = config.get('epoch_val')
    # The epoch interval in which to save model
    epoch_save = config.get('epoch_save')
    max_val_v = -1e18

    if config.get('loss_fn') is None:
        config['loss_fn'] = 'L1'
    loss_fn = utils.get_loss_function(config['loss_fn'])
    # loss_fn = nn.L1Loss()

    timer = utils.Timer()

    for epoch in range(epoch_start, epoch_max + 1):
        t_epoch_start = timer.t()
        # log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
        logger.info('epoch {}/{}'.format(epoch, epoch_max))

        writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
        wandb.log({'epoch': epoch, "train/lr": optimizer.param_groups[0]['lr']})

        train_loss = train(train_loaders, model, optimizer, loss_fn, 
                data_norm = config['data_norm'])
        if lr_scheduler is not None:
            lr_scheduler.step()

        # log_info.append('train: loss={:.4f}'.format(train_loss))
        logger.info('Epoch {}: train: loss={:.4f}'.format(epoch, train_loss))
        writer.add_scalars('loss', {'train': train_loss}, epoch)
        wandb.log({'epoch': epoch, 
                "train/loss": train_loss})

        if n_gpus > 1:
            model_ = model.module
        else:
            model_ = model
        model_spec = config['model']
        model_spec['sd'] = model_.state_dict()
        optimizer_spec = config['optimizer']
        optimizer_spec['sd'] = optimizer.state_dict()
        sv_file = {
            'model': model_spec,
            'optimizer': optimizer_spec,
            'epoch': epoch,
            'config': config
        }

        torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth'))

        if (epoch_save is not None) and (epoch % epoch_save == 0):
            torch.save(sv_file,
                os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))

        if (epoch_val is not None) and (epoch % epoch_val == 0):
            if n_gpus > 1 and (config.get('eval_bsize') is not None):
                model_ = model.module
            else:
                model_ = model
            # val_res = eval_psnr(val_loader, model_,
            #     data_norm=config['data_norm'],
            #     eval_type=config.get('eval_type'),
            #     eval_bsize=config.get('eval_bsize'))
            val_res = eval_metric(val_loader, model_,
                data_norm=config['data_norm'],
                eval_type=config.get('eval_type'),
                eval_bsize=config.get('eval_bsize'),
                ratio_ergas=config.get('ratio_ergas'),
                loss_fn = loss_fn,
                eval_metric_flag = config.get('eval_metric_flag'))
            val_psnr, val_ergas, val_sam, val_ssim, val_loss = val_res

            # log_info.append('val: psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
            #     val_psnr, val_ergas, val_sam, val_ssim))
            logger.info('Epoch {}: val: val_loss={:.4f}\t psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
                epoch, val_loss, val_psnr, val_ergas, val_sam, val_ssim))

            writer.add_scalars('psnr',  {'val': val_psnr }, epoch)
            writer.add_scalars('ergas', {'val': val_ergas}, epoch)
            writer.add_scalars('sam',   {'val': val_sam  }, epoch)
            writer.add_scalars('ssim',  {'val': val_ssim }, epoch)
            writer.add_scalars('loss',  {'val': val_loss }, epoch)

            wandb.log({ 'epoch': epoch,
                        "val/psnr":  val_psnr, 
                        "val/ergas": val_ergas,
                        "val/sam":   val_sam,
                        "val/ssim":  val_ssim,
                        "val/loss":  val_loss
                    })
            # wandb.log({"val_ergas": val_ergas, 'epoch': epoch})
            # wandb.log({"val_sam":   val_sam, 'epoch': epoch})
            # wandb.log({"val_ssim":  val_ssim, 'epoch': epoch})
            # wandb.log({"val_loss":  val_loss, 'epoch': epoch})

            if val_psnr > max_val_v:
                max_val_v = val_psnr
                torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth'))

        t = timer.t()
        prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
        t_epoch = utils.time_text(t - t_epoch_start)
        t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
        # log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))
        logger.info('{} {}/{}'.format(t_epoch, t_elapsed, t_all))

        # log(', '.join(log_info))
        writer.flush()

def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config')
    parser.add_argument('--name', default=None)
    parser.add_argument('--tag', default=None)
    parser.add_argument('--gpu', default='0')
    parser.add_argument(
        "--ni",
        action="store_true",
        help="No interaction. Suitable for Slurm Job launcher",
    )
    return parser

if __name__ == '__main__':
    

    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        print('config loaded.')

    config_str = utils.make_model_file_param_args(config, args)

    save_name = args.name
    if save_name is None:
        save_name = '_' + args.config.split('/')[-1][:-len('.yaml')]
    if args.tag is not None:
        save_name += '_' + args.tag
    save_path = os.path.join('./save', save_name, config_str)
    
    s1, s2 = config_str.split("banddec")
    s0, s1 = s1.split('TSM')
    s2, s3 = s2.split("bandposenc")

    wandb.init(project=(s0+s3)[:128], entity="specdecode")
    # wandb.init(project=config_str[:128], entity="specdecode")
    wandb.config = utils.dict2namespace(config)

    main(config, save_path, args)
