import time

import torch

from train.gradient_based.__init__ import inner_adapt, divide_loss
from utils import psnr, get_meta_batch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def check(P):
    filename_with_today_date = True
    return filename_with_today_date


def train_step_img(P, steps, wrapper, optimizer, task_data, metric_logger, logger, inner_lr):
    stime = time.time()
    wrapper.train()

    batch_size, context = get_meta_batch(P, task_data)

    # Run inner loop
    wrapper.support = True
    # loss_in: (step * iter, b, c, h, w), loss_in_log: (step, t, b, c, h//n, w//n), res_in: (step, b, c, h, w)
    params, loss_in, loss_in_log, res_in, grad_in = inner_adapt(P, wrapper, context[0], inner_lr,
                                                                P.inner_step, first_order=P.mode == 'fomaml', order=P.order)
    loss_in = loss_in.view(P.inner_step, P.inner_iter, batch_size, -1).mean(dim=-1) # (step, iter, b)
    loss_in = loss_in[-1][-1] # (b)
    loss_in_log = loss_in_log.view(P.inner_step, P.inner_step, batch_size, -1).mean(dim=-1)

    """ outer loss aggregate """
    wrapper.support = False
    loss_out, res_out = wrapper(context[0], params=params) # (b, c, h, w)
    loss_out = loss_out.view(batch_size, -1).mean(dim=1) # (b)
    loss = loss_out.mean() * batch_size
    # """ outer gradient step """
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(wrapper.decoder.parameters(), 1.0)
    optimizer.step()
    torch.cuda.synchronize()
    P.trained_inner_lr = inner_lr.item()

    """ track stat """
    metric_logger.meters['batch_time'].update(time.time() - stime, n=batch_size)
    metric_logger.meters['train_context'].update(context[0], n=batch_size)
    metric_logger.meters['loss_in'].update(loss_in.mean().item(), n=batch_size)
    metric_logger.meters['loss_out'].update(loss_out.mean().item(), n=batch_size)
    metric_logger.meters['psnr_in'].update(psnr(loss_in).mean().item(), n=batch_size)
    metric_logger.meters['psnr_out'].update(psnr(loss_out).mean().item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    if steps % P.print_step == 0:
        logger.log_dirname(f"Step {steps}")
        if P.log_method == 'step':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}/loss_patch{j:02}', loss_in_log[i][j].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}/psnr_patch{j:02}', psnr(loss_in_log[i][j]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)
        elif P.log_method == 'patch':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'train_loss_in_patch{i:02}/loss_step{j:02}', loss_in_log[j][i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_patch{i:02}/psnr_step{j:02}', psnr(loss_in_log[j][i]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)

        logger.scalar_summary('train/inner_lr', inner_lr.item(), steps)
        logger.scalar_summary('train/loss_in', loss_in.mean().item(), steps)
        logger.scalar_summary('train/loss_out', loss_out.mean().item(), steps)
        logger.scalar_summary('train/psnr_in', psnr(loss_in).mean().item(), steps)
        logger.scalar_summary('train/psnr_out', psnr(loss_out).mean().item(), steps)
        logger.scalar_summary('train/batch_time', metric_logger.batch_time.value, steps)
        if P.data_type == 'img':
            logger.image_summary('train/imgs_in', res_in, metric_logger.train_context, steps)
            logger.image_summary('train/imgs_out', res_out, metric_logger.train_context, steps)

        logger.log('[TRAIN] [Step %3d] [Time %.3f] [Data %.3f] '
                   '[LossIn %f] [LossOut %f] [PSNRIn %.3f] [PSNROut %.3f]' %
                   (steps, metric_logger.batch_time.global_avg, metric_logger.data_time.global_avg,
                    loss_in.mean().item(), loss_out.mean().item(),
                    psnr(loss_in).mean().item(), psnr(loss_out).mean().item()))

        metric_logger.reset()


def train_step_video(P, steps, wrapper, optimizer, task_data, metric_logger, logger, inner_lr):
    stime = time.time()
    wrapper.train()

    batch_size, context = get_meta_batch(P, task_data)

    # Run inner loop
    wrapper.support = True
    # loss_in: (step * iter, b, c, h, w), loss_in_log: (step, t, b, c, h, w), res_in: (step, b, c, h, w)
    params, loss_in, loss_in_log, res_in, grad_in = inner_adapt(P, wrapper, context[0], inner_lr,
                                                                P.inner_step, first_order=P.mode == 'fomaml', order=P.order)
    loss_in = loss_in.view(P.inner_step, P.inner_iter, batch_size, -1).mean(dim=-1) # (step, iter, b)
    loss_in = loss_in[-1][-1] # (b)
    loss_in_log = loss_in_log.view(P.inner_step, P.inner_step, batch_size, -1).mean(dim=-1) # (step, t, b)

    """ outer loss aggregate """
    wrapper.support = False
    loss_out, res_out = wrapper(context[0], params=params) # (b, c, h, w)
    loss_out = loss_out.view(batch_size, -1).mean(dim=-1) # (b)
    loss = loss_out.mean()
    # """ outer gradient step """
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(wrapper.decoder.parameters(), 1.0)
    optimizer.step()
    torch.cuda.synchronize()
    P.trained_inner_lr = inner_lr.item()

    """ track stat """
    metric_logger.meters['batch_time'].update(time.time() - stime, n=batch_size)
    metric_logger.meters['train_context'].update(context[0], n=batch_size)
    metric_logger.synchronize_between_processes()

    if steps % P.print_step == 0:
        logger.log_dirname(f"Step {steps}")
        if P.log_method == 'step':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}/loss_patch{j:02}', loss_in_log[i][j].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}/psnr_patch{j:02}', psnr(loss_in_log[i][j]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)
        elif P.log_method == 'patch':
            for i in range(P.inner_step):
                for j in range(P.inner_step):
                    logger.writer.add_scalar(f'train_loss_in_patch{i:02}/loss_step{j:02}', loss_in_log[j][i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_patch{i:02}/psnr_step{j:02}', psnr(loss_in_log[j][i]).mean().item(), steps)
                if i > 0:
                    logger.writer.add_scalar(f'train_loss_in_step{i:02}', loss_in_log[i][:i].mean().item(), steps)
                    logger.writer.add_scalar(f'train_psnr_in_step{i:02}', psnr(loss_in_log[i][:i]).mean().item(), steps)

        logger.scalar_summary('train/inner_lr', inner_lr.item(), steps)
        logger.scalar_summary('train/loss_in', loss_in.mean().item(), steps)
        logger.scalar_summary('train/loss_out', loss_out.mean().item(), steps)
        logger.scalar_summary('train/psnr_in', psnr(loss_in).mean().item(), steps)
        logger.scalar_summary('train/psnr_out', psnr(loss_out).mean().item(), steps)
        logger.scalar_summary('train/batch_time', metric_logger.batch_time.value, steps)
        if P.data_type == 'video':
            logger.video_summary('train/vids_in', res_in, metric_logger.train_context, steps)
            logger.video_summary('train/vids_out', res_out, metric_logger.train_context, steps)

        logger.log('[TRAIN] [Step %3d] [Time %.3f] [Data %.3f] '
                   '[LossIn %f] [LossOut %f] [PSNRIn %.3f] [PSNROut %.3f]' %
                   (steps, metric_logger.batch_time.global_avg, metric_logger.data_time.global_avg,
                    loss_in.mean().item(), loss_out.mean().item(),
                    psnr(loss_in).mean().item(), psnr(loss_out).mean().item()))

        metric_logger.reset()
