import torch as th
from torch import nn
from torch.utils.data import Dataset, DataLoader
import cv2

from torch.nn.parallel import DistributedDataParallel
import numpy as np
import os
from typing import Tuple, Union, List
from einops import rearrange, repeat, reduce

from utils.configuration import Configuration
from utils.parallel import run_parallel, DatasetPartition
from utils.loss import depth_smooth_loss
from utils.scheduled_sampling import ExponentialSampler
from utils.io import model_path
from utils.optimizers import SDAdam, SDAMSGrad, RAdam
from model.loci import Loci
import torch.distributed as dist
import time
import random
import nn as nn_modules
from utils.io import Timer, BinaryStatistics, UEMA
from utils.data import DeviceSideDataset
from nn.latent_classifier import CaterLocalizer, CaterObjectBehavior
from nn.tracker import TrackingLoss, L1GridDistance, L2TrackingDistance
from einops import rearrange, repeat, reduce

def save_model(
    file,
    net, 
    optimizer_init, 
    optimizer_encoder, 
    optimizer_decoder, 
    optimizer_predictor,
    optimizer_background
):

    state = { }

    state['optimizer_init'] = optimizer_init.state_dict()
    state['optimizer_encoder'] = optimizer_encoder.state_dict()
    state['optimizer_decoder'] = optimizer_decoder.state_dict()
    state['optimizer_predictor'] = optimizer_predictor.state_dict()
    state['optimizer_background'] = optimizer_background.state_dict()

    state["model"] = net.state_dict()
    th.save(state, file)
    print(f"saved {file}")

def load_model(
    file,
    cfg,
    net, 
    optimizer_init, 
    optimizer_encoder, 
    optimizer_decoder, 
    optimizer_predictor,
    optimizer_background,
    load_optimizers = True
):
    device = th.device(cfg.device)
    state = th.load(file, map_location=device)
    print(f"loaded {file} to {device}")
    print(f"load optimizers: {load_optimizers}")

    if load_optimizers:
        optimizer_init.load_state_dict(state[f'optimizer_init'])
        for n in range(len(optimizer_init.param_groups)):
            optimizer_init.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
        for n in range(len(optimizer_encoder.param_groups)):
            optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
        for n in range(len(optimizer_decoder.param_groups)):
            optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_predictor.load_state_dict(state['optimizer_predictor'])
        for n in range(len(optimizer_predictor.param_groups)):
            optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate

        optimizer_background.load_state_dict(state['optimizer_background'])
        for n in range(len(optimizer_background.param_groups)):
            optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate

    # backward compatibility
    model = {}
    for k, v in state['model'].items():
        model[k.replace("module.","")] = v

    net.load_state_dict(model)

def run(cfg: Configuration, num_gpus: int, trainset: Dataset, valset: Dataset, testset: Dataset, file, bg_file, object_file):
    print("run training", flush=True)
    if num_gpus == 1:
        train_eprop(cfg.device, -1, cfg, trainset, valset, testset, file, bg_file, object_file)
    else:
        run_parallel(train_eprop, num_gpus, cfg, trainset, valset, testset, file, bg_file, object_file)

def train_eprop(rank: int, world_size: int, cfg: Configuration, trainset: Dataset, valset: Dataset, testset: Dataset, file, bg_file, object_file):

    print(f'rank {rank} online', flush=True)
    device = th.device(rank)

    if world_size < 0:
        rank = 0

    path = None
    if rank == 0:
        path = model_path(cfg, overwrite=False)
        cfg.save(path)

    cfg_net = cfg.model
    sampler = ExponentialSampler(0.9999)

    background = None
    if world_size > 0:
        dataset = DatasetPartition(dataset, rank, world_size)

    net = Loci(
        cfg = cfg_net,
        camera_view_matrix = trainset.cam if cfg.datatype == 'cater' else None,
        zero_elevation     = trainset.z if cfg.datatype == 'cater' else None,
        teacher_forcing = cfg.teacher_forcing
    )

    net = net.to(device=device)

    l1distance = L1GridDistance(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None
    l2distance = L2TrackingDistance(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None
    l2loss     = TrackingLoss(trainset.cam, trainset.z).to(device) if cfg.datatype == 'cater' else None

    if rank == 0:
        print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True)
        print(f'  States:     {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True)
        print(f'  Encoder:    {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True)
        if cfg_net.encoder.hyper:
            print(f'    Base:     {sum([param.numel() for param in net.encoder.hyper_weights.parameters()]):7d} parameters', flush=True)
            print(f'    Hyper:    {net.encoder.num_hyper_weights():7d} parameters', flush=True)
        print(f'  Decoder:    {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True)
        if cfg_net.decoder.hyper:
            print(f'    Base:     {sum([param.numel() for param in net.decoder.hyper_weights.parameters()]):7d} parameters', flush=True)
            print(f'    Hyper:    {net.decoder.num_hyper_weights():7d} parameters', flush=True)
        print(f'  predictor:  {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True)
        print(f'  background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True)
        print("\n")
    
    bceloss = nn.BCELoss()
    mseloss = nn.MSELoss()
    l1loss  = nn.L1Loss()

    optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate)
    optimizer_encoder = RAdam(net.encoder.position_encoder.parameters() if cfg_net.encoder.gestalt_frozzen else net.encoder.parameters(), lr = cfg.learning_rate)
    optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
    optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
    optimizer_background = RAdam(net.background.parameters(), lr = cfg_net.background.learning_rate)

    if file != "":
        load_model(
            file,
            cfg,
            net,
            optimizer_init, 
            optimizer_encoder, 
            optimizer_decoder, 
            optimizer_predictor,
            optimizer_background,
            cfg.load_optimizers
        )
        print(f'loaded[{rank}] {file}', flush=True)

    if bg_file != "":
        state = th.load(bg_file, map_location=device)
        net.background.load_state_dict(state['model'])
        print(f"loaded background {bg_file} to {device}")

    if object_file != "":
        state = th.load(object_file, map_location=device)
        encoder = dict()
        decoder = dict()
        init    = dict()
        for k, v in state['model'].items():
            if k.startswith("encoder."):
                encoder[k.lstrip("encoder.")] = v
            elif k.startswith("decoder."):
                decoder[k.lstrip("decoder.")] = v
            else:
                init[k] = v

        net.encoder.load_state_dict(encoder)
        net.decoder.load_state_dict(decoder)
        net.initial_states.load_pretrained(init)

        print(f"loaded encoder, decoder {bg_file} to {device}")
            
    trainloader = DataLoader(
        trainset, 
        pin_memory = True, 
        num_workers = cfg.num_workers, 
        batch_size = cfg_net.batch_size, 
        shuffle = True,
        drop_last = True, 
        prefetch_factor = cfg.prefetch_factor, 
        persistent_workers = True
    )

    valloader  = None
    testloader = None
    if cfg.datatype == 'cater':
        valloader = DataLoader(
            valset, 
            pin_memory = True, 
            num_workers = cfg.num_workers, 
            batch_size = cfg_net.batch_size, 
            shuffle = False,
            drop_last = True, 
            prefetch_factor = cfg.prefetch_factor, 
            persistent_workers = True
        )
        testloader = DataLoader(
            testset, 
            pin_memory = True, 
            num_workers = cfg.num_workers, 
            batch_size = cfg_net.batch_size, 
            shuffle = False,
            drop_last = True, 
            prefetch_factor = cfg.prefetch_factor, 
            persistent_workers = True
        )

    net_parallel = None
    if world_size > 0:
        net_parallel = DistributedDataParallel(net, device_ids=[rank], find_unused_parameters=True)
    else:
        net_parallel = net

    if file == "" and cfg_net.decoder.hyper:
        net.decoder.pretrain(cfg_net.decoder.pretraining_steps, device)
    
    if rank == 0:
        save_model(
            os.path.join(path, 'net0.pt'),
            net, 
            optimizer_init, 
            optimizer_encoder, 
            optimizer_decoder, 
            optimizer_predictor,
            optimizer_background
        )

    #th.autograd.set_detect_anomaly(True)

    num_updates = 0
    num_time_steps = 0
    lr = cfg.learning_rate

    avgloss                  = UEMA()
    avg_position_loss        = UEMA()
    avg_time_loss            = UEMA()
    avg_depth_loss         = UEMA()
    avg_object_loss          = UEMA()
    avg_mse_object_loss      = UEMA()
    avg_long_mse_object_loss = UEMA(100 * (cfg.sequence_len - 1 + cfg.teacher_forcing))
    avg_num_objects          = UEMA()
    avg_openings             = UEMA()
    avg_l1_distance          = UEMA()
    avg_l2_distance          = UEMA()
    avg_top1_accuracy        = UEMA()
    avg_tracking_loss        = UEMA()
    avg_gestalt              = UEMA()
    avg_gestalt2             = UEMA()
    avg_gestalt_mean         = UEMA()

    th.backends.cudnn.benchmark = True
    timer = Timer()

    for epoch in range(cfg.epochs):
        if epoch > 0 and cfg.datatype == "cater":
        #    eval_net(net_parallel, 'Val',  valset,  valloader, device, cfg, epoch)
            eval_net(net_parallel, 'Test', testset, testloader, device, cfg, epoch)
        for batch_index, input in enumerate(trainloader):
            
            tensor_rgb        = input[0]
            tensor_depth      = input[1] if cfg.datatype != 'cater' else None
            static_background = input[1][:,:1].to(device) if cfg.datatype != 'asteroids' else input[2].to(device)
            target_position   = input[2].float().to(device) if cfg.datatype == 'cater' else None
            target_label      = input[3].to(device) if cfg.datatype == 'cater' else None

            bg_rgb      = input[2] if cfg_net.background.frozzen else None
            bg_depth    = input[3] if cfg_net.background.frozzen else None
            uncertainty = input[4] if cfg_net.background.frozzen else None

            if epoch == 0 and batch_index == 0 and not cfg_net.background.use:
                net_parallel.background.set_background(static_background[:,0])
                net_parallel.background.set_depth(th.zeros_like(static_background[:,0,:1]))

            # run complete forward pass to get a position estimation
            position    = None
            gestalt     = None
            priority    = None
            mask        = None
            object      = None
            output      = None
            loss        = th.tensor(0)

            sequence_len = (tensor_rgb.shape[1]-1) 

            input_rgb      = tensor_rgb[:,0].to(device)
            input_rgb_next = input_rgb
            target_rgb     = th.clip(input_rgb, 0, 1).detach()

            input_depth      = tensor_depth[:,0].to(device) if tensor_depth is not None else None
            input_depth_next = input_depth
            target_depth     = th.clip(input_depth, 0, 1).detach() if tensor_depth is not None else None

            bg_depth_cur  = bg_depth[:,0].to(device) if bg_depth is not None else None
            bg_depth_next = bg_depth_cur

            bg_rgb_cur  = bg_rgb[:,0].to(device) if bg_rgb is not None else None
            bg_rgb_next = bg_rgb_cur

            pos_target = target_position[:,0] if cfg.datatype == 'cater' else None
            error      = None
            
            uncertainty_cur  = net.background.uncertainty_estimation(th.cat((input_depth, input_rgb), dim=1))[0] if not cfg_net.background.frozzen else uncertainty[:,0].to(device)
            uncertainty_next = uncertainty_cur

            for t in range(-cfg.teacher_forcing, sequence_len):
                if t >= 0:
                    num_time_steps += 1
                
                if t >= 0:
                    input_rgb      = input_rgb_next
                    input_rgb_next = tensor_rgb[:,t+1].to(device)
                    target_rgb     = th.clip(input_rgb_next, 0, 1)

                    input_depth      = input_depth_next
                    input_depth_next = tensor_depth[:,t+1].to(device) if tensor_depth is not None else None
                    target_depth     = th.clip(input_depth_next, 0, 1) if tensor_depth is not None else None

                    bg_depth_cur  = bg_depth_next
                    bg_depth_next = bg_depth[:,t+1].to(device) if bg_depth is not None else None

                    bg_rgb_cur  = bg_rgb_next
                    bg_rgb_next = bg_rgb[:,t+1].to(device) if bg_rgb is not None else None

                    uncertainty_cur  = uncertainty_next
                    uncertainty_next = net.background.uncertainty_estimation(th.cat((input_depth, input_rgb), dim=1))[0] if not cfg_net.background.frozzen else uncertainty[:,t+1].to(device)

                    pos_target     = target_position[:,t+1] if cfg.datatype == 'cater' else None

                    
                output = net_parallel(
                    input_rgb     = input_rgb, 
                    input_depth   = input_depth,
                    error_last    = uncertainty_cur if cfg_net.background.use else error,
                    background    = bg_rgb_next,
                    bg_depth      = bg_depth_next,
                    bg_mask       = net.background.priority.expand(*uncertainty_next.shape),
                    mask_last     = mask,
                    position_last = position,
                    gestalt_last  = gestalt,
                    priority_last = priority,
                    reset         = (t == -cfg.teacher_forcing),
                )

                object_loss   = output['object_loss']   * cfg_net.object_regularizer
                position_loss = output['position_loss'] * cfg_net.position_regularizer
                time_loss     = output['time_loss']     * cfg_net.time_regularizer
                tracking_loss = l2loss(output['snitch_position'], pos_target) * cfg_net.supervision_factor if cfg.datatype == 'cater' else 0

                bg_error_cur  = th.sqrt(reduce((input_rgb  - output['background'])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                bg_error_next = th.sqrt(reduce((target_rgb - output['background'])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error         = th.sqrt(reduce((target_rgb - output['output_rgb'])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error         = th.sqrt(error) * bg_error_next

                mask     = output['mask'].detach()
                position = output['position'].detach()
                gestalt  = output['gestalt'].detach()
                priority = output['priority'].detach()

                fg_mask_cur   = (uncertainty_cur > 0.5).float().detach()  if cfg_net.background.use else th.gt(bg_error_cur, 0.1).float().detach()
                fg_mask_next  = (uncertainty_next > 0.5).float().detach() if cfg_net.background.use else th.gt(bg_error_next, 0.1).float().detach()

                cliped_output_rgb_next   = th.clip(output['output_rgb'], 0, 1)
                cliped_output_depth_next = th.clip(output['output_depth'], 0, 1)

                target_rgb_cliped = th.clip(target_rgb, 0, 1)
                input_rgb_cliped  = th.clip(input_rgb, 0, 1)
                if net.get_init_status() <= 0:
                    target_rgb_cliped = th.clip(target_rgb_cliped * fg_mask_next, 0, 1)
                    input_rgb_cliped  = th.clip(input_rgb_cliped  * fg_mask_cur , 0, 1)
                    
                    if target_depth is not None and cfg.datatype != 'asteroids': 
                        target_depth = th.clip(target_depth * fg_mask_next, 0, 1)
                        input_depth  = th.clip(input_depth  * fg_mask_cur , 0, 1)

                rgb_loss   = bceloss(cliped_output_rgb_next, target_rgb_cliped) * cfg_net.rgb_loss_factor
                depth_loss = bceloss(cliped_output_depth_next, target_depth) if target_depth is not None else depth_smooth_loss(cliped_output_depth_next, target_rgb) * 0.001

                loss = rgb_loss + depth_loss + position_loss + object_loss + time_loss + tracking_loss

                optimizer_init.zero_grad()
                optimizer_encoder.zero_grad()

                if not cfg_net.decoder.frozzen:
                    optimizer_decoder.zero_grad()

                optimizer_predictor.zero_grad()

                #if not cfg_net.background.frozzen:
                #    optimizer_background.zero_grad()

                loss.backward()

                optimizer_init.step()
                optimizer_encoder.step()

                if not cfg_net.decoder.frozzen:
                    optimizer_decoder.step()

                optimizer_predictor.step()

                #if not cfg_net.background.frozzen:
                #    optimizer_background.step()

                avg_openings.update(net.get_openings())

                if t == sequence_len - 1 and cfg.datatype == 'cater':
                    l1, top1, top5, _ = l1distance(output['snitch_position'], target_label)
                    l2                = th.mean(l2distance(output['snitch_position'], target_position[:,-1])[0])

                    avg_l1_distance.update(l1.item())
                    avg_l2_distance.update(l2.item())
                    avg_top1_accuracy.update(top1.item())
                
                if t >= cfg.statistics_offset:
                    static_bg_error_cur  = th.sqrt(reduce((input_rgb  - static_background[:,0])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                    static_bg_error_next = th.sqrt(reduce((target_rgb - static_background[:,0])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                    loss_next = mseloss(output['output_rgb'] * static_bg_error_next, target_rgb * static_bg_error_next)

                    avg_position_loss.update(position_loss.item())
                    avg_object_loss.update(object_loss.item())
                    avg_time_loss.update(time_loss.item())
                    avg_depth_loss.update(depth_loss.item())
                    avg_mse_object_loss.update(loss_next.item())
                    avg_long_mse_object_loss.update(loss_next.item())

                    if cfg.datatype == 'cater': 
                        avg_tracking_loss.update(tracking_loss.item())

                    avg_num_objects.update(th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item())

                    _gestalt  = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)),    'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
                    _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
                    max_mask     = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()

                    avg_gestalt.update( (th.sum(_gestalt  * max_mask) / (1e-16 + th.sum(max_mask))).item())
                    avg_gestalt2.update((th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))).item())
                    avg_gestalt_mean.update(th.mean(th.clip(gestalt, 0, 1)).item())

                avgloss.update(loss.item())

                num_updates += 1
                print("Epoch[{}]: {}, {}, L: {:.2e}|{:.2e}|{:.2e}, D: {:.2e}, R: {:.2e}|{:.2e}|{:.2e}, S:, {:.2e}|{:.2f}|{:.4f}|{:.4f}, i: {:d}|{:.2e}, obj: {:.1f}, open: {:.2e}, binar: {:.2e}|{:.2e}|{:.3f}".format(
                    num_updates,
                    str(timer),
                    epoch + 1,
                    np.abs(float(avgloss)),
                    float(avg_mse_object_loss),
                    float(avg_long_mse_object_loss),
                    float(avg_depth_loss),
                    float(avg_object_loss),
                    float(avg_time_loss),
                    float(avg_position_loss),
                    float(avg_tracking_loss),
                    float(avg_top1_accuracy)*100,
                    float(avg_l1_distance),
                    float(avg_l2_distance),
                    net.get_init_status(),
                    -(net.initial_states.init.get()-1),
                    float(avg_num_objects),
                    float(avg_openings),
                    float(avg_gestalt),
                    np.sqrt(float(avg_gestalt2) - float(avg_gestalt)**2),
                    float(avg_gestalt_mean),
                ), flush=True)

                if num_updates > cfg.updates:
                    if rank == 0:
                        save_model(
                            os.path.join(path, 'net.pt'),
                            net, 
                            optimizer_init, 
                            optimizer_encoder, 
                            optimizer_decoder, 
                            optimizer_predictor,
                            optimizer_background
                        )
                    return

            if batch_index % 10 == 0 or num_updates < 3000:
                if not np.isinf(loss.item()) and not np.isnan(loss.item()):
                    if rank == 0:
                        save_model(
                            os.path.join(path, 'net{}.{}.pt'.format(epoch, batch_index // 100)),
                            net, 
                            optimizer_init, 
                            optimizer_encoder, 
                            optimizer_decoder, 
                            optimizer_predictor,
                            optimizer_background
                        )


        if (epoch + 1) % 1 == 0 or (epoch < 10 and batch_index % 100 == 0):
            if not np.isinf(loss.item()) and not np.isnan(loss.item()):
                if rank == 0:
                    save_model(
                        os.path.join(path, 'net{}.pt'.format(epoch+1)),
                        net, 
                        optimizer_init, 
                        optimizer_encoder, 
                        optimizer_decoder, 
                        optimizer_predictor,
                        optimizer_background
                    )


def eval_net(net_parallel, prefix, dataset, dataloader, device, cfg, epoch):

    net = net_parallel

    timer = Timer()

    bceloss    = nn.BCELoss()
    mseloss    = nn.MSELoss()

    l1distance = L1GridDistance(dataset.cam, dataset.z).to(device)
    l2distance = L2TrackingDistance(dataset.cam, dataset.z).to(device)
    l2loss     = TrackingLoss(dataset.cam, dataset.z).to(device)

    avgloss = 0
    avg_position_loss = 0
    avg_position_sum = 1e-30
    avg_time_loss = 0
    avg_time_sum = 1e-30
    avg_object_loss = 0
    avg_object_sum = 1e-30
    avg_object_loss2 = 0
    avg_object_sum2 = 1e-30
    avg_object_loss3 = 0
    avg_object_sum3 = 1e-30
    avg_num_objects = 0
    avg_num_objects_sum = 1e-30
    avg_openings = 0
    avg_openings_sum = 1e-30
    avg_l1_distance = 0
    avg_l2_distance = 0
    avg_top1_accuracy = 0
    avg_top5_accuracy = 0
    avg_tracking_sum = 1e-30
    avg_tracking_loss = 0
    avgsum = 1e-30
    min_loss = 100000
    num_updates = 0
    num_time_steps = 0
    l2_contained = []
    for t in range(cfg.sequence_len):
        l2_contained.append([])

    with th.no_grad():
        for batch_index, input in enumerate(dataloader):
            
            tensor            = input[0]
            static_background = input[1].to(device)
            target_position   = input[2].float().to(device)
            target_label      = input[3].to(device)
            snitch_contained  = input[4].to(device)
            snitch_contained_time = th.zeros_like(snitch_contained)

            if batch_index == 0:
                net_parallel.background.set_background(static_background[:,0])
                net_parallel.background.set_depth(th.ones_like(static_background[:,0,:1]))
            
            snitch_contained_time[:,0] = snitch_contained[:,0]
            for t in range(1, snitch_contained.shape[1]):
                snitch_contained_time[:,t] = snitch_contained_time[:,t-1] * snitch_contained[:,t] + snitch_contained[:,t]

            snitch_contained_time = snitch_contained_time.long()
                

            # run complete forward pass to get a position estimation
            position    = None
            gestalt     = None
            priority    = None
            mask        = None
            object      = None
            net.sampler = None
            output      = None
            loss        = th.tensor(0)

            sequence_len = (tensor.shape[1]-1) 

            input      = tensor[:,0].to(device)
            input_next = input
            target     = th.clip(input, 0, 1).detach()
            pos_target = target_position[:,0] if cfg.datatype == 'cater' else None
            error      = None

            for t in range(-cfg.teacher_forcing, sequence_len):
                if t >= 0:
                    num_time_steps += 1
                
                if t >= 0:
                    input      = input_next
                    input_next = tensor[:,t+1].to(device)
                    target     = th.clip(input_next, 0, 1)
                    pos_target = target_position[:,t+1] if cfg.datatype == 'cater' else None

                output = net_parallel(
                    input_rgb     = input, 
                    input_depth   = None,
                    error_last    = error,
                    mask_last     = mask,
                    position_last = position,
                    gestalt_last  = gestalt,
                    priority_last = priority,
                    reset         = (t == -cfg.teacher_forcing),
                    test          = True
                )

                init = max(0, min(1, net.get_init_status()))

                bg_error = th.sqrt(reduce((target - output['background'])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error    = th.sqrt(reduce((target - output['output_rgb'])**2, 'b c h w -> b 1 h w', 'mean')).detach()
                error    = th.sqrt(error) * bg_error

                mask     = output['mask'].detach()
                position = output['position'].detach()
                gestalt  = output['gestalt'].detach()
                priority = output['priority'].detach()

                if t >= 0:

                    l2, _ = l2distance(output['snitch_position'], target_position[:,t+1])
                    for b in range(snitch_contained_time.shape[0]):
                        c = snitch_contained_time[b,t].item()
                        if c > 0.5:
                            l2_contained[c].append(l2[b].detach().item())


                if t == sequence_len - 1:
                    l1, top1, top5, label = l1distance(output['snitch_position'], target_label)
                    l2, world_position    = l2distance(output['snitch_position'], target_position[:,-1])
                    l2                    = th.mean(l2)

                    batch_size = label.shape[0]
                    for n in range(batch_size):
                        print(f"Sample {batch_size * batch_index + n:04d}: ", end="")
                        print(f"Label: {label[n].item():02d}, Target: {target_label[n].long().item():02d}, ", end="")
                        print(f"World-Position: {world_position[n,0].item():.5f}|{world_position[n,1].item():.5f}|{world_position[n,2].item():.5f}, ", end="")
                        print(f"Target-Position: {target_position[n,-1,0].item():.5f}|{target_position[n,-1,1].item():.5f}|{target_position[n,-1,2].item():.5f}, ", flush=True)

                    avg_l1_distance = avg_l1_distance + l1.item()
                    avg_l2_distance = avg_l2_distance + l2.item()
                    avg_top1_accuracy = avg_top1_accuracy + top1.item()
                    avg_top5_accuracy = avg_top5_accuracy + top5.item()
                    avg_tracking_sum  = avg_tracking_sum + 1
                
                if t >= cfg.statistics_offset:
                    loss = mseloss(output['output_rgb'] * bg_error, target * bg_error)
                    avg_position_sum  = avg_position_sum + 1
                    avg_object_sum3   = avg_object_sum3 + 1
                    avg_time_sum      = avg_time_sum + 1
                    avg_object_loss   = avg_object_loss + loss.item()
                    avg_object_sum    = avg_object_sum + 1
                    avg_object_loss2  = avg_object_loss2 + loss.item()
                    avg_object_sum2   = avg_object_sum2 + 1

                    n_objects = th.mean(reduce((reduce(mask, 'b c h w -> b c', 'max') > 0.5).float()[:,1:], 'b c -> b', 'sum')).item()
                    avg_num_objects = avg_num_objects + n_objects
                    avg_num_objects_sum = avg_num_objects_sum + 1

                fg_mask = th.gt(bg_error, 0.1).float().detach()

                cliped_output = th.clip(output['output_rgb'], 0, 1)
                target        = th.clip(target * fg_mask * (1 - init) + target * init, 0, 1)

                loss = bceloss(cliped_output, target)

                avgloss  = avgloss + loss.item()
                avgsum   = avgsum + 1

                num_updates += 1
                print("{}[{}/{}/{}]: {}, Loss: {:.2e}|{:.2e}, snitch:, {:.2e}, top1: {:.4f}, top5: {:.4f}, L1:, {:.6f}, L2: {:.6f}, i: {:.2f}, obj: {:.1f}, openings: {:.2e}".format(
                    prefix,
                    num_updates,
                    num_time_steps,
                    sequence_len * len(dataloader),
                    str(timer),
                    np.abs(avgloss/avgsum),
                    avg_object_loss/avg_object_sum,
                    avg_tracking_loss / avg_time_sum,
                    avg_top1_accuracy / avg_tracking_sum * 100,
                    avg_top5_accuracy / avg_tracking_sum * 100,
                    avg_l1_distance / avg_tracking_sum,
                    avg_l2_distance  / avg_tracking_sum,
                    net.get_init_status(),
                    avg_num_objects / avg_num_objects_sum,
                    avg_openings / avg_openings_sum,
                ), flush=True)

    print("\\addplot+[mark=none,name path=quantil9,trialcolorb,opacity=0.1,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.9):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil75,trialcolorb,opacity=0.3,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.75):0.4f})")
    print("};")

    print("\\addplot+[mark=none,trialcolorb,thick,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.5):0.4f})")
    print("};")
    print("\\addplot+[mark=none,trialcolorb,thick,dotted,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.mean(data):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil25,trialcolorb,opacity=0.3,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.25):0.4f})")
    print("};")

    print("\\addplot+[mark=none,name path=quantil1,trialcolorb,opacity=0.1,forget plot] plot coordinates {")
    for t in range(cfg.sequence_len):
        if len(l2_contained[t]) > 0:
            data = np.array(l2_contained[t])
            print(f"({t},{np.quantile(data, 0.1):0.4f})")
    print("};")
    print("\\addplot[trialcolorb,opacity=0.1] fill between[of=quantil9 and quantil1];")
    print("\\addplot[trialcolorb,opacity=0.2] fill between[of=quantil75 and quantil25];")

class L1GridDistanceTracker(nn.Module):
    def __init__(self):
        super(L1GridDistanceTracker, self).__init__()

    def forward(self, label, target_label):

        label        = label.long()
        target_label = target_label.long()

        x = label % 6
        y = label / 6

        target_x = target_label % 6
        target_y = target_label / 6
        
        return th.sum((th.abs(x - target_x) + th.abs(y - target_y)).float())

def train_latent_tracker(cfg: Configuration, trainset: Dataset, valset: Dataset, testset: Dataset, file):

    print(f'device {cfg.device} online', flush=True)
    device = th.device(cfg.device)

    path = model_path(cfg, overwrite=False)
    cfg.save(path)

    cfg_net = cfg.model

    net = CaterLocalizer(
        gestalt_size       = cfg_net.gestalt_size,
        net_width          = 1,
        num_objects        = cfg_net.num_objects,
        num_timesteps      = cfg.sequence_len,
        camera_view_matrix = trainset.cam,
        zero_elevation     = trainset.z
    )

    net = net.to(device=device)

    print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):6d} parameters', flush=True)

    mseloss = nn.MSELoss()
    cross_entropy_loss = nn.CrossEntropyLoss()
    l1grid = L1GridDistanceTracker()

    optimizer = th.optim.Adam(
        net.parameters(), 
        lr = cfg.learning_rate,
        amsgrad=True
    )

    l2_contained = []
    for t in range(cfg.sequence_len):
        l2_contained.append([])

    cam = np.array([
        (1.4503, 1.6376,  0.0000, -0.0251),
        (-1.0346, 0.9163,  2.5685,  0.0095),
        (-0.6606, 0.5850, -0.4748, 10.5666),
        (-0.6592, 0.5839, -0.4738, 10.7452)
    ])

    z = 0.3421497941017151

    l2distance = L2TrackingDistance(cam, z).to(device)

    if file != "":
        state = th.load(file)
        net.load_state_dict(state["model"])
        optimizer.load_state_dict(state["optimizer"])
        print(f'loaded {file}', flush=True)

            
    trainloader = DeviceSideDataset(trainset, device, cfg_net.batch_size)
    testloader  = DeviceSideDataset(testset, device, cfg_net.batch_size)
    valloader   = DeviceSideDataset(valset, device, cfg_net.batch_size)

    #trainloader = DataLoader(
    #    trainset, 
    #    pin_memory = True, 
    #    shuffle = True,
    #    drop_last = True, 
    #    num_workers = cfg.num_workers, 
    #    batch_size = cfg_net.batch_size, 
    #    prefetch_factor = cfg.prefetch_factor, 
    #    persistent_workers = True
    #)
    #testloader = DataLoader(
    #    testset, 
    #    pin_memory = True, 
    #    shuffle = True,
    #    drop_last = True, 
    #    num_workers = cfg.num_workers, 
    #    batch_size = cfg_net.batch_size, 
    #    prefetch_factor = cfg.prefetch_factor, 
    #    persistent_workers = True
    #)
    #valloader = DataLoader(
    #    valset, 
    #    pin_memory = True, 
    #    shuffle = True,
    #    drop_last = True, 
    #    num_workers = cfg.num_workers, 
    #    batch_size = cfg_net.batch_size, 
    #    prefetch_factor = cfg.prefetch_factor, 
    #    persistent_workers = True
    #)

    state = { }
    state['optimizer'] = optimizer.state_dict()
    state["model"] = net.state_dict()
    #th.save(state, os.path.join(path, 'net0.pt'))

    lr = cfg.learning_rate

    th.backends.cudnn.benchmark = True
    timer = Timer()

    #time_weight = th.arange(300, device=device).view(1,-1,1) / 300

    for epoch in range(cfg.epochs):

        avg_train_loss = 0
        avg_test_loss  = 0
        avg_val_loss  = 0

        avg_train_l2_distance = 0
        avg_test_l2_distance  = 0
        avg_val_l2_distance  = 0

        avg_train_l1_distance = 0
        avg_test_l1_distance  = 0
        avg_val_l1_distance  = 0

        avg_train_sum = 0
        avg_test_sum = 0
        avg_val_sum = 0

        avg_train_top1 = 0
        avg_test_top1 = 0
        avg_val_top1 = 0

        avg_train_top5 = 0
        avg_test_top5 = 0
        avg_val_top5 = 0

        net.train()
        for batch_index, data in enumerate(trainloader):

            tensor, positions, labels, contained, actions = data[:5]
            #tensor    = tensor.to(device)
            #positions = positions.to(device)
            #labels    = labels.to(device)
            #contained = contained.to(device)
            #actions   = actions.to(device)
                

            contained   = contained.squeeze(dim=2)
            ground_mask = 1 - actions[:,:,3]
            labels      = labels.int()

            out_visible, out_hidden, label_visible, label_hidden = net(tensor)
            output = out_visible * (1 - contained).unsqueeze(dim=2) + out_hidden * contained.unsqueeze(dim=2)
            out_label = label_visible * (1 - contained[:,-1:]) + label_hidden * contained[:,-1:]

            loss = th.mean((output - positions)**2) 

            avg_train_top1 += th.mean((th.argmax(out_label.detach(), dim=1) == labels).float()) / len(trainloader)
            avg_train_top5 += th.mean(th.sum((th.topk(out_label.detach(), 5, dim=1)[1] == labels.unsqueeze(dim=1)).float(), dim=1)) / len(trainloader)

            avg_train_loss        += loss.item()
            avg_train_l2_distance += th.sum(th.sqrt(th.sum((output[:,-1] - positions[:,-1])**2, dim=1)), dim=0).detach().item()

            avg_train_l1_distance += l1grid(th.argmax(out_label.detach(), dim=1), labels).item()
            avg_train_sum         += output.shape[0]

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with th.no_grad():
            net.eval()
            for batch_index, data in enumerate(testloader):

                tensor, positions, labels, contained, actions = data[:5]
                #tensor    = tensor.to(device)
                #positions = positions.to(device)
                #labels    = labels.to(device)
                #contained = contained.to(device)
                #actions   = actions.to(device)

                contained_time = th.zeros_like(contained)
                contained_time[:,0] = contained[:,0]
                for t in range(1, cfg.sequence_len):
                    contained_time[:,t] = contained_time[:,t-1] * contained[:,t] + contained[:,t]

                contained_time = contained_time.long()

                contained   = contained.squeeze(dim=2)
                ground_mask = 1 - actions[:,:,3]

                out_visible, out_hidden, label_visible, label_hidden = net(tensor)
                output = out_visible * (1 - contained).unsqueeze(dim=2) + out_hidden * contained.unsqueeze(dim=2)
                out_label = label_visible * (1 - contained[:,-1:]) + label_hidden * contained[:,-1:]

                loss = th.mean((output - positions)**2) 

                avg_test_top1 += th.mean((th.argmax(out_label, dim=1) == labels).float()) / len(testloader)
                avg_test_top5 += th.mean(th.sum((th.topk(out_label, 5, dim=1)[1] == labels.unsqueeze(dim=1)).float(), dim=1)) / len(testloader)

                avg_test_loss        += loss.item()
                avg_test_l2_distance += th.sum(th.sqrt(th.sum((output[:,-1] - positions[:,-1])**2, dim=1)), dim=0).detach().item()

                avg_test_l1_distance += l1grid(th.argmax(out_label.detach(), dim=1), labels).item()

                avg_test_sum         += output.shape[0] 

                for t in range(1, cfg.sequence_len):
                    l2 = th.sqrt(th.sum((output[:,t] - positions[:,t])**2, dim=1))
                    for b in range(contained_time.shape[0]):
                        c = contained_time[b,t].item()
                        if c > 0.5:
                            l2_contained[c].append(l2[b].detach().item())

            print("\\addplot+[mark=none,name path=quantil9,trialcolor,opacity=0.1,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.quantile(data, 0.9):0.4f})")
            print("};")

            print("\\addplot+[mark=none,name path=quantil75,trialcolor,opacity=0.3,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.quantile(data, 0.75):0.4f})")
            print("};")

            print("\\addplot+[mark=none,trialcolor,thick,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.quantile(data, 0.5):0.4f})")
            print("};")
            print("\\addplot+[mark=none,trialcolor,thick,dotted,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.mean(data):0.4f})")
            print("};")

            print("\\addplot+[mark=none,name path=quantil25,trialcolor,opacity=0.3,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.quantile(data, 0.25):0.4f})")
            print("};")

            print("\\addplot+[mark=none,name path=quantil1,trialcolor,opacity=0.1,forget plot] plot coordinates {")
            for t in range(cfg.sequence_len):
                if len(l2_contained[t]) > 0:
                    data = np.array(l2_contained[t])
                    print(f"({t},{np.quantile(data, 0.1):0.4f})")
            print("};")
            print("\\addplot[trialcolor,opacity=0.1] fill between[of=quantil9 and quantil1];")
            print("\\addplot[trialcolor,opacity=0.2] fill between[of=quantil75 and quantil25];")


                    #print(f"Contained[{t:03d}] {len(data):04d}: {np.quantile(data, 0.25):0.4f}|{np.quantile(data, 0.5):.4f}|{np.quantile(data, 0.75):.4f}")
            return

            for batch_index, data in enumerate(valloader):

                tensor, positions, labels, contained, actions = data[:5]
                #tensor    = tensor.to(device)
                #positions = positions.to(device)
                #labels    = labels.to(device)
                #contained = contained.to(device)
                #actions   = actions.to(device)
                
                contained   = contained.squeeze(dim=2)
                ground_mask = 1 - actions[:,:,3]

                out_visible, out_hidden, label_visible, label_hidden = net(tensor)
                output = out_visible * (1 - contained).unsqueeze(dim=2) + out_hidden * contained.unsqueeze(dim=2)
                out_label = label_visible * (1 - contained[:,-1:]) + label_hidden * contained[:,-1:]

                loss = th.mean((output - positions)**2)

                avg_val_top1 += th.mean((th.argmax(out_label, dim=1) == labels).float()) / len(valloader)
                avg_val_top5 += th.mean(th.sum((th.topk(out_label, 5, dim=1)[1] == labels.unsqueeze(dim=1)).float(), dim=1)) / len(valloader)

                avg_val_loss        += loss.item()
                avg_val_l2_distance += th.sum(th.sqrt(th.sum((output[:,-1] - positions[:,-1])**2, dim=1)), dim=0).detach().item()

                avg_val_l1_distance += l1grid(th.argmax(out_label.detach(), dim=1), labels).item()

                avg_val_sum         += output.shape[0] 

        np.set_printoptions(threshold=10000)
        np.set_printoptions(linewidth=np.inf)
        
        
        print("Epoch[{}/{}]: {}, Loss: {:.2e}|{:.2e}|{:.2e}, L2: {:.4f}|{:.4f}|{:.4f}, L1: {:.4f}|{:.4f}|{:.4f}, lr: {:.2e}, Top-1: {:.4f}%|{:.4f}%|{:.4f}%, Top-5: {:.4f}%|{:.4f}%|{:.4f}%".format(
            epoch + 1,
            cfg.epochs,
            str(timer),
            avg_train_loss/len(trainloader),
            avg_val_loss/len(valloader),
            avg_test_loss/len(testloader),
            avg_train_l2_distance / avg_train_sum,
            avg_val_l2_distance / avg_val_sum,
            avg_test_l2_distance / avg_test_sum,
            avg_train_l1_distance / avg_train_sum,
            avg_val_l1_distance / avg_val_sum,
            avg_test_l1_distance / avg_test_sum,
            lr,
            avg_train_top1 * 100,
            avg_val_top1 * 100,
            avg_test_top1 * 100,
            avg_train_top5 * 100,
            avg_val_top5 * 100,
            avg_test_top5 * 100,
        ), flush=True)

    state = { }
    state['optimizer'] = optimizer.state_dict()
    state["model"] = net.state_dict()
    th.save(state, os.path.join(path, 'net.pt'))
