from models import utils

import torch
from torchvision import utils as vutils

def train_one_epoch(epoch, model, loader, optimizer, scheduler, use_amp=False, fp16_scaler=None):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mse', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('rpy', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    log_interval = len(loader) // 2

    model.train()
    optimizer.zero_grad()
    for inputs in metric_logger.log_every(loader, log_interval, header):
        image = inputs['image']
        index = inputs['index']
        image = image.cuda()
        index = index.cuda()

        optimizer.zero_grad()
        outputs = model(image, index)
        loss = outputs["loss"]

        # optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()

        # log 
        # torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(mse=outputs["mse"].item())
        metric_logger.update(rpy=outputs["replay_loss"])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        scheduler.step()
    
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    if utils.is_main_process():
        print(">>> Train Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def valid_one_epoch(epoch, model, loader):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('mse', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Test:'
    log_interval = len(loader) // 2

    model.eval()
    for inputs in metric_logger.log_every(loader, log_interval, header):
        image = inputs['image']
        image = image.cuda()
        with torch.no_grad():
            outputs = model(image)
            loss = outputs["loss"]
        
        metric_logger.update(loss=loss.item())
        metric_logger.update(mse=outputs["mse"].item())

    metric_logger.synchronize_between_processes()

    if utils.is_main_process():
        # print('* loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss))
        print(">>> Test Averaged stats:", metric_logger)
        
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def make_recon_img(slot, mask):
    """Returns an image from composing slots (weighted sum) according to the masks.

    Args:
        slot (Tensor): The slot-wise images.
        mask (Tensor): The masks. These are weights that should sum to 1 along the
            slot dimension, but this is not enforced.

    Returns:
        The image resulting from a weighted sum of the slots using the masks as weights.
    """
    b, s, ch, h, w = slot.shape  # B, slots, 3, H, W
    assert mask.shape == (b, s, 1, h, w)  # B, slots, 1, H, W
    return (slot * mask).sum(dim=1)  # B, 3, H, W


def sample_images(model, loader, batch_size, n_samples, perm):
    # perm = torch.randperm(batch_sized)
    # perm = torch.arange(batch_size)
    model.eval()
    idx = perm[: n_samples]
    batch = next(iter(loader))
    batch = batch['image'][:n_samples]
    image = batch.cuda()
    with torch.no_grad():
        outputs = model(image)
        recon_combined = outputs['reconstruction']
        # recons = outputs['slot_reconstruction']
        masks = outputs['mask']
        slots = outputs['slot']

    image = (image * 2.0) - 1.0
    slots = (slots * 2.0) - 1.0
    images = [
                image.unsqueeze(1),  # original images
                recon_combined.unsqueeze(1),  # reconstructions
                slots * masks  + (1 - masks) #recons * masks + (1 - masks),  # each slot
            ]
    out = utils.to_rgb_from_tensor(
        torch.cat(images,dim=1,)
    )

    batch_size, num_slots, C, H, W = slots.shape
    images = vutils.make_grid(
        out.view(batch_size * out.shape[1], C, H, W).cpu(), normalize=False, nrow=out.shape[1],
    )
    return images
