import argparse
import os
import math
from functools import partial

import yaml
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import csv

import datasets
import models
import utils

from test import eval_metric

def compute_model_eval_by_scale_band(config, model_path, scale, num_band = None, filepath = None):
    config['test_dataset']['wrapper']['args']['scale_min'] = scale
    config['test_dataset']['wrapper']['args']['scale_max'] = scale

    # if hasattr(config['test_dataset']['wrapper']['args'], 'num_band_min') and num_band is not None :
    config['test_dataset']['wrapper']['args']['num_band_min'] = num_band
    config['test_dataset']['wrapper']['args']['num_band_max'] = num_band



    spec = config['test_dataset']
    dataset = datasets.make(spec['dataset'])
    dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
    loader = DataLoader(dataset, batch_size=spec['batch_size'],
        num_workers=8, pin_memory=True)
    print('{} dataset: size={}'.format(config['data'], len(dataset)))
    for k, v in dataset[0].items():
        print('  {}: shape={}'.format(k, tuple(v.shape)))

    model_spec = torch.load(model_path)['model']
    model = utils.to_cuda(models.make(model_spec, load_sd=True) )
    print("Load model")


    if config.get('loss_fn') is None:
        config['loss_fn'] = 'L1'
    loss_fn = utils.get_loss_function(config['loss_fn'])

    print("Start model evaluating")
    val_res = eval_metric(loader, model,
                data_norm=config.get('data_norm'),
                eval_type=config.get('eval_type'),
                eval_bsize=config.get('eval_bsize'),
                ratio_ergas=config.get('ratio_ergas'),
                loss_fn = loss_fn,
                eval_metric_flag = config.get('eval_metric_flag'))

    val_psnr, val_ergas, val_sam, val_ssim, val_loss = val_res

    print('Scale: {:d} band: {:d} Test: val: val_loss={:.4f}\t psnr={:.4f}\t ergas={:.4f}\t sam={:.4f}\t ssim={:.4f}\t'.format(
            scale, num_band,
                    val_loss, val_psnr, val_ergas, val_sam, val_ssim))

    data = [model_path, scale, scale, num_band, num_band, val_loss, val_psnr, val_ergas, val_sam, val_ssim]
    
    
    filepath = os.path.join(os.path.dirname(model_path), "model_eval.csv")
    
    if not os.path.exists(filepath):
        with open(filepath, 'w', newline='\n') as csvfile:
            spamwriter = csv.writer(csvfile, delimiter=',',
                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
            head = ['model_path', 
                    'scale_min', 'scale_max', 
                    "num_band_min", "num_band_max", 
                    "val_loss", "val_psnr", "val_ergas", "val_sam", "val_ssim"]
            spamwriter.writerow(head)
    
    with open(filepath, 'a', newline='\n') as csvfile:
        spamwriter = csv.writer(csvfile, delimiter=',',
                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
        spamwriter.writerow(data)

def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config')
    # parser.add_argument('--model_dir')
    parser.add_argument('--model')
    # parser.add_argument('--log')
    parser.add_argument('--gpu', default='0')

    parser.add_argument("--scale_list", nargs="+", default=[])
    parser.add_argument("--num_band_list", nargs="+", default=[])
    
    return parser

if __name__ == '__main__':
    parser = make_args_parser()
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    scale_list = [int(x) for x in args.scale_list]
    num_band_list = [int(x) for x in args.num_band_list]

    for scale in scale_list:
        for num_band in num_band_list:
            compute_model_eval_by_scale_band(config, 
                                 model_path = args.model, 
                                 scale = scale, 
                                 num_band = num_band)

