import torch

from train.gradient_based.__init__ import inner_adapt
from utils import MetricLogger, psnr, get_meta_batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def check(P):
    filename_with_today_date = True
    return filename_with_today_date


def test_model_img(P, wrapper, loader, steps, logger=None):
    metric_logger = MetricLogger(delimiter="  ")

    if logger is None:
        log_ = print
    else:
        log_ = logger.log

    # Switch to evaluate mode
    mode = wrapper.training
    wrapper.eval()
    wrapper.coord_init()

    for n, task_data in enumerate(loader):
        task_data = {k: v.to(device, non_blocking=True) for k, v in task_data.items()}
        batch_size, context = get_meta_batch(P, task_data)
        params, loss_in, loss_in_log, res_in, grad_in = inner_adapt(P, wrapper, context[0], P.trained_inner_lr,
                                                                    P.inner_step, first_order=True, order=P.order)
        loss_in = loss_in.view(P.inner_step, P.tto, batch_size, -1).mean(dim=-1) # (step, tto, b)
        loss_in = loss_in[-1][-1] # (b)
        loss_in_log = loss_in_log.view(P.inner_step, P.inner_step, batch_size, -1).mean(dim=-1) # (step, step, b)

        with torch.no_grad():
            loss_out, res_out = wrapper(context[0], params=params)
            loss_out = loss_out.view(batch_size, -1).mean(dim=-1) # (b)

        metric_logger.meters['loss_in'].update(loss_in.mean().item(), n=batch_size)
        metric_logger.meters['psnr_in'].update(psnr(loss_in).mean().item(), n=batch_size)
        metric_logger.meters['loss_out'].update(loss_out.mean().item(), n=batch_size)
        metric_logger.meters['psnr_out'].update(psnr(loss_out).mean().item(), n=batch_size)
        metric_logger.meters['eval_context'].update(context[0], n=batch_size)
        if n * P.test_batch_size > P.max_test_task:
            break

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    log_(' * [EVAL] [LossIn %.3f] [LossOut %.3f] [PSNRIn %.3f] [PSNROut %.3f]' %
        (metric_logger.loss_in.global_avg, metric_logger.loss_out.global_avg,
        metric_logger.psnr_in.global_avg, metric_logger.psnr_out.global_avg))

    if logger is not None:
        if P.log_method == 'step':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}/loss_patch{j:02}', loss_in_log[i][j].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}/psnr_patch{j:02}', psnr(loss_in_log[i][j]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)
        elif P.log_method == 'patch':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'eval_loss_in_patch{i:02}/loss_step{j:02}', loss_in_log[j][i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_patch{i:02}/psnr_step{j:02}', psnr(loss_in_log[j][i]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)

        logger.scalar_summary('eval/inner_lr', P.trained_inner_lr, steps)
        logger.scalar_summary('eval/loss_in', metric_logger.loss_in.global_avg, steps)
        logger.scalar_summary('eval/loss_out', metric_logger.loss_out.global_avg, steps)
        logger.scalar_summary('eval/psnr_in', metric_logger.psnr_in.global_avg, steps)
        logger.scalar_summary('eval/psnr_out', metric_logger.psnr_out.global_avg, steps)
        if P.data_type == 'img':
            logger.image_summary('eval/imgs_in', res_in, metric_logger.eval_context, steps)
            logger.image_summary('eval/imgs_out', res_out, metric_logger.eval_context, steps)

    wrapper.train(mode)
    return metric_logger.psnr_out.global_avg

def test_model_video(P, wrapper, loader, steps, logger=None):
    metric_logger = MetricLogger(delimiter="  ")

    if logger is None:
        log_ = print
    else:
        log_ = logger.log

    # Switch to evaluate mode
    mode = wrapper.training
    wrapper.eval()
    wrapper.coord_init()

    for n, task_data in enumerate(loader):
        task_data = {k: v.to(device, non_blocking=True) for k, v in task_data.items()}
        batch_size, context = get_meta_batch(P, task_data)
        params, loss_in, loss_in_log, res_in, grad_in = inner_adapt(P, wrapper, context[0], P.trained_inner_lr,
                                                                    P.inner_step, first_order=True, order=P.order)
        loss_in = loss_in.view(P.inner_step, P.tto, batch_size, -1).mean(dim=-1) # (step, tto, b)
        loss_in = loss_in[-1][-1] # (b)
        loss_in_log = loss_in_log.view(P.inner_step, P.inner_step, batch_size, -1).mean(dim=-1) # (step, step, b)

        with torch.no_grad():
            loss_out, res_out = wrapper(context[0], params=params)
            loss_out = loss_out.view(batch_size, -1).mean(dim=-1) # (b)

        metric_logger.meters['loss_in'].update(loss_in.mean().item(), n=batch_size)
        metric_logger.meters['psnr_in'].update(psnr(loss_in).mean().item(), n=batch_size)
        metric_logger.meters['loss_out'].update(loss_out.mean().item(), n=batch_size)
        metric_logger.meters['psnr_out'].update(psnr(loss_out).mean().item(), n=batch_size)
        metric_logger.meters['eval_context'].update(context[0], n=batch_size)
        if n * P.test_batch_size > P.max_test_task:
            break

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()

    log_(' * [EVAL] [LossIn %.3f] [LossOut %.3f] [PSNRIn %.3f] [PSNROut %.3f]' %
        (metric_logger.loss_in.global_avg, metric_logger.loss_out.global_avg,
        metric_logger.psnr_in.global_avg, metric_logger.psnr_out.global_avg))

    if logger is not None:
        if P.log_method == 'step':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}/loss_patch{j:02}', loss_in_log[i][j].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}/psnr_patch{j:02}', psnr(loss_in_log[i][j]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)
        elif P.log_method == 'patch':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'eval_loss_in_patch{i:02}/loss_step{j:02}', loss_in_log[j][i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_patch{i:02}/psnr_step{j:02}', psnr(loss_in_log[j][i]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'eval_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'eval_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)

        logger.scalar_summary('eval/inner_lr', P.trained_inner_lr, steps)
        logger.scalar_summary('eval/loss_in', metric_logger.loss_in.global_avg, steps)
        logger.scalar_summary('eval/loss_out', metric_logger.loss_out.global_avg, steps)
        logger.scalar_summary('eval/psnr_in', metric_logger.psnr_in.global_avg, steps)
        logger.scalar_summary('eval/psnr_out', metric_logger.psnr_out.global_avg, steps)
        if P.data_type == 'video':
            logger.video_summary('eval/vids_in', res_in, metric_logger.eval_context, steps)
            logger.video_summary('eval/vids_out', res_out, metric_logger.eval_context, steps)

    wrapper.train(mode)
    return metric_logger.psnr_out.global_avg
