import pickle
import torch as th
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn

import os
from evaluation.plot_utils import color_slots, compute_occlusion_mask, get_highlighted_input, plot_error_batch, plot_object_view, plot_online_error, plot_online_error_slots, preprocess, preprocess_multi, to_rgb, write_image
from utils.configuration import Configuration
from model.loci import Loci
from utils.utils import LambdaModule, Gaus2D, Prioritize, Shape2D, Vector2D
import numpy as np
import cv2
from einops import rearrange, repeat, reduce
import motmetrics as mm
from copy import deepcopy
import pandas as pd

def save(cfg: Configuration, dataset: Dataset, file, active_layer, size, object_view = False, nice_view = False, individual_views = False, trace_view = False, concept = '', mota = False, latent_trace = False, net = None, plot_frequency= 1, plot_first_samples = 2, num_samples=1000):

    # set model configurations
    cfg_net = cfg.model
    if th.cuda.is_available():
        device = th.device(cfg.device)
        verbose = False
    else:
        device = th.device("cpu")
        verbose = True
        cfg.device = "cpu"
        cfg.model.device = "cpu"
        print('!!! USING CPU !!!')

    cfg_net.batch_size = 1
    cfg_net.closed_prediction = True

    # create model 
    if net is None:
        net = load_model(cfg, cfg_net, dataset, file, device)
    net.eval()
    
    # creat eposition helpers
    gaus2d  = Gaus2D(size).to(device)
    vector2d = Vector2D(size).to(device)
    scale  = size[0] // (cfg_net.latent_size[0] * 2**active_layer)

    # init 
    # init = max(min(net.get_init_status() - 1, 1), 0)
    init = 1
    prioritize = Prioritize(cfg_net.num_objects).to(device)

    # config
    plot_error = True
    root_path = None

    # get evaluation sets
    set_test_array, evaluation_modes, errors_to_plot = get_evaluation_sets(dataset, concept)
    statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error': [],  'TE': []}
    statistics_complete = deepcopy(statistics_template)
    statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [],  'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': [], 'vanishing': []}
    acc_memory_complete = None  

    for set_test in set_test_array:

        for evaluation_mode in evaluation_modes:
            print(f'Start evaluation loop: {evaluation_mode}')

            # load data
            dataloader = DataLoader(
                Subset(dataset, set_test['samples'][:num_samples]), 
                num_workers = 1, 
                pin_memory = False, 
                batch_size = 1,
                shuffle = False
            )

            # memory
            mseloss = nn.MSELoss()
            root_path, plot_path = setup_result_folders(file, set_test, evaluation_mode, trace_view, object_view, individual_views, nice_view, plot_error, latent_trace, mota, errors_to_plot)
            acc_memory_eval = []
            num_plotted = 0

            with th.no_grad():
                for i, input in enumerate(dataloader):
                    print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)

                    # get input frame and target frame
                    tensor = input[0].float().to(device)
                    background_fix  = input[1].to(device)
                    gt_object_positions = input[3].to(device)
                    gt_object_visibility = input[4].to(device)
                    gt_occluder_mask = input[5].to(device)

                    # apply skip frames
                    gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], cfg.skip_frames)]
                    gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], cfg.skip_frames)]
                    tensor = tensor[:,range(0, tensor.shape[1], cfg.skip_frames)]
                    sequence_len = tensor.shape[1]

                    # initial frame
                    input  = tensor[:,0]
                    target = th.clip(tensor[:,0], 0, 1)
                    error_last  = None

                    # evaluation phase
                    evaluation_phase_start = round(set_test['start'][i]/cfg.skip_frames)
                    evaluation_phase_critical = None
                    printable = True
                    if "critical" in set_test and i < len(set_test['critical']):
                        evaluation_phase_critical = round(set_test['critical'][i]/cfg.skip_frames)
                    if "printable" in set_test and i < len(set_test['printable']):
                        printable = set_test['printable'][i]

                    # placehodlers
                    mask_cur       = None
                    mask_last      = None
                    maskraw_last   = None
                    position_last  = None
                    gestalt_last   = None
                    priority_last  = None
                    gt_positions_target = None
                    slots_occlusionfactor = None
                    association_table = th.ones(cfg_net.num_objects).to(device) * -1
                    if mota:
                        acc = mm.MOTAccumulator(auto_id=True)

                    # memory 
                    statistics_batch = deepcopy(statistics_template)
                    gestalt_memory = []
                    position_memory = []
                    slots_vanishing_memory = np.zeros(cfg_net.num_objects)

                    # loop through frames
                    for t_index,t in enumerate(range(-cfg.teacher_forcing, sequence_len-1)):

                        # move to next frame
                        t_run = max(t, 0)
                        input  = tensor[:,t_run]
                        target = th.clip(tensor[:,t_run+1], 0, 1)
                        gt_positions_target = gt_object_positions[:,t_run]
                        gt_positions_target_next = gt_object_positions[:,t_run+1]
                        gt_visibility_target = gt_object_visibility[:,t_run]
                        
                        # closed loop: use own prediction as input
                        if t >= evaluation_phase_start and evaluation_mode in ['closed_outer', 'closed_outer_error']:
                            input  = output_next
                            if evaluation_mode == 'closed_outer':
                                error_last  = th.zeros_like(error_last)

                        # obtain prediction
                        (
                            output_next, 
                            position_next, 
                            gestalt_next, 
                            priority_next, 
                            mask_next, 
                            maskraw_next,
                            object_next, 
                            background, 
                            slots_occlusionfactor,
                            output_cur,
                            position_cur,
                            gestalt_cur,
                            priority_cur,
                            mask_cur,
                            maskraw_cur,
                            object_cur,
                            position_encoder_cur,
                            slots_bounded,
                            slots_partially_occluded_cur,   
                            slots_occluded_cur,
                            slots_partially_occluded_next,
                            slots_occluded_next,
                            slots_closed,
                            output_hidden,
                            largest_object,
                            maskraw_hidden,
                            object_hidden
                        ) = net(
                            input, 
                            error_last,
                            mask_last, 
                            maskraw_last,
                            position_last, 
                            gestalt_last,
                            priority_last,
                            background_fix,
                            slots_occlusionfactor,
                            reset = (t == -cfg.teacher_forcing),
                            evaluate=True,
                            warmup = (t < 0),
                            shuffleslots = False,
                            reset_mask = (t <= 0),
                            allow_spawn = True,
                            show_hidden = True,
                            clean_slots = (t <= 0),
                        )

                        # 1. Track error
                        if t >= 0:

                            if latent_trace:
                                gestalt_memory.append(gestalt_cur.detach().cpu().numpy())
                                position_memory.append(position_cur.detach().cpu().numpy())

                            # position error: MSE between predicted position and target position
                            tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_net, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask)

                            statistics_batch = store_statistics(statistics_batch,
                                                                set_test['type'],
                                                                evaluation_mode,
                                                                set_test['samples'][i],
                                                                t,
                                                                mseloss(output_next, target).item(),
                                                                tracking_error)

                            # compute rawmask_size
                            rawmask_size = reduce(maskraw_next[:, :-1], 'b o h w-> b o', 'sum')
                            rawmask_size_hidden = reduce(maskraw_hidden[:, :-1], 'b o h w-> b o', 'sum')
                            mask_size = reduce(mask_next[:, :-1], 'b o h w-> b o', 'sum')

                            # compute slot-wise prediction error
                            output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
                            target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
                            slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')/mask_size

                            # chceck if objects vanishes:
                            objects_vanishing = th.abs(gt_positions_target[:,:,2] - gt_positions_target_next[:,:,2]) > 0.2
                            objects_vanishing = th.where(objects_vanishing.flatten())[0]
                            slots_vanishing = [(obj.item() in objects_vanishing) for obj in association_table[0]]
                            slots_vanishing_memory = slots_vanishing + slots_vanishing_memory

                            statistics_complete_slots = store_statistics(statistics_complete_slots,
                                                                        [set_test['type']] * cfg_net.num_objects,
                                                                        [evaluation_mode] * cfg_net.num_objects,
                                                                        [set_test['samples'][i]] * cfg_net.num_objects,
                                                                        [t] * cfg_net.num_objects,
                                                                        range(cfg_net.num_objects),
                                                                        tracking_error_perslot.cpu().numpy().flatten(),
                                                                        slots_visible.cpu().numpy().flatten().astype(int),
                                                                        slots_bounded.cpu().numpy().flatten().astype(int),
                                                                        slots_occluder.cpu().numpy().flatten().astype(int),
                                                                        slots_in_image.cpu().numpy().flatten().astype(int),
                                                                        slot_error.cpu().numpy().flatten(),
                                                                        mask_size.cpu().numpy().flatten(),
                                                                        rawmask_size.cpu().numpy().flatten(),
                                                                        rawmask_size_hidden.cpu().numpy().flatten(),
                                                                        slots_closed[:, :, 1].cpu().numpy().flatten(),
                                                                        slots_closed[:, :, 0].cpu().numpy().flatten(),
                                                                        association_table[0].cpu().numpy().flatten().astype(int),
                                                                        extend = True)

                            if mota:
                                acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bounded, cfg_net, gt_occluder_mask, slots_occluder, maskraw_next)

                        # 2. Remember output
                        mask_last     = mask_next.clone()
                        maskraw_last  = maskraw_next.clone()
                        position_last = position_next.clone()
                        gestalt_last  = gestalt_next.clone()
                        priority_last = priority_next.clone()
                        
                        # 3. Error for next frame
                        # background error
                        bg_error_cur  = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                        bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                        # prediction error
                        error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                        error_next    = th.sqrt(error_next) * bg_error_next
                        error_last    = error_next.clone()

                        # 4. plot preparation
                        if t % plot_frequency == 0 and num_plotted < plot_first_samples and t >= 0 and printable:

                            # compute postfix
                            if t == evaluation_phase_start:
                                postfix = '_s'
                            elif (evaluation_phase_critical is not None) and t == (evaluation_phase_critical):
                                postfix = '_c'
                            else:
                                postfix = ''

                            # compute plot content
                            highlighted_input = get_highlighted_input(input, mask_cur)
                            output = th.clip(output_next, 0, 1)
                            position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
                            velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects))      

                            # color slots
                            slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next)
                            position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur)
                            velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)

                            # compute occlusion
                            maskraw_next_original = maskraw_next.clone()
                            maskraw_cur_l, maskraw_next_l = compute_occlusion_mask(maskraw_cur, maskraw_next, mask_cur, mask_next, scale)
                            if True:
                                maskraw_cur_h, maskraw_next_h = compute_occlusion_mask(maskraw_cur, maskraw_hidden, mask_cur, mask_next, scale)
                                maskraw_cur_h[:,largest_object] = maskraw_cur_l[:,largest_object]
                                maskraw_next_h[:,largest_object] = maskraw_next_l[:,largest_object]
                                maskraw_cur = maskraw_cur_h
                                maskraw_next = maskraw_next_h

                                object_hidden[:, largest_object] = object_next[:, largest_object]
                                object_next = object_hidden
                            else:
                                maskraw_cur = maskraw_cur_l
                                maskraw_next = maskraw_next_l

                            # scale plot conetent
                            input, target, output, highlighted_input, object_next, object_cur, mask_next, background, error_next, output_hidden, bg_error_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, background, error_next, output_hidden, bg_error_next, scale=scale)

                            # convert to rgb
                            bg_error_next    = to_rgb(bg_error_next)

                            # reshape
                            object_next     = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
                            object_cur      = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
                            mask_next       = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')

                            if object_view:
                                num_objects = cfg_net.num_objects if concept == 'test' else 4
                                error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded)
                                error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
                                error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path)
                                error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path)
                                img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, error_next, object_cur, object_next, maskraw_cur, maskraw_next, position_cur2d, velocity_next2d, output, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object)
                                cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}{postfix}.jpg', img)

                            if individual_views:
                                # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']:
                                write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}{postfix}.jpg', error_next[0])
                                write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}{postfix}.jpg', input[0])
                                write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}{postfix}.jpg', mask_next[0,-1])

                                #write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}{postfix}.jpg', output_hidden[0])
                                #write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}{postfix}.jpg', output_next[0])


                                o = 2
                                position_next2d, maskraw_next_original = preprocess_multi(gaus2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects)), maskraw_next_original, scale=scale)
                                write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}{postfix}.jpg', object_next[0, o])
                                write_image(f'{plot_path}/individual/position/position-{i:04d}-{t_index:03d}{postfix}.jpg', position_next2d[o])
                                write_image(f'{plot_path}/individual/rawmask/rawmask-{i:04d}-{t_index:03d}{postfix}.jpg', maskraw_next_original[0, o][None])
                                write_image(f'{plot_path}/individual/mask/mask-{i:04d}-{t_index:03d}{postfix}.jpg', mask_next[0, o])
                                write_image(f'{plot_path}/individual/othermask/othermask-{i:04d}-{t_index:03d}{postfix}.jpg', th.sum(mask_next[0,:-1], dim=0) - mask_next[0, o])


                                
                    # fill jumping statistics
                    statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1))

                    # write the entries of gestalt_memory to a csv file, where each row is a timestep and each column is entry of a gestalt 
                    if latent_trace:
                        store_latent(gestalt_memory, position_memory, cfg_net.num_objects, plot_path, i)

                    if mota:
                        acc_memory_eval.append(acc)

                    # plot error over time for that batch/sample
                    if plot_error:
                        plot_error_batch(statistics_batch, evaluation_phase_critical, errors_to_plot, evaluation_phase_start, plot_path, i)

                    if printable:
                        num_plotted += 1

                    # store batch statistics in complete statistics
                    statistics_complete = append_statistics(statistics_complete, statistics_batch)

            if mota:
                mh = mm.metrics.create()
                summary = mh.compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
                summary['set'] = set_test['type']
                summary['evalmode'] = evaluation_mode
                if acc_memory_complete is None:
                    acc_memory_complete = summary.copy()
                else:
                    acc_memory_complete = pd.concat([acc_memory_complete, summary])
            
    print('-- Evaluation Done --')
    pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv')
    pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv')
    pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv')
    if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
        os.remove(f'{root_path}/tmp.jpg')
    pass

def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_net, gt_occluder_mask, slots_occluder, maskraw, ignore_occluder = False):

    # num objects
    num_objects = len(gt_positions[0])

    # get rid of batch dimension and priority dimension
    pos = rearrange(estimated_positions.detach()[0], '(o c) -> o c', o=cfg_net.num_objects)[:, :2]
    targets = gt_positions[0, :, :2] 

    # stretch positions to account for frame ratio, Specific for ADEPT!
    pos = th.clip(pos, -1, 1)
    pos[:, 0] = pos[:, 0] * 1.5
    targets[:, 0] = targets[:, 0] * 1.5
    
    # remove objects that are not in the image
    edge = 1
    in_image = th.cat([targets[:, 0] < (1.5 * edge), targets[:, 0] > (-1.5 * edge), targets[:, 1] < (1 * edge), targets[:, 1] > (-1 * edge)])
    in_image = th.all(rearrange(in_image, '(c o) -> o c', o=num_objects), dim=1)

    if ignore_occluder:
       in_image = (gt_occluder_mask[0] == 0) * in_image
    targets = targets[in_image]

    # test if position estimates in image
    in_image_pos = th.cat([pos[:, 0] < (1.5 * edge), pos[:, 0] > (-1.5 * edge), pos[:, 1] < (1 * edge), pos[:, 1] > (-1 * edge)])
    in_image_pos = th.all(rearrange(in_image_pos, '(c o) -> c o', o=cfg_net.num_objects), dim=0, keepdim=True)

    # only position estimates that are in image and bound
    maskraw_size = reduce(maskraw[:, :-1], 'b o h w-> b o', 'sum')
    m = (slots_bounded * in_image_pos * (maskraw_size > 100)).bool()
    if ignore_occluder:
        m = (m * (1 - slots_occluder)).bool()

    pos = pos[repeat(m, '1 o -> o 2')]
    pos = rearrange(pos, '(o c) -> o c', c = 2)

    # compute pairwise distances
    diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2)).item()
    C = mm.distances.norm2squared_matrix(targets.cpu().numpy(), pos.cpu().numpy(), max_d2=diagonal_length*0.1)

    # upadate accumulator
    acc.update( (th.where(in_image)[0]).cpu(), (th.where(m)[1]).cpu(), C)

    return acc

def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_net, slots_bounded, slots_occluded_cur, association_table,  gt_occluder_mask):

    # tracking utils
    pdist = nn.PairwiseDistance(p=2).to(position_cur.device)

    # 1. association of newly bounded slots to ground truth objects
    # num objects
    num_objects = len(gt_positions_target[0])

    # get rid of batch dimension and priority dimension
    pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_net.num_objects)[:, :2]
    targets = gt_positions_target[0, :, :2]

    # stretch positions to account for frame ratio, Specific for ADEPT!
    pos = th.clip(pos, -1, 1)
    pos[:, 0] = pos[:, 0] * 1.5
    targets[:, 0] = targets[:, 0] * 1.5
    diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2))

    # reshape and repeat for comparison
    pos = repeat(pos, 'o c -> (o r) c', r=num_objects)
    targets = repeat(targets, 'o c -> (r o) c', r=cfg_net.num_objects)

    # comparison
    distance = pdist(pos, targets)
    distance = rearrange(distance, '(o r) -> o r', r=num_objects)

    # find closest target for each slot
    distance = th.min(distance, dim=1, keepdim=True)

    # update association table
    slots_newly_bounded = slots_bounded * (association_table == -1) * (1-slots_occluded_cur)
    association_table = association_table * (1-slots_newly_bounded) + slots_newly_bounded * distance[1].T

    # 2. position error
    # get rid of batch dimension and priority dimension
    pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_net.num_objects)[:, :2]
    targets = gt_positions_target[0, :, :3]

    # stretch positions to account for frame ratio, Specific for ADEPT!
    pos[:, 0] = pos[:, 0] * 1.5
    targets[:, 0] = targets[:, 0] * 1.5

    # gather targets according to association table
    targets = targets[association_table.long()][0]

    # determine which slosts are within the image
    slots_in_image = th.cat([targets[:, 0] < 1.5, targets[:, 0] > -1.5, targets[:, 1] < 1, targets[:, 1] > -1, targets[:, 2] > 0])
    slots_in_image = rearrange(slots_in_image, '(c o) -> o c', o=cfg_net.num_objects)
    slots_in_image = th.all(slots_in_image, dim=1)

    # define which slots to consider for tracking error
    slots_to_track = slots_bounded * slots_in_image

    # compute position error
    targets = targets[:, :2]
    tracking_error_perslot = th.sqrt(th.sum((pos - targets)**2, dim=1))/diagonal_length
    tracking_error_perslot = tracking_error_perslot[None, :] * slots_to_track
    tracking_error = th.sum(tracking_error_perslot).item()/max(th.sum(slots_to_track).item(), 1)
    
    # compute which slots are visible
    visible_objects = th.where(gt_visibility_target[0] == 1)[0]
    slots_visible = th.tensor([[int(obj.item()) in visible_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
    slots_visible = slots_visible * slots_to_track

    # determine which objects are bound to the occluder
    occluder_objects = th.where(gt_occluder_mask[0] == 1)[0]
    slots_occluder = th.tensor([[int(obj.item()) in occluder_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
    slots_occluder = slots_occluder * slots_to_track

    return tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder

def append_statistics(memory1, memory2, ignore=[]):
    for key in memory1:
        if key not in ignore:
            memory2[key] = memory2[key] + memory1[key]
    return memory2

def store_statistics(memory, *args, extend=False):
    for i,key in enumerate(memory.keys()):
        if i >= len(args):
            break
        if extend:
            memory[key].extend(args[i])
        else:
            memory[key].append(args[i])
    return memory

def get_evaluation_sets(dataset, concept):

    if concept != 'test':
        evaluation_modes = ['open']
        errors_to_plot = ['image_error', 'TE']

        if concept in ['createdown']:
            printable = [1,2,3,4,5,10,11,13,14,15,16,17,18,19,21,23,25,26,27,31,32,33,45]
        else:
            printable = range(len(dataset.samples))
        
        suprise_mask = [(sample.case in [1])   for i,sample in enumerate(dataset.samples)]
        control_mask = [(sample.case in [0,3]) for i,sample in enumerate(dataset.samples)]
        
        # create test sets
        set_surprise = {"samples": np.where(suprise_mask)[0].tolist(), "start": np.zeros(np.sum(suprise_mask), dtype=int), "type": concept + '_surprise'}
        set_surprise['printable'] = [el in printable for el in set_surprise['samples']]
        set_control  = {"samples": np.where(control_mask)[0].tolist(), "start": np.zeros(np.sum(control_mask), dtype=int), "type": concept + '_control'}
        set_control['printable'] = [el in printable for el in set_control['samples']]
        set_test_array = [set_control, set_surprise]

    else:
        evaluation_modes = ['open'] # 'closed_outer'
        errors_to_plot = ['image_error', 'TE']

        if False:
            set_occluded = {"samples": [0,13,22,24,25,27,33,34,37,38,46,56,59,60,64,69,76,80,82,88,91,93], "start": [56,33,51,29,12,51,27,48,38,36,19,49,34,52,38,38,41,32,22,40,24,42], "critical": [76,54,99,67,57,76,80,72,63,64,28,82,61,84,63,65,83,53,56,62,74,70], "type": "occlusion"}
        else:
            n = len(dataset)
            set_occluded = {"samples": np.arange(n, dtype=int), "start": np.zeros(n, dtype=int), "critical": np.zeros(n, dtype=int), "type": "test"}
        set_test_array = [set_occluded]

    return set_test_array, evaluation_modes, errors_to_plot

def setup_result_folders(file, set_test, evaluation_mode, trace_view, object_view, individual_views, nice_view, plot_error, latent_trace, mota, errors_to_plot):

    net_name = file.split('/')[-1].split('.')[0]
    root_path = file.split('nets')[0]
    root_path = os.path.join(root_path, 'plots', net_name, set_test['type'])
    plot_path = os.path.join(root_path, evaluation_mode)
    
    # create directories
    #if os.path.exists(plot_path):
    #    shutil.rmtree(plot_path)
    os.makedirs(plot_path, exist_ok = True)
    if trace_view:
        os.makedirs(os.path.join(plot_path, 'trace'), exist_ok = True)
    if object_view:
        os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
    if individual_views:
        os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
        for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
            os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
    if nice_view:
        os.makedirs(os.path.join(plot_path, 'nice'), exist_ok = True)
    if plot_error:
        os.makedirs(os.path.join(root_path, 'errors'), exist_ok = True)
        os.makedirs(os.path.join(plot_path, 'errors'), exist_ok = True)
        for error_to_plot in errors_to_plot:
            os.makedirs(os.path.join(plot_path, 'errors', error_to_plot), exist_ok = True)
    if latent_trace:
        os.makedirs(os.path.join(plot_path, 'latent'), exist_ok = True)
    os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)

    # final directory
    plot_path = plot_path + '/'
    print(f"save plots to {plot_path}")

    return root_path, plot_path
            
def load_model(cfg, cfg_net, dataset, file, device):

    net = Loci(
        cfg_net,
        camera_view_matrix = dataset.cam if cfg.datatype == 'cater' else None,
        zero_elevation     = dataset.z if cfg.datatype == 'cater' else None,
        teacher_forcing    = cfg.teacher_forcing
    )

    # load model
    if file != '':
        print(f"load {file} to device {device}")
        state = th.load(file, map_location=device)

        # backward compatibility
        model = {}
        for key, value in state["model"].items():
            model[key.replace(".module.", ".")] = value

        net.load_state_dict(model)

    # ???
    if net.get_init_status() < 1:
        net.inc_init_level()
    
    # set network to evaluation mode
    net = net.to(device=device)

    return net

def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next):

    slots_bounded = th.squeeze(slots_bounded)[..., None,None,None]
    slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None]
    slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None]
    slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
    slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]

    return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next

def store_latent(gestalt_memory, position_memory, num_objects, plot_path, i):

    gestalt_memory = rearrange(np.array(gestalt_memory), 'l b (n c) -> l b n c', n = num_objects)
    position_memory = rearrange(np.array(position_memory),    'l b (n c) -> l b n c', n = num_objects)

    for n in range(0, num_objects):
        with open(f'{plot_path}latent/latent-{i:04d}-{n:02d}.pickle', "wb") as outfile:
            state = {
                "gestalt":  gestalt_memory[:,:, n],
                "position": position_memory[:,:, n],
            }
            pickle.dump(state, outfile)

    pass
