import argparse
import os
import math
from functools import partial

import yaml
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm

import datasets
import models
import utils


def batched_predict(model, inp, coord, cell, bsize, band_coord):
    '''
    Args:
        model:
        inp: image tensor, shape (N, c, H_l, W_l)
        coord: shape (N, H_h * W_h, 2)
        cell: shape (N, H_h * W_h, 2)
        bsize: int, batch size
        band_coord: shape (N, C or num_band_sample, 2)
    Return:
        pred: shape (N, H_h * W_h, C or num_band_sample)
    '''
    with torch.no_grad():
        model.gen_feat(inp)
        n = coord.shape[1]
        ql = 0
        preds = []
        while ql < n:
            qr = min(ql + bsize, n)
            pred = model.query_rgb(coord[:, ql: qr, :], cell[:, ql: qr, :], band_coord)
            preds.append(pred)
            ql = qr
        pred = torch.cat(preds, dim=1)
    return pred


def eval_psnr(loader, model, data_norm=None, eval_type=None, eval_bsize=None,
              verbose=False):
    model.eval()

    if data_norm is None:
        data_norm = {
            'inp': {'sub': [0], 'div': [1]},
            'gt': {'sub': [0], 'div': [1]}
        }
    t = data_norm['inp']
    inp_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, -1, 1, 1) )
    inp_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, -1, 1, 1) )
    t = data_norm['gt']
    gt_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, 1, -1) )
    gt_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, 1, -1) )

    if eval_type is None:
        metric_fn = utils.calc_psnr
    elif eval_type.startswith('div2k'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale)
    elif eval_type.startswith('benchmark'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale)
    else:
        raise NotImplementedError

    val_res = utils.Averager()

    pbar = tqdm(loader, leave=False, desc='val')
    for batch in pbar:
        for k, v in batch.items():
            batch[k] = utils.to_cuda(v )

        inp = (batch['inp'] - inp_sub) / inp_div
        # pred: shape (B, H_h * W_h, C)
        if eval_bsize is None:
            with torch.no_grad():
                pred = model(inp, batch['coord'], batch['cell'])
        else:
            pred = batched_predict(model, inp,
                batch['coord'], batch['cell'], eval_bsize)
        pred = pred * gt_div + gt_sub
        pred.clamp_(0, 1)

        B, hw, C = pred.shape
        assert pred.shape == batch['gt'].shape
        assert B == batch['inp'].shape[0]

        if eval_type is not None: # reshape for shaving-eval
            ih, iw = batch['inp'].shape[-2:]
            s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
            shape = [B, round(ih * s), round(iw * s), C]
            # pred: shape (B, C, ih * s, iw * s)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, ih * s, iw * s)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
        else:
            shape = [B, hw, 1, C]
            # pred: shape (B, C, hw, 1)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, hw, 1)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()


        res = metric_fn(pred, batch['gt'])
        val_res.add(res.item(), inp.shape[0])

        if verbose:
            pbar.set_description('val {:.4f}'.format(val_res.item()))

    return val_res.item()

def eval_metric(loader, model, data_norm=None, eval_type=None, eval_bsize=None,
              verbose=False, ratio_ergas = 1.0/2, 
              loss_fn = nn.L1Loss(), eval_metric_flag = None):
    model.eval()

    if data_norm is None:
        data_norm = {
            'inp': {'sub': [0], 'div': [1]},
            'gt': {'sub': [0], 'div': [1]}
        }
    t = data_norm['inp']
    inp_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, -1, 1, 1) )
    inp_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, -1, 1, 1) )
    t = data_norm['gt']
    gt_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, 1, -1) )
    gt_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, 1, -1) )

    if eval_type is None:
        metric_fn = utils.calc_eval_metric
    elif eval_type.startswith('div2k'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='div2k', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('benchmark'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='benchmark', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('cave'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='cave', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('pavia_centra'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='pavia_centra', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    else:
        raise NotImplementedError

    val_psnr = utils.Averager()
    val_ergas = utils.Averager()
    val_sam = utils.Averager()
    val_ssim = utils.Averager()

    val_loss = utils.Averager()

    pbar = tqdm(loader, leave=False, desc='val')
    for batch in pbar:
        for k, v in batch.items():
            batch[k] = utils.to_cuda(v )

        # for k, v in batch.items():
        #     print('  {}: shape={}'.format(k, tuple(v.shape)))

        inp = (batch['inp'] - inp_sub) / inp_div
        # pred: shape (B, H_h * W_h, C)
        if eval_bsize is None:
            with torch.no_grad():
                pred = model(inp, batch['coord'], batch['cell'], batch['band_coord'])
        else:
            pred = batched_predict(model, inp,
                batch['coord'], batch['cell'], eval_bsize, batch['band_coord'])

        # normalize target image
        gt = (batch['gt'] - gt_sub) / gt_div
        with torch.no_grad():
            loss = loss_fn(pred, gt)
            val_loss.add(loss.item())


        pred = pred * gt_div + gt_sub
        pred.clamp_(0, 1)

        B, hw, C = pred.shape
        assert pred.shape == batch['gt'].shape
        assert B == batch['inp'].shape[0]

        # update eval_metric_flag
        eval_metric_flag = utils.update_eval_metric_flag(eval_metric_flag)

        if eval_type is not None: # reshape for shaving-eval
            ih, iw = batch['inp'].shape[-2:]
            # s: float, the upscale factor
            s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
            shape = [B, round(ih * s), round(iw * s), C]
            # pred: shape (B, C, ih * s, iw * s)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, ih * s, iw * s)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
        else:
            shape = [B, hw, 1, C]
            # pred: shape (B, C, hw, 1)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, hw, 1)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()

            # this is not a full image, there is no meaning to compute ssim
            eval_metric_flag['ssim'] = False




        psnr, ergas, sam, ssim = metric_fn(pred, batch['gt'], 
            ratio_ergas=1, eval_metric_flag=eval_metric_flag)
        val_psnr.add(psnr.mean(), psnr.shape[0])
        val_ergas.add(ergas.mean(), ergas.shape[0])
        val_sam.add(sam.mean(), sam.shape[0])
        val_ssim.add(ssim.mean(), ssim.shape[0])

        if verbose:
            pbar.set_description('val PSNR {:.4f}'.format(val_psnr.item()))
            pbar.set_description('val ERGAS {:.4f}'.format(val_ergas.item()))
            pbar.set_description('val SAM {:.4f}'.format(val_sam.item()))
            pbar.set_description('val SSIM {:.4f}'.format(val_ssim.item()))
            pbar.set_description('val loss {:.4f}'.format(val_loss.item()))

    return val_psnr.item(), val_ergas.item(), val_sam.item(), val_ssim.item(), val_loss.item()

def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config')
    # parser.add_argument('--model_dir')
    parser.add_argument('--model')
    # parser.add_argument('--log')
    parser.add_argument('--gpu', default='0')
    # parser.add_argument(
    #     "--ni",
    #     action="store_true",
    #     help="No interaction. Suitable for Slurm Job launcher",
    # )
    return parser

if __name__ == '__main__':
    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    spec = config['test_dataset']
    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
    loader = DataLoader(dataset, batch_size=spec['batch_size'],
        num_workers=8, pin_memory=True)
    print('{} dataset: size={}'.format(config['data'], len(dataset)))
    for k, v in dataset[0].items():
        print('  {}: shape={}'.format(k, tuple(v.shape)))

    model_spec = torch.load(args.model)['model']
    model = utils.to_cuda(models.make(model_spec, load_sd=True) )
    print("Load model")

    # res = eval_psnr(loader, model,
    #     data_norm=config.get('data_norm'),
    #     eval_type=config.get('eval_type'),
    #     eval_bsize=config.get('eval_bsize'),
    #     verbose=True)
    # print('result: {:.4f}'.format(res))

    if config.get('loss_fn') is None:
        config['loss_fn'] = 'L1'
    loss_fn = utils.get_loss_function(config['loss_fn'])

    print("Start model evaluating")
    val_res = eval_metric(loader, model,
                data_norm=config.get('data_norm'),
                eval_type=config.get('eval_type'),
                eval_bsize=config.get('eval_bsize'),
                ratio_ergas=config.get('ratio_ergas'),
                loss_fn = loss_fn,
                eval_metric_flag = config.get('eval_metric_flag'))

    val_psnr, val_ergas, val_sam, val_ssim, val_loss = val_res

    print(f"Model: {args.model}")
    print('Scale: {:d} Test: val: val_loss={:.4f}\t psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
        config['test_dataset']['wrapper']['args']['scale_min'],
                val_loss, val_psnr, val_ergas, val_sam, val_ssim))
