from distutils.command.config import config
import torch as th
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pytorch_msssim import ms_ssim, ssim
import cv2

from torch.nn.parallel import DistributedDataParallel
import numpy as np
import os
from typing import Tuple, Union, List
from einops import rearrange, repeat, reduce

from utils.configuration import Configuration
from utils.parallel import run_parallel, DatasetPartition
from utils.scheduled_sampling import ExponentialSampler
from utils.io import model_path
from utils.optimizers import SDAdam, SDAMSGrad, RAdam
from model.loci import Loci
import torch.distributed as dist
import time
import random
import nn as nn_modules
from utils.io import Timer, BinaryStatistics, UEMA
from utils.data import DeviceSideDataset
from utils.loss import SSIMLoss, MSSSIMLoss
from nn.latent_classifier import CaterLocalizer, CaterObjectBehavior
from nn.tracker import TrackingLoss, L1GridDistance, L2TrackingDistance
from einops import rearrange, repeat, reduce

from evaluation.plot_utils import color_mask
import lpips
from skimage.metrics import structural_similarity as ssimloss
from skimage.metrics import peak_signal_noise_ratio as psnrloss

def save_model(
    file,
    net, 
    optimizer_init, 
    optimizer_encoder, 
    optimizer_decoder, 
    optimizer_predictor,
    optimizer_background
):

    state = { }

    state['optimizer_init'] = optimizer_init.state_dict()
    state['optimizer_encoder'] = optimizer_encoder.state_dict()
    state['optimizer_decoder'] = optimizer_decoder.state_dict()
    state['optimizer_predictor'] = optimizer_predictor.state_dict()
    state['optimizer_background'] = optimizer_background.state_dict()

    state["model"] = net.state_dict()
    th.save(state, file)

def load_model(
    file,
    cfg,
    net, 
    optimizer_init, 
    optimizer_encoder, 
    optimizer_decoder, 
    optimizer_predictor,
    optimizer_background,
    load_optimizers = True
):
    device = th.device(cfg.device)
    state = th.load(file, map_location=device)
    print(f"load {file} to device {device}")
    print(f"load optimizers: {load_optimizers}")

    if load_optimizers:
        optimizer_init.load_state_dict(state[f'optimizer_init'])
        for n in range(len(optimizer_init.param_groups)):
            optimizer_init.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
        for n in range(len(optimizer_encoder.param_groups)):
            optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
        for n in range(len(optimizer_decoder.param_groups)):
            optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_predictor.load_state_dict(state['optimizer_predictor'])
        for n in range(len(optimizer_predictor.param_groups)):
            optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_background.load_state_dict(state['optimizer_background'])
        for n in range(len(optimizer_background.param_groups)):
            optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate

    # load full model?
    if True:

        rand_state = net.state_dict()
        for k, v in rand_state.items():
            if k not in state["model"]:
                state["model"][k] = v

        # backward compatibility
        model = {}
        for key, value in state["model"].items():
            model[key.replace(".module.", ".")] = value

        net.load_state_dict(model)

    else:

        rand_state = net.state_dict()
        for k, v in rand_state.items():
            if k not in state["model"]:
                state["model"][k] = v

        # backward compatibility
        model = {}
        for key, value in state["model"].items():
            key = key.replace(".module.", ".")
            module = key.split(".")[0]
            # remove module prefix but keep rest of key
            key = ".".join(key.split(".")[1:])
            if module not in model.keys():
                model[module] = {}
            model[module][key] = value

        for key, value in model.items():
            if ('encoder' in  key) or ('decoder' in key):
                net._modules[key].load_state_dict(value)

    pass

def run(cfg: Configuration, num_gpus: int, trainset: Dataset, valset: Dataset, testset: Dataset, file, active_layer):
    print("run training", flush=True)
    if num_gpus == 1:
        train_eprop(cfg.device, -1, cfg, trainset, valset, testset, file, active_layer)
    else:
        run_parallel(train_eprop, num_gpus, cfg, trainset, valset, testset, file, active_layer)

def train_eprop(rank: int, world_size: int, cfg: Configuration, trainset: Dataset, valset: Dataset, testset: Dataset, file, active_layer):

    print(f'rank {rank} online', flush=True)
    if th.cuda.is_available():
        device = th.device(rank)
        verbose = False
    else:
        device = th.device("cpu")
        verbose = True
        cfg.device = "cpu"
        cfg.model.device = "cpu"
        cfg.model.batch_size = 2
        cfg.teacher_forcing = 4
        print('!!! USING CPU !!!')

    if world_size < 0:
        rank = 0

    path = None
    if rank == 0:
        path = model_path(cfg, overwrite=False)
        cfg.save(path)
    os.makedirs(os.path.join(path, 'nets'), exist_ok=True)

    cfg_net = cfg.model

    background = None
    if world_size > 0:
        dataset = DatasetPartition(dataset, rank, world_size)

    net = Loci(
        cfg                = cfg_net,
        camera_view_matrix = trainset.cam if cfg.datatype == 'cater' else None,
        zero_elevation     = trainset.z if cfg.datatype == 'cater' else None,
        teacher_forcing    = cfg.teacher_forcing
    )
    net = net.to(device=device)

    l1distance = L1GridDistance(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None
    l2distance = L2TrackingDistance(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None
    l2loss     = TrackingLoss(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None

    if rank == 0:
        print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True)
        print(f'  States:     {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True)
        print(f'  Encoder:    {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True)
        print(f'  Decoder:    {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True)
        print(f'  predictor:  {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True)
        print(f'  background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True)
        print("\n")
    
    bceloss    = nn.BCELoss()
    mseloss    = nn.MSELoss()
    msssimloss = MSSSIMLoss()
    l1loss     = nn.L1Loss()

    optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30)
    optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate)
    optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
    optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
    optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate) # only train mask
    optimizer_update = RAdam(net.update_module.parameters(), lr = cfg.learning_rate)

    if file != "":
        load_model(
            file,
            cfg,
            net,
            optimizer_init, 
            optimizer_encoder, 
            optimizer_decoder, 
            optimizer_predictor,
            optimizer_background,
            cfg.load_optimizers
        )
        print(f'loaded[{rank}] {file}', flush=True)

    trainloader = DataLoader(
        trainset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = True,
        drop_last = True, 
        prefetch_factor = cfg.prefetch_factor, 
        persistent_workers = True
    )

    if cfg.datatype == 'adept':
        testset.train = True
    testloader = DataLoader(
        testset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = False,
        drop_last = True, 
        prefetch_factor = cfg.prefetch_factor, 
        persistent_workers = True,
    )

    net_parallel = None
    if world_size > 0:
        net_parallel = DistributedDataParallel(net, device_ids=[rank], find_unused_parameters=True)
    else:
        net_parallel = net
    
    if rank == 0:
        save_model(
            os.path.join(path, 'nets', 'net0.pt'),
            net, 
            optimizer_init, 
            optimizer_encoder, 
            optimizer_decoder, 
            optimizer_predictor,
            optimizer_background
        )

    num_updates = cfg.num_updates
    if num_updates > 0:
        print('!!! Start training at num_updates: ', num_updates)
        print('!!! Net init status: ', net.get_init_status())

    num_time_steps = 0
    closed_loop_steps = cfg.closed_loop_steps
    increase_closed_loop_steps = False
    background_blendin_factor = 0.0

    avgloss                  = UEMA()
    avg_position_loss        = UEMA()
    avg_time_loss            = UEMA()
    avg_object_loss          = UEMA()
    avg_encoder_loss         = UEMA()
    avg_mse_object_loss      = UEMA()
    avg_long_mse_object_loss = UEMA(33333) #100 * (cfg.sequence_len - 1 + cfg.teacher_forcing))
    avg_num_objects          = UEMA()
    avg_openings             = UEMA()
    avg_l1_distance          = UEMA()
    avg_l2_distance          = UEMA()
    avg_top1_accuracy        = UEMA()
    avg_tracking_loss        = UEMA()
    avg_gestalt              = UEMA()
    avg_gestalt2             = UEMA()
    avg_gestalt_mean         = UEMA()
    avg_update_gestalt       = UEMA()
    avg_update_position      = UEMA()

    th.backends.cudnn.benchmark = True
    timer = Timer()

    if num_updates >= cfg.background_pretraining_end and net.get_init_status() < 1:
        net.inc_init_level()

    if num_updates >= cfg.entity_pretraining_phase1_end and net.get_init_status() < 2:
        net.inc_init_level()

    if num_updates >= cfg.entity_pretraining_phase2_end and net.get_init_status() < 3:
        net.inc_init_level()
        for param in optimizer_init.param_groups:
            param['lr'] = cfg.learning_rate

    if num_updates > cfg.start_closed_prediction:
        net.cfg.closed_prediction = True

    if num_updates >= cfg.entity_pretraining_phase1_end:
        background_blendin_factor = max(min((num_updates - cfg.entity_pretraining_phase1_end)/30000, 1.0), 0.0)

    print('Start training')
    for epoch in range(cfg.epochs):
        if epoch > 0:
            if cfg.datatype == 'cater':
                cater_evaluation(net_parallel, 'Test', testset, testloader, device, cfg, epoch)
            elif cfg.datatype == 'adept':
                adept_evaluation(testloader, net, cfg, device)
            #elif cfg.datatype == 'clevrer':
            #    clevrer_evaluation(testloader, net, cfg, device)

        print('Start epoch:', epoch)

        if increase_closed_loop_steps:
            closed_loop_steps = max(closed_loop_steps + 1, cfg.closed_loop_steps_max)
            print('Increase closed loop steps to', closed_loop_steps)
        increase_closed_loop_steps = False

        for batch_index, input in enumerate(trainloader):

            # input and background 
            tensor          = input[0]
            background      = input[1].to(device)
            shuffleslots    = (num_updates <= cfg.shufleslots_end)
            target_position = input[2].float().to(device) if cfg.datatype == 'cater' else None
            target_label    = input[3].to(device) if cfg.datatype == 'cater' else None

            # run complete forward pass to get a position estimation
            position    = None
            gestalt     = None
            priority    = None
            mask        = None
            object      = None
            maskraw     = None
            loss        = th.tensor(0)
            summed_loss = None
            slots_occlusionfactor = None
            closed_loop_allowed = None

            # apply skip frames
            selec = range(random.randrange(cfg.skip_frames), tensor.shape[1], cfg.skip_frames)
            tensor = tensor[:,selec]
            sequence_len = tensor.shape[1]

            # initial frame
            input      = tensor[:,0].to(device)

            # initial frame as target --> teacher forcing
            input_next = input
            target     = th.clip(input, 0, 1).detach()
            pos_target = target_position[:,0] if cfg.datatype == 'cater' else None
            error_last = None

            # first apply teacher forcing for the first x frames
            for t in range(-cfg.teacher_forcing, sequence_len-1):

                # set update scheme
                if t >= cfg.closed_loop_start:
                    t_run = (t - cfg.closed_loop_start)
                    closed_loop    = t_run % closed_loop_steps != 0
                    run_optimizers = t_run % closed_loop_steps == closed_loop_steps - 1
                    detach         = (t_run % closed_loop_steps == 0) or t == -cfg.teacher_forcing
                else:
                    closed_loop    = False
                    run_optimizers = True
                    detach         = True

                if verbose:
                    print(f't: {t}, closed_loop: {closed_loop}, run_optimizers: {run_optimizers}, detach: {detach}')

                # unfold the video sequence --> next frame
                if t >= 0:

                    # try to predict next frame
                    num_time_steps += 1
                    input      = input_next
                    input_next = tensor[:,t+1].to(device)
                    target     = th.clip(input_next, 0, 1)
                    pos_target = target_position[:,t+1] if cfg.datatype == 'cater' else None

                    # error dropout
                    if net.get_init_status() > 2 and cfg.error_dropout > 0 and np.random.rand() < cfg.error_dropout:
                        error_last = th.zeros_like(error_last)

                    if net.cfg.closed_prediction and cfg.datatype == 'clevrer':
                        if t >= 10:
                            blackout    = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device)
                            input       = blackout * (input * 0)         + (1-blackout) * input
                            error_last  = blackout * (error_last * 0)    + (1-blackout) * error_last
                        
                # object learning
                (
                    output_next, 
                    output_cur,
                    position, 
                    gestalt, 
                    priority, 
                    mask, 
                    maskraw,
                    object, 
                    background, 
                    slots_occlusionfactor,
                    position_loss,
                    object_loss,
                    time_loss,
                    slots_closed
                    ) = net_parallel(
                    input,      # current frame
                    error_last, # error of last frame --> missing object
                    mask,       # masks of current frame
                    maskraw,    # raw masks of current frame
                    position,   # positions of objects of next frame
                    gestalt,    # gestalt of objects of next frame
                    priority,   # priority of objects of next frame
                    background,
                    slots_occlusionfactor,
                    reset = (t == -cfg.teacher_forcing), # new sequence
                    warmup = (t < 0),                    # teacher forcing
                    detach = detach,
                    shuffleslots = True, #(t <= 0) or shuffleslots, # TODO
                    reset_mask = (t <= 0),
                    clean_slots = (t <= 0 and not shuffleslots),
                )

                snitch_position = net.predictor.get_snitch_position() if cfg.datatype == 'cater' and t >= 0 else None

                # losses
                position_loss = position_loss * cfg_net.position_regularizer
                object_loss   = object_loss   * cfg_net.object_regularizer
                time_loss     = time_loss     * cfg_net.time_regularizer
                tracking_loss = l2loss(snitch_position, pos_target) * cfg_net.supervision_factor if cfg.datatype == 'cater' and t >= 0 else 0

                # background error
                bg_error_cur  = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                # prediction error
                error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error_next    = th.sqrt(error_next) * bg_error_next
                error_last    = error_next

                # detach as no gradient flow needed
                if run_optimizers:
                    position = position.detach()
                    gestalt  = gestalt.detach()
                    maskraw  = maskraw.detach()
                    object   = object.detach()
                    priority = priority.detach()

                mask = mask.detach()

                # number of objects
                avg_openings.update(net.get_openings())

                # plots for online evaluation
                if num_updates % 20000 == 0 and run_optimizers:

                    plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
                    if not os.path.exists(plot_path):
                        os.makedirs(plot_path, exist_ok=True)
                    
                        # highlight error
                        grayscale        = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
                        object_mask_cur  = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
                        highlited_input  = grayscale * (1 - object_mask_cur)
                        highlited_input += grayscale * object_mask_cur * 0.3333333
                        cmask = color_mask(mask[:,:-1])
                        highlited_input  = highlited_input + cmask * 0.6666666

                        cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
                        cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
                        cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
                        cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255)
                        cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
                        cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)

                if t == sequence_len - 2 and cfg.datatype == 'cater':
                    l1, top1, top5, _ = l1distance(snitch_position, target_label)
                    l2                = th.mean(l2distance(snitch_position, target_position[:,-1])[0])

                    avg_l1_distance.update(l1.item())
                    avg_l2_distance.update(l2.item())
                    avg_top1_accuracy.update(top1.item())
                    
                
                # track statistics
                if t >= cfg.statistics_offset:
                    _bg_error_next = bg_error_next
                    _bg_error_cur  = bg_error_cur
                    if cfg.datatype == 'cater':
                        _bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                        _bg_error_cur  = th.sqrt(reduce((input  - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                    # full loss: econder + prediction + decoder
                    loss_next = mseloss(output_next * _bg_error_next, target * _bg_error_next)

                    # encoder loss: encoder -> decoder
                    loss_cur  = mseloss(output_cur * _bg_error_cur,   input  * _bg_error_cur)

                    # update averages
                    avg_position_loss.update(position_loss.item())
                    avg_time_loss.update(time_loss.item())
                    avg_object_loss.update(object_loss.item())
                    avg_encoder_loss.update(loss_cur.item())
                    avg_mse_object_loss.update(loss_next.item())
                    avg_long_mse_object_loss.update(loss_next.item())
                    avg_update_gestalt.update(slots_closed[:,:,0].mean().item())
                    avg_update_position.update(slots_closed[:,:,1].mean().item())
                    if cfg.datatype == 'cater': 
                        avg_tracking_loss.update(tracking_loss.item())

                    # area of foreground mask
                    avg_num_objects.update(th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item())

                    # difference in shape
                    _gestalt  = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)),    'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
                    _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
                    max_mask     = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()

                    avg_gestalt.update( (th.sum(_gestalt  * max_mask) / (1e-16 + th.sum(max_mask))).item())
                    avg_gestalt2.update((th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))).item())
                    avg_gestalt_mean.update(th.mean(th.clip(gestalt, 0, 1)).item())

                # clip predction
                cliped_output_next = th.clip(output_next, 0, 1)

                # focus on foreground learning
                if background_blendin_factor < 1:
                    fg_mask_next = th.gt(bg_error_next, 0.1).float().detach()
                    fg_mask_next[fg_mask_next == 0] = background_blendin_factor
                    target       = th.clip(target * fg_mask_next, 0, 1)

                    fg_mask_cur = th.gt(bg_error_cur, 0.1).float().detach()
                    fg_mask_cur[fg_mask_cur == 0] = background_blendin_factor
                    input       = th.clip(input * fg_mask_cur, 0, 1)

                    if num_updates % 30 == 0 and num_updates >= cfg.entity_pretraining_phase1_end:
                        background_blendin_factor = min(1, background_blendin_factor + 0.001)

                # TODO set to zero for black input frames during closed loop 
                encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer

                # combined loss
                loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + object_loss + time_loss + tracking_loss
                
                # accumulate loss
                if summed_loss is None:
                    summed_loss = loss
                else:
                    summed_loss = summed_loss + loss

                if run_optimizers:

                    # zero grad
                    optimizer_init.zero_grad()
                    optimizer_encoder.zero_grad()
                    optimizer_decoder.zero_grad()
                    optimizer_predictor.zero_grad()
                    optimizer_background.zero_grad()
                    if net.cfg.closed_prediction:
                        optimizer_update.zero_grad()

                    # optimize
                    summed_loss.backward()
                    optimizer_init.step()
                    optimizer_encoder.step()
                    optimizer_decoder.step()
                    optimizer_predictor.step()
                    optimizer_background.step()
                    if net.cfg.closed_prediction:
                        optimizer_update.step()

                    num_updates += 1
                    summed_loss = None

                    if num_updates == cfg.background_pretraining_end and net.get_init_status() < 1:
                        net.inc_init_level()

                    if num_updates == cfg.entity_pretraining_phase1_end and net.get_init_status() < 2:
                        net.inc_init_level()

                    if num_updates == cfg.entity_pretraining_phase2_end and net.get_init_status() < 3:
                        net.inc_init_level()
                        for param in optimizer_init.param_groups:
                            param['lr'] = cfg.learning_rate

                    if num_updates == cfg.start_closed_prediction:
                        print('Start closed predictions')
                        net.cfg.closed_prediction = True

                    if (cfg.increase_closed_loop_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.increase_closed_loop_steps_every == 0) and ((num_updates-cfg.num_updates) > 0):
                        increase_closed_loop_steps = True

                avgloss.update(loss.item())

                if num_updates % 100 == 0 and run_optimizers:
                    print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Loss: {np.abs(float(avgloss)):.2e}|{float(avg_mse_object_loss):.2e}|{float(avg_long_mse_object_loss):.2e}, reg: {float(avg_encoder_loss):.2e}|{float(avg_time_loss):.2e}|{float(avg_position_loss):.2e},  i: {net.get_init_status() + net.initial_states.init.get():.2f}, obj: {float(avg_num_objects):.1f}, open: {float(avg_openings):.2e}|{float(avg_gestalt):.2f}, bin: {float(avg_gestalt_mean):.2e}|{np.sqrt(float(avg_gestalt2) - float(avg_gestalt)**2):.2e} closed: {float(avg_update_gestalt):.2e}|{float(avg_update_position):.2e} Blendin:{float(background_blendin_factor)}', flush=False)

                if cfg.datatype == 'cater':
                    print("CATER[{}/{}/{}/{}]: {}, {}, Loss: {:.2e}|{:.2e}|{:.2e}, reg: {:.2e}|{:.2e}|{:.2e}, snitch: {:.2e}|{:.2f}|{:.2f}|{:.2f}, i: {:.2f}, obj: {:.1f}, openings: {:.2e}".format(
                        num_updates,
                        num_time_steps,
                        sequence_len,
                        1,
                        str(timer),
                        epoch + 1,
                        np.abs(float(avgloss)),
                        float(avg_mse_object_loss),
                        float(avg_long_mse_object_loss),
                        float(avg_object_loss),
                        float(avg_time_loss),
                        float(avg_position_loss),
                        float(avg_tracking_loss),
                        float(avg_top1_accuracy),
                        float(avg_l1_distance),
                        float(avg_l2_distance),
                        net.get_init_status(),
                        float(avg_num_objects),
                        float(avg_openings),
                    ), flush=True)

                if num_updates > cfg.updates:
                    save_model(
                        os.path.join(path, 'nets', 'net_final.pt'),
                        net, 
                        optimizer_init, 
                        optimizer_encoder, 
                        optimizer_decoder, 
                        optimizer_predictor,
                        optimizer_background
                    )
                    print("Training finished")
                    return

                if num_updates % 50000 == 0 and run_optimizers:
                    save_model(
                        os.path.join(path, 'nets', f'net_{num_updates}.pt'),
                        net, 
                        optimizer_init, 
                        optimizer_encoder, 
                        optimizer_decoder, 
                        optimizer_predictor,
                        optimizer_background
                    )


def adept_evaluation(testloader: DataLoader, net: Loci, cfg: Configuration, device):

    # memory
    mseloss = nn.MSELoss()
    avgloss = 0
    start_time = time.time()

    with th.no_grad():
        for i, input in enumerate(testloader):

            # get input frame and target frame
            tensor = input[0].float().to(device)
            background_fix  = input[1].to(device)

            # apply skip frames
            tensor = tensor[:,range(0, tensor.shape[1], cfg.skip_frames)]
            sequence_len = tensor.shape[1]

            # initial frame
            input  = tensor[:,0]
            target = th.clip(tensor[:,0], 0, 1)
            error_last  = None

            # placehodlers
            mask_cur       = None
            mask_last      = None
            maskraw_last   = None
            position_last  = None
            gestalt_last   = None
            priority_last  = None
            gt_positions_target = None
            slots_occlusionfactor = None

            # loop through frames
            for t_index,t in enumerate(range(-cfg.teacher_forcing, sequence_len-1)):

                # move to next frame
                t_run = max(t, 0)
                input  = tensor[:,t_run]
                target = th.clip(tensor[:,t_run+1], 0, 1)

                # obtain prediction
                (
                    output_next, 
                    position_next, 
                    gestalt_next, 
                    priority_next, 
                    mask_next, 
                    maskraw_next,
                    object_next, 
                    background, 
                    slots_occlusionfactor,
                    output_cur,
                    position_cur,
                    gestalt_cur,
                    priority_cur,
                    mask_cur,
                    maskraw_cur,
                    object_cur,
                    position_encoder_cur,
                    slots_bounded,
                    slots_partially_occluded_cur,   
                    slots_occluded_cur,
                    slots_partially_occluded_next,
                    slots_occluded_next,
                    slots_closed,
                    output_hidden,
                    largest_object,
                    maskraw_hidden,
                    object_hidden
                ) = net(
                    input, 
                    error_last,
                    mask_last, 
                    maskraw_last,
                    position_last, 
                    gestalt_last,
                    priority_last,
                    background_fix,
                    slots_occlusionfactor,
                    reset = (t == -cfg.teacher_forcing),
                    evaluate=True,
                    warmup = (t < 0),
                    shuffleslots = False,
                    reset_mask = (t <= 0),
                    allow_spawn = True,
                    show_hidden = False,
                    clean_slots = (t <= 0),
                )

                # 1. Track error
                if t >= 0:
                    loss = mseloss(output_next, target)
                    avgloss += loss.item()

                # 2. Remember output
                mask_last     = mask_next.clone()
                maskraw_last  = maskraw_next.clone()
                position_last = position_next.clone()
                gestalt_last  = gestalt_next.clone()
                priority_last = priority_next.clone()
                        
                # 3. Error for next frame
                # background error
                bg_error_cur  = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                # prediction error
                error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error_next    = th.sqrt(error_next) * bg_error_next
                error_last    = error_next.clone()

    print(f"Test loss: {avgloss / len(testloader.dataset):.2e}, Time: {time.time() - start_time}")
            
    pass

def cater_evaluation(net_parallel, prefix, dataset, dataloader, device, cfg, epoch):

    net = net_parallel

    timer = Timer()

    bceloss    = nn.BCELoss()
    mseloss    = nn.MSELoss()
    msssimloss = MSSSIMLoss()

    l1distance = L1GridDistance(dataset.cam, dataset.z).to(device)
    l2distance = L2TrackingDistance(dataset.cam, dataset.z).to(device)
    l2loss     = TrackingLoss(dataset.cam, dataset.z).to(device)

    avgloss = 0
    avg_position_loss = 0
    avg_position_sum = 1e-30
    avg_time_loss = 0
    avg_time_sum = 1e-30
    avg_object_loss = 0
    avg_object_sum = 1e-30
    avg_object_loss2 = 0
    avg_object_sum2 = 1e-30
    avg_object_loss3 = 0
    avg_object_sum3 = 1e-30
    avg_num_objects = 0
    avg_num_objects_sum = 1e-30
    avg_openings = 0
    avg_openings_sum = 1e-30
    avg_l1_distance = 0
    avg_l2_distance = 0
    avg_top1_accuracy = 0
    avg_top5_accuracy = 0
    avg_tracking_sum = 1e-30
    avg_tracking_loss = 0
    avgsum = 1e-30
    min_loss = 100000
    num_updates = 0
    num_time_steps = 0
    l2_contained = []
    for t in range(cfg.sequence_len):
        l2_contained.append([])

    with th.no_grad():
        for batch_index, input in enumerate(dataloader):
            
            tensor           = input[0]
            background       = input[1].to(device)
            target_position  = input[2].float().to(device)
            target_label     = input[3].to(device)
            snitch_contained = input[4].to(device)
            snitch_contained_time = th.zeros_like(snitch_contained)

            snitch_contained_time[:,0] = snitch_contained[:,0]
            for t in range(1, snitch_contained.shape[1]):
                snitch_contained_time[:,t] = snitch_contained_time[:,t-1] * snitch_contained[:,t] + snitch_contained[:,t]

            snitch_contained_time = snitch_contained_time.long()
                

            # run complete forward pass to get a position estimation
            position    = None
            gestalt     = None
            priority    = None
            mask        = None
            object      = None
            maskraw     = None
            loss        = th.tensor(0)
            summed_loss = None
            slots_occlusionfactor = None

            sequence_len = (tensor.shape[1]-1) 

            input      = tensor[:,0].to(device)
            input_next = input
            target     = th.clip(input, 0, 1).detach()
            pos_target = target_position[:,0] if cfg.datatype == 'cater' else None
            error_last = None

            for t in range(-cfg.teacher_forcing, sequence_len):
                
                # unfold the video sequence --> next frame
                if t >= 0:

                    # try to predict next frame
                    num_time_steps += 1
                    input      = input_next
                    input_next = tensor[:,t+1].to(device)
                    target     = th.clip(input_next, 0, 1)
                    pos_target = target_position[:,t+1] if cfg.datatype == 'cater' else None

                # object learning
                (
                    output_next, 
                    output_cur,
                    position, 
                    gestalt, 
                    priority, 
                    mask, 
                    maskraw,
                    object, 
                    background, 
                    slots_occlusionfactor,
                    position_loss,
                    object_loss,
                    time_loss,
                    slots_closed
                    ) = net_parallel(
                    input,      # current frame
                    error_last, # error of last frame --> missing object
                    mask,       # masks of current frame
                    maskraw,    # raw masks of current frame
                    position,   # positions of objects of next frame
                    gestalt,    # gestalt of objects of next frame
                    priority,   # priority of objects of next frame
                    background,
                    slots_occlusionfactor,
                    reset = (t == -cfg.teacher_forcing), # new sequence
                    warmup = (t < 0),                    # teacher forcing
                    detach = False,
                    shuffleslots = False,
                    reset_mask = (t <= 0),
                    clean_slots = (t <= 0),
                )

                snitch_position = net.predictor.get_snitch_position() if cfg.datatype == 'cater' and t >= 0 else None

                init = max(0, min(1, net.get_init_status()))

                # background error
                bg_error_cur  = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                # prediction error
                error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error_next    = th.sqrt(error_next) * bg_error_next
                error_last    = error_next

                if t >= 0:

                    l2, _ = l2distance(snitch_position, target_position[:,t+1])
                    for b in range(snitch_contained_time.shape[0]):
                        c = snitch_contained_time[b,t].item()
                        if c > 0.5:
                            l2_contained[c].append(l2[b].detach().item())


                if t == sequence_len - 1:
                    l1, top1, top5, label = l1distance(snitch_position, target_label)
                    l2, world_position    = l2distance(snitch_position, target_position[:,-1])
                    l2                    = th.mean(l2)

                    batch_size = label.shape[0]
                    for n in range(batch_size):
                        print(f"Sample {batch_size * batch_index + n:04d}: ", end="")
                        print(f"Label: {label[n].item():02d}, Target: {target_label[n].long().item():02d}, ", end="")
                        print(f"World-Position: {world_position[n,0].item():.5f}|{world_position[n,1].item():.5f}|{world_position[n,2].item():.5f}, ", end="")
                        print(f"Target-Position: {target_position[n,-1,0].item():.5f}|{target_position[n,-1,1].item():.5f}|{target_position[n,-1,2].item():.5f}, ", flush=True)

                    avg_l1_distance = avg_l1_distance + l1.item()
                    avg_l2_distance = avg_l2_distance + l2.item()
                    avg_top1_accuracy = avg_top1_accuracy + top1.item()
                    avg_top5_accuracy = avg_top5_accuracy + top5.item()
                    avg_tracking_sum  = avg_tracking_sum + 1
                
                if t >= cfg.statistics_offset:
                    loss = mseloss(output_next * bg_error_next, target * bg_error_next)
                    avg_position_sum  = avg_position_sum + 1
                    avg_object_sum3   = avg_object_sum3 + 1
                    avg_time_sum      = avg_time_sum + 1
                    avg_object_loss   = avg_object_loss + loss.item()
                    avg_object_sum    = avg_object_sum + 1
                    avg_object_loss2  = avg_object_loss2 + loss.item()
                    avg_object_sum2   = avg_object_sum2 + 1

                    n_objects = th.mean(reduce((reduce(mask, 'b c h w -> b c', 'max') > 0.5).float()[:,1:], 'b c -> b', 'sum')).item()
                    avg_num_objects = avg_num_objects + n_objects
                    avg_num_objects_sum = avg_num_objects_sum + 1


                loss = bceloss(th.clip(output_next, 0, 1), target)

                avgloss  = avgloss + loss.item()
                avgsum   = avgsum + 1

                num_updates += 1
                print("{}[{}/{}/{}]: {}, Loss: {:.2e}|{:.2e}, snitch:, {:.2e}, top1: {:.4f}, top5: {:.4f}, L1:, {:.6f}, L2: {:.6f}, i: {:.2f}, obj: {:.1f}, openings: {:.2e}".format(
                    prefix,
                    num_updates,
                    num_time_steps,
                    sequence_len * len(dataloader),
                    str(timer),
                    np.abs(avgloss/avgsum),
                    avg_object_loss/avg_object_sum,
                    avg_tracking_loss / avg_time_sum,
                    avg_top1_accuracy / avg_tracking_sum * 100,
                    avg_top5_accuracy / avg_tracking_sum * 100,
                    avg_l1_distance / avg_tracking_sum,
                    avg_l2_distance  / avg_tracking_sum,
                    net.get_init_status(),
                    avg_num_objects / avg_num_objects_sum,
                    avg_openings / avg_openings_sum,
                ), flush=True)

    print("\\addplot+[mark=none,name path=quantil9,trialcolorb,opacity=0.1,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.9):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil75,trialcolorb,opacity=0.3,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.75):0.4f})")
    print("};")

    print("\\addplot+[mark=none,trialcolorb,thick,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.5):0.4f})")
    print("};")
    print("\\addplot+[mark=none,trialcolorb,thick,dotted,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.mean(data):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil25,trialcolorb,opacity=0.3,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.25):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil1,trialcolorb,opacity=0.1,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.1):0.4f})")
    print("};")
    print("\\addplot[trialcolorb,opacity=0.1] fill between[of=quantil9 and quantil1];")
    print("\\addplot[trialcolorb,opacity=0.2] fill between[of=quantil75 and quantil25];")


def clevrer_evaluation(testloader: DataLoader, net: Loci, cfg: Configuration, device):

    # memory
    mseloss = nn.MSELoss()
    lpipsloss = lpips.LPIPS(net='vgg').to(device)
    avgloss_mse = 0
    avgloss_lpips = 0
    avgloss_psnr = 0
    avgloss_ssim = 0
    start_time = time.time()

    burn_in_length = 6
    rollout_length = 42

    with th.no_grad():
        for i, input in enumerate(testloader):

            # get input frame and target frame
            tensor = input[0].float().to(device)
            background_fix  = input[1].to(device)

            # apply skip frames
            tensor = tensor[:,range(0, tensor.shape[1], cfg.skip_frames)]
            sequence_len = tensor.shape[1]

            # initial frame
            input  = tensor[:,0]
            target = th.clip(tensor[:,0], 0, 1)
            error_last  = None

            # placehodlers
            mask_cur       = None
            mask_last      = None
            maskraw_last   = None
            position_last  = None
            gestalt_last   = None
            priority_last  = None
            gt_positions_target = None
            slots_occlusionfactor = None

            # loop through frames
            for t_index,t in enumerate(range(-cfg.teacher_forcing, min(burn_in_length + rollout_length-1, sequence_len-1))):

                # move to next frame
                t_run = max(t, 0)
                input  = tensor[:,t_run]
                if t_run >= burn_in_length:
                    blackout = True
                    input = input * 0
                    error_last = error_last * 0
                target = th.clip(tensor[:,t_run+1], 0, 1)

                # obtain prediction
                (
                    output_next, 
                    position_next, 
                    gestalt_next, 
                    priority_next, 
                    mask_next, 
                    maskraw_next,
                    object_next, 
                    background, 
                    slots_occlusionfactor,
                    output_cur,
                    position_cur,
                    gestalt_cur,
                    priority_cur,
                    mask_cur,
                    maskraw_cur,
                    object_cur,
                    position_encoder_cur,
                    slots_bounded,
                    slots_partially_occluded_cur,   
                    slots_occluded_cur,
                    slots_partially_occluded_next,
                    slots_occluded_next,
                    slots_closed,
                    output_hidden,
                    largest_object,
                    maskraw_hidden,
                    object_hidden
                ) = net(
                    input, 
                    error_last,
                    mask_last, 
                    maskraw_last,
                    position_last, 
                    gestalt_last,
                    priority_last,
                    background_fix,
                    slots_occlusionfactor,
                    reset = (t == -cfg.teacher_forcing),
                    evaluate=True,
                    warmup = (t < 0),
                    shuffleslots = False,
                    reset_mask = (t <= 0),
                    allow_spawn = True,
                    show_hidden = False,
                    clean_slots = (t <= 0),
                )

                # 1. Track error
                if t >= 0:
                    loss_mse    = mseloss(output_next, target)
                    loss_ssim   = np.sum([ssimloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), channel_axis=0,gaussian_weights=True,sigma=1.5,use_sample_covariance=False,data_range=1) for i in range(output_next.shape[0])]),
                    loss_psnr   = np.sum([psnrloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), data_range=1)  for i in range(output_next.shape[0])]),
                    loss_lpips  = th.sum(lpipsloss(output_next*2-1, target*2-1))

                    avgloss_mse += loss_mse.item()
                    avgloss_ssim += loss_ssim[0].item()
                    avgloss_psnr += loss_psnr[0].item()
                    avgloss_lpips += loss_lpips.item()

                # 2. Remember output
                mask_last     = mask_next.clone()
                maskraw_last  = maskraw_next.clone()
                position_last = position_next.clone()
                gestalt_last  = gestalt_next.clone()
                priority_last = priority_next.clone()
                        
                # 3. Error for next frame
                # background error
                bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                # prediction error
                error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error_next    = th.sqrt(error_next) * bg_error_next
                error_last    = error_next.clone()

    print(f"MSE loss: {avgloss_mse / len(testloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(testloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(testloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(testloader.dataset):.2e}, Time: {time.time() - start_time}")
            
    pass