import argparse
import os
import yaml
import torch
from torch.utils.data import DataLoader
import csv
import datasets
import models
import utils

from test import eval_metric


def compute_model_eval_by_scale_band(config, model_path, scale, num_band=None, filepath=None):
    config['test_dataset']['wrapper']['args']['scale_min'] = scale
    config['test_dataset']['wrapper']['args']['scale_max'] = scale

    # if hasattr(config['test_dataset']['wrapper']['args'], 'num_band_min') and num_band is not None :
    config['test_dataset']['wrapper']['args']['num_band_min'] = num_band
    config['test_dataset']['wrapper']['args']['num_band_max'] = num_band

    spec = config['test_dataset']
    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
    loader = DataLoader(dataset, batch_size=spec['batch_size'],
                        num_workers=8, pin_memory=True)
    print('{} dataset: size={}'.format(config['data'], len(dataset)))

    for k, v in dataset[0].items():
        print('  {}: shape={}'.format(k, tuple(v.shape)))

    model_spec = torch.load(model_path)['model']
    model = utils.to_cuda(models.make(model_spec, load_sd=True))
    print("Load model")

    if config.get('loss_fn') is None:
        config['loss_fn'] = 'L1'
    loss_fn = utils.get_loss_function(config['loss_fn'])

    print("Start model evaluating")
    val_res = eval_metric(loader, model,
                          data_norm=config.get('data_norm'),
                          eval_type=config.get('eval_type'),
                          eval_bsize=config.get('eval_bsize'),
                          ratio_ergas=config.get('ratio_ergas'),
                          loss_fn=loss_fn,
                          eval_metric_flag=config.get('eval_metric_flag'))

    val_psnr, val_ergas, val_sam, val_ssim, val_loss = val_res

    print(
        'Scale: {:d} band: {:d} Test: val: val_loss={:.4f}\t psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
            scale, num_band,
            val_loss, val_psnr, val_ergas, val_sam, val_ssim))

    data = [model_path, scale, scale, num_band, num_band, val_loss, val_psnr, val_ergas, val_sam, val_ssim]

    filepath = os.path.join(os.path.dirname(model_path), "model_eval.csv")

    if not os.path.exists(filepath):
        with open(filepath, 'w', newline='\n') as csvfile:
            spamwriter = csv.writer(csvfile, delimiter=',',
                                    quotechar='|', quoting=csv.QUOTE_MINIMAL)
            head = ['model_path',
                    'scale_min', 'scale_max',
                    "num_band_min", "num_band_max",
                    "val_loss", "val_psnr", "val_ergas", "val_sam", "val_ssim"]
            spamwriter.writerow(head)

    with open(filepath, 'a', newline='\n') as csvfile:
        spamwriter = csv.writer(csvfile, delimiter=',',
                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
        spamwriter.writerow(data)


def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default="")
    # parser.add_argument('--model_dir')
    parser.add_argument('--model', default="")
    # parser.add_argument('--log')
    parser.add_argument('--gpu', default='0')
    # tha spatial scale you wish to achieve
    parser.add_argument("--scale_list", nargs="+", default=[2, 4, 8, 10, 12, 14])
    # the number of bands you want to reconstruct
    parser.add_argument("--num_band_list", nargs="+", default=[31])

    return parser


if __name__ == '__main__':
    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    scale_list = [int(x) for x in args.scale_list]
    num_band_list = [int(x) for x in args.num_band_list]
    for scale in scale_list:
        for num_band in num_band_list:
            compute_model_eval_by_scale_band(config,
                                             model_path=args.model,
                                             scale=scale,
                                             num_band=num_band)
