from models import utils

import torch
from torchvision import utils as vutils

def train_one_epoch(epoch, model, loader, optimizer, scheduler, use_amp=False, fp16_scaler=None):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('mse', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    log_interval = len(loader) // 2

    model.train()
    optimizer.zero_grad()
    for inputs in metric_logger.log_every(loader, log_interval, header):
        image = inputs['image']
        image = image.cuda()

        optimizer.zero_grad()
        outputs = model(image)
        loss = outputs["loss"]

        # optimizer.zero_grad()
        loss.backward()

        # step
        optimizer.step()

        # log 
        # torch.cuda.synchronize()
        metric_logger.update(loss=loss.item())
        metric_logger.update(mse=loss.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

        scheduler.step()
    
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    if utils.is_main_process():
        print(">>> Train Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def valid_one_epoch(epoch, model, loader):
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('mse', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Test:'
    log_interval = len(loader) // 2

    model.eval()
    for inputs in metric_logger.log_every(loader, log_interval, header):
        image = inputs['image']
        image = image.cuda()
        with torch.no_grad():
            outputs = model(image)
            loss = outputs["loss"]
        
        metric_logger.update(loss=loss.item())
        metric_logger.update(mse=loss.item())

    metric_logger.synchronize_between_processes()

    if utils.is_main_process():
        # print('* loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss))
        print(">>> Test Averaged stats:", metric_logger)
        
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def sample_images(model, loader, batch_size, n_samples, perm):
    # perm = torch.randperm(batch_sized)
    # perm = torch.arange(batch_size)
    model.eval()
    idx = perm[: n_samples]
    batch = next(iter(loader))
    batch = batch['image'][:n_samples]
    image = batch.cuda()
    with torch.no_grad():
        outputs = model(image)
        recon_combined = outputs['reconstruction']
        recons = outputs['slot_reconstruction']
        masks = outputs['mask']
        slots = outputs['slot']
        if 'reconstruction_root' in outputs:
            recon_root = outputs['reconstruction_root']
    image = (image * 2.0) - 1.0
    images = [
                image.unsqueeze(1),  # original images
                recon_combined.unsqueeze(1),  # reconstructions
                recons + (1 - masks) #recons * masks + (1 - masks),  # each slot
            ]
    if 'reconstruction_root' in outputs:
        images = images[:1] + [recon_root.unsqueeze(1)] + images[1:]
    out = utils.to_rgb_from_tensor(
        torch.cat(images,dim=1,)
    )

    batch_size, num_slots, C, H, W = recons.shape
    images = vutils.make_grid(
        out.view(batch_size * out.shape[1], C, H, W).cpu(), normalize=False, nrow=out.shape[1],
    )
    return images

