import argparse
import os
import math
from functools import partial

import yaml
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import numpy as np

import datasets
import models
import utils


from test import *



def image_generation(loader, model, data_norm=None, eval_type=None, eval_bsize=None,
              verbose=False, ratio_ergas = 1.0/2, 
              loss_fn = nn.L1Loss(), eval_metric_flag = None):
    '''
    generate image
    '''
    model.eval()

    if data_norm is None:
        data_norm = {
            'inp': {'sub': [0], 'div': [1]},
            'gt': {'sub': [0], 'div': [1]}
        }
    t = data_norm['inp']
    inp_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, -1, 1, 1) )
    inp_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, -1, 1, 1) )
    t = data_norm['gt']
    gt_sub = utils.to_cuda(torch.FloatTensor(t['sub']).view(1, 1, -1) )
    gt_div = utils.to_cuda(torch.FloatTensor(t['div']).view(1, 1, -1) )

    metric_fn = utils.calc_eval_metric
    if eval_type is None:
        metric_fn = utils.calc_eval_metric
    elif eval_type.startswith('div2k'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='div2k', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('benchmark'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='benchmark', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('cave'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='cave', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    elif eval_type.startswith('pavia_centra'):
        scale = int(eval_type.split('-')[1])
        metric_fn = partial(utils.calc_eval_metric, dataset='pavia_centra', scale=scale, 
            ratio_ergas = ratio_ergas, eval_metric_flag = eval_metric_flag)
    else:
        raise NotImplementedError
    
    
    preds = []
    gts = []
    psnrs = []
    ergass= []
    sams = []
    ssims = []

    pbar = tqdm(loader, leave=False, desc='val')
    for batch in pbar:
        for k, v in batch.items():
            batch[k] = utils.to_cuda(v )

        inp = (batch['inp'] - inp_sub) / inp_div
        # pred: shape (B, H_h * W_h, C)
        if eval_bsize is None:
            with torch.no_grad():
                pred = model(inp, batch['coord'], batch['cell'], batch['band_coord'])
        else:
            pred = batched_predict(model, inp,
                batch['coord'], batch['cell'], eval_bsize, batch['band_coord'])


        pred = pred * gt_div + gt_sub
        pred.clamp_(0, 1)

        B, hw, C = pred.shape
        assert pred.shape == batch['gt'].shape
        assert B == batch['inp'].shape[0]

        # update eval_metric_flag
        eval_metric_flag = utils.update_eval_metric_flag(eval_metric_flag)

        if eval_type is not None: # reshape for shaving-eval
            ih, iw = batch['inp'].shape[-2:]
            # s: float, the upscale factor
            s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
            shape = [B, round(ih * s), round(iw * s), C]
            # pred: shape (B, C, ih * s, iw * s)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, ih * s, iw * s)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
        else:
            shape = [B, hw, 1, C]
            # pred: shape (B, C, hw, 1)
            pred = pred.reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()
            # batch['gt']: shape (B, C, hw, 1)
            batch['gt'] = batch['gt'].reshape(*shape) \
                .permute(0, 3, 1, 2).contiguous()

            # this is not a full image, there is no meaning to compute ssim
            eval_metric_flag['ssim'] = False






        psnr, ergas, sam, ssim = metric_fn(pred, batch['gt'], 
            ratio_ergas=1, eval_metric_flag=eval_metric_flag)

        preds.append(pred.cpu().numpy())
        gts.append(batch['gt'].cpu().numpy())
        psnrs.append(psnr)
        ergass.append(ergas)
        sams.append(sam)
        ssims.append(ssim)



    preds = np.concatenate(preds, axis = 0)
    gts = np.concatenate(gts, axis = 0)
    psnrs = np.concatenate(psnrs, axis = 0)
    ergass = np.concatenate(ergass, axis = 0)
    sams = np.concatenate(sams, axis = 0)
    ssims = np.concatenate(ssims, axis = 0)

    res = {
        "pred": preds,  # shape (B, C, H, W)
        'gt': gts,      # shape (B, C, H, W)
        'psnr': psnrs,  # shape (B)
        'ergas': ergass, # shape (B)
        'sam':sams,     # shape (B)
        'ssim': ssims   # shape (B)
    }

    return res



def generate_img_and_save(config, model_path, scale, num_band = None, filepath = None):
    config['test_dataset']['wrapper']['args']['scale_min'] = scale
    config['test_dataset']['wrapper']['args']['scale_max'] = scale

    # if hasattr(config['test_dataset']['wrapper']['args'], 'num_band_min') and num_band is not None :
    config['test_dataset']['wrapper']['args']['num_band_min'] = num_band
    config['test_dataset']['wrapper']['args']['num_band_max'] = num_band



    spec = config['test_dataset']
    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
    loader = DataLoader(dataset, batch_size=spec['batch_size'],
        num_workers=8, pin_memory=True)
    print('{} dataset: size={}'.format(config['data'], len(dataset)))
    for k, v in dataset[0].items():
        print('  {}: shape={}'.format(k, tuple(v.shape)))

    model_spec = torch.load(model_path)['model']
    model = utils.to_cuda(models.make(model_spec, load_sd=True) )
    print("Load model")


    if config.get('loss_fn') is None:
        config['loss_fn'] = 'L1'
    loss_fn = utils.get_loss_function(config['loss_fn'])

    print("Start model evaluating")
    val_res = image_generation(loader, model,
                data_norm=config.get('data_norm'),
                eval_type=config.get('eval_type'),
                eval_bsize=config.get('eval_bsize'),
                ratio_ergas=config.get('ratio_ergas'),
                loss_fn = loss_fn,
                eval_metric_flag = config.get('eval_metric_flag'))

    

    print('Scale: {:d} band: {:d} Test: val: psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
            scale, num_band,
                    val_res['psnr'].mean(), val_res['ergas'].mean(), val_res['sam'].mean(), val_res['ssim'].mean()))

    # data = [model_path, scale, scale, num_band, num_band, val_loss, val_psnr, val_ergas, val_sam, val_ssim]
    
    
    filepath = os.path.join(os.path.dirname(model_path), f"model_pred_scale{scale}_band{num_band}.pkl")
    
    utils.pickle_dump(val_res, filepath)
    print(f"Save img to {filepath}")
    return val_res, filepath

def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config')
    # parser.add_argument('--model_dir')
    parser.add_argument('--model')
    # parser.add_argument('--log')
    parser.add_argument('--gpu', default='0')

    parser.add_argument("--scale_list", nargs="+", default=[])
    parser.add_argument("--num_band_list", nargs="+", default=[])
    
    return parser

if __name__ == '__main__':
    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    scale_list = [int(x) for x in args.scale_list]
    num_band_list = [int(x) for x in args.num_band_list]

    for scale in scale_list:
    	for num_band in num_band_list:
    		val_res, filepath = generate_img_and_save(config, 
                                 model_path = args.model, 
                                 scale = scale, 
                                 num_band = num_band)