import pickle
import torch as th
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn

import os
from evaluation.plot_utils import color_slots, compute_occlusion_mask, get_highlighted_input, plot_error_batch, plot_object_view, plot_online_error, plot_online_error_slots, preprocess_multi, to_rgb, write_image
from evaluation.vp_utils import masks_to_boxes, pred_eval_step, postproc_mask
from utils.configuration import Configuration
from model.loci import Loci
from utils.utils import LambdaModule, Gaus2D, Prioritize, Shape2D, Vector2D
import numpy as np
import cv2
from einops import rearrange, repeat, reduce
import motmetrics as mm
from copy import deepcopy
import pandas as pd
import lpips
from skimage.metrics import structural_similarity as ssimloss
from skimage.metrics import peak_signal_noise_ratio as psnrloss

import torchvision.transforms as transforms

def save(cfg: Configuration, dataset: Dataset, file, active_layer, size, object_view = False, nice_view = False, individual_views = False, trace_view = False, concept = '', mota = False, latent_trace = False, net = None, plot_frequency= 1, plot_first_samples = 2, num_samples=1000):

    # set model configurations
    cfg_net = cfg.model
    if th.cuda.is_available():
        device = th.device(cfg.device)
        verbose = False
    else:
        device = th.device("cpu")
        verbose = True
        cfg.device = "cpu"
        cfg.model.device = "cpu"
        print('!!! USING CPU !!!')

    batch_size = 1 if plot_first_samples > 0 else 32
    cfg_net.batch_size = batch_size

    # create model 
    if net is None:
        net = load_model(cfg, cfg_net, dataset, file, device)
    net.eval()
    
    # creat eposition helpers
    gaus2d  = Gaus2D(size).to(device)
    vector2d = Vector2D(size).to(device)
    scale  = size[0] // (cfg_net.latent_size[0] * 2**active_layer)

    # init 
    # init = max(min(net.get_init_status() - 1, 1), 0)
    init = 1
    prioritize = Prioritize(cfg_net.num_objects).to(device)

    # config
    plot_error = True
    root_path = None

    # get evaluation sets
    statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []}
    statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []}
    set_test_array, evaluation_modes, errors_to_plot = get_evaluation_sets(dataset, concept)
    metric_complete = None

    # Evaluation Specifics
    burn_in_length = 6
    rollout_length = 42
    cfg.skip_frames = 2
    blackout_p = 0.2
    target_size = (64,64)
    dataset.burn_in_length = burn_in_length
    dataset.rollout_length = rollout_length
    dataset.skip_length = cfg.skip_frames

    # transformation utils
    to_small        = transforms.Resize(target_size)
    to_normalize    = transforms.Normalize((0.5, ), (0.5, ))
    to_smallnorm    = transforms.Compose([to_small, to_normalize])

    # Losses
    lpipsloss = lpips.LPIPS(net='vgg').to(device)
    mseloss = nn.MSELoss()
    
    for set_test in set_test_array:

        for evaluation_mode in evaluation_modes:
            print(f'Start evaluation loop: {evaluation_mode}')

            # load data
            dataloader = DataLoader(
                dataset if plot_first_samples == 0 else Subset(dataset, range(plot_first_samples)), 
                num_workers = 1, 
                pin_memory = False, 
                batch_size = batch_size,
                shuffle = False,
                drop_last = True,
            )

            # memory
            root_path, plot_path = setup_result_folders(file, set_test, evaluation_mode, trace_view, object_view, individual_views, nice_view, plot_error, latent_trace, mota, errors_to_plot)
            metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'blackout': []}

            with th.no_grad():
                for i, input in enumerate(dataloader):
                    print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)

                    # get input frame and target frame
                    tensor = input[0].float().to(device)
                    background_fix = input[1].to(device)
                    gt_mask        = input[2].to(device)
                    gt_bbox        = input[3].to(device)
                    gt_pres_mask   = input[4].to(device)
                    #gt2            = input[5].to(device)

                    sequence_len = tensor.shape[1]

                    # initial frame
                    input  = tensor[:,0]
                    target = th.clip(tensor[:,0], 0, 1)
                    error_last  = None

                    # placehodlers
                    mask_cur       = None
                    mask_last      = None
                    maskraw_last   = None
                    position_last  = None
                    gestalt_last   = None
                    priority_last  = None
                    slots_occlusionfactor = None

                    # sltformer evaluation
                    pred        = th.zeros((batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
                    gt          = th.zeros((batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
                    pred_mask   = th.zeros((batch_size, rollout_length, target_size[0], target_size[1])).to(device)
                    statistics_batch = deepcopy(statistics_template)
                    num_rollout = 0
                    num_burnin  = 0
                    blackout_mem = [0]
 
                    # loop through frames
                    for t_index,t in enumerate(range(-cfg.teacher_forcing, sequence_len-1)):

                        # move to next frame
                        t_run = max(t, 0)
                        input  = tensor[:,t_run]
                        target = th.clip(tensor[:,t_run+1], 0, 1)

                        rollout_index = t_run - burn_in_length
                        if (rollout_index >= 0) and (evaluation_mode == 'blackout'):
                            #blackout    = (th.tensor(np.random.rand(cfg_net.batch_size) < blackout_p)[:,None,None,None]).float().to(device)

                            blackout    = (th.rand(1) < blackout_p).float().to(device)
                            input       = blackout * (input * 0)         + (1-blackout) * input
                            error_last  = blackout * (error_last * 0)    + (1-blackout) * error_last
                            blackout_mem.append(blackout.int().cpu().item())

                        elif t>=0:
                            num_burnin += 1  

                        if (rollout_index >= 0) and (evaluation_mode != 'blackout'):  
                            blackout_mem.append(0)    

                        # obtain prediction
                        (
                            output_next, 
                            position_next, 
                            gestalt_next, 
                            priority_next, 
                            mask_next, 
                            maskraw_next,
                            object_next, 
                            background, 
                            slots_occlusionfactor,
                            output_cur,
                            position_cur,
                            gestalt_cur,
                            priority_cur,
                            mask_cur,
                            maskraw_cur,
                            object_cur,
                            position_encoder_cur,
                            slots_bounded,
                            slots_partially_occluded_cur,   
                            slots_occluded_cur,
                            slots_partially_occluded_next,
                            slots_occluded_next,
                            slots_closed,
                            output_hidden,
                            largest_object,
                            maskraw_hidden,
                            object_hidden
                        ) = net(
                            input, 
                            error_last,
                            mask_last, 
                            maskraw_last,
                            position_last, 
                            gestalt_last,
                            priority_last,
                            background_fix,
                            slots_occlusionfactor,
                            reset = (t == -cfg.teacher_forcing),
                            evaluate=True,
                            warmup = (t < 0),
                            shuffleslots = False,
                            reset_mask = (t <= 0),
                            allow_spawn = True,
                            show_hidden = (batch_size == 1),
                            clean_slots = (t <= 0),
                        )

                        # 1. Track error for plots
                        if t >= 0:

                            # save prediction
                            if rollout_index >= -1:
                                pred[:,rollout_index+1] = to_smallnorm(output_next)
                                gt[:,rollout_index+1]   = to_smallnorm(target)
                                pred_mask[:,rollout_index+1] = postproc_mask(to_small(mask_next)[:,None,:,None])[:, 0] # th.argmax(to_small(mask_next), dim=1)  
                                num_rollout += 1

                            if batch_size == 1 and plot_object_view:           
                                statistics_batch = store_statistics(statistics_batch,
                                                                set_test['type'],
                                                                evaluation_mode,
                                                                set_test['samples'][i],
                                                                t,
                                                                mseloss(output_next, target).item()
                                                                )

                                # compute slot-wise prediction error
                                output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
                                target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
                                slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')

                                # compute rawmask_size
                                rawmask_size = reduce(maskraw_hidden[:, :-1], 'b o h w-> b o', 'sum')

                                statistics_complete_slots = store_statistics(statistics_complete_slots,
                                                                            [set_test['type']] * cfg_net.num_objects,
                                                                            [evaluation_mode] * cfg_net.num_objects,
                                                                            [set_test['samples'][i]] * cfg_net.num_objects,
                                                                            [t] * cfg_net.num_objects,
                                                                            range(cfg_net.num_objects),
                                                                            slots_bounded.cpu().numpy().flatten().astype(int),
                                                                            slot_error.cpu().numpy().flatten(),
                                                                            rawmask_size.cpu().numpy().flatten(),
                                                                            slots_closed[:, :, 1].cpu().numpy().flatten(),
                                                                            slots_closed[:, :, 0].cpu().numpy().flatten(),
                                                                            extend = True)
            
                        # 2. Remember output
                        mask_last     = mask_next.clone()
                        maskraw_last  = maskraw_next.clone()
                        position_last = position_next.clone()
                        gestalt_last  = gestalt_next.clone()
                        priority_last = priority_next.clone()
                        output_pred = output_next.clone()
                        
                        # 3. Error for next frame
                        # background error
                        bg_error_cur  = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                        bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()

                        # prediction error
                        error_next    = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
                        error_next    = th.sqrt(error_next) * bg_error_next
                        error_last    = error_next.clone()

                        # 4. plot preparation
                        if batch_size == 1 and t % plot_frequency == 0 and i < plot_first_samples and t >= -100:

                            # compute plot content
                            highlighted_input = get_highlighted_input(input, mask_cur)
                            output = th.clip(output_next, 0, 1)
                            position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
                            #position_next2d = gaus2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
                            velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects))      

                            # color slots
                            slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next)
                            position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur)
                            velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)

                            # compute occlusion
                            maskraw_cur, maskraw_next = compute_occlusion_mask(maskraw_cur, maskraw_next, mask_cur, mask_next, scale)

                            # scale plot conetent
                            input, target, output, highlighted_input, object_next, object_cur, mask_next, background, error_next, output_hidden, bg_error_next, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, background, error_next, output_hidden, bg_error_next, output_next, scale=scale)

                            # convert to rgb
                            bg_error_next    = to_rgb(bg_error_next)

                            # reshape
                            object_next     = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
                            object_cur      = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
                            mask_next       = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')

                            if object_view:
                                num_objects = cfg_net.num_objects
                                error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
                                error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path)
                                img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, error_next, object_cur, object_next, maskraw_cur, maskraw_next, position_cur2d, velocity_next2d, output, target, slots_closed, None, None, size, num_objects, largest_object)
                                cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img)
                                
                            if individual_views:
                                t_index = t_index - cfg.teacher_forcing
                                #write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}{postfix}.jpg', error_next[0])
                                write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', target[0])
                                #write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}{postfix}.jpg', mask_next[0,-1])
                                write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0])

                                if False:
                                    o = 2
                                    position_next2d, maskraw_next_original = preprocess_multi(gaus2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects)), maskraw_next_original, scale=scale)
                                    write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}{postfix}.jpg', object_next[0, o])
                                    write_image(f'{plot_path}/individual/position/position-{i:04d}-{t_index:03d}{postfix}.jpg', position_next2d[o])
                                    write_image(f'{plot_path}/individual/rawmask/rawmask-{i:04d}-{t_index:03d}{postfix}.jpg', maskraw_next_original[0, o][None])
                                    write_image(f'{plot_path}/individual/mask/mask-{i:04d}-{t_index:03d}{postfix}.jpg', mask_next[0, o])
                                    write_image(f'{plot_path}/individual/othermask/othermask-{i:04d}-{t_index:03d}{postfix}.jpg', th.sum(mask_next[0,:-1], dim=0) - mask_next[0, o])
                                    
                    # compute slotformer statistics per batch
                    if plot_first_samples == 0:
                        for b in range(batch_size):
                            metric_dict = pred_eval_step(
                                gt=gt[b:b+1],
                                pred=pred[b:b+1],
                                pred_mask=pred_mask.long()[b:b+1],
                                pred_bbox=masks_to_boxes(pred_mask.long()[b:b+1], cfg_net.num_objects+1),
                                gt_mask=gt_mask.long()[b:b+1],
                                gt_pres_mask=gt_pres_mask[b:b+1], 
                                gt_bbox=gt_bbox[b:b+1],
                                lpips_fn=lpipsloss,
                                eval_traj=True,
                            )
                            metric_dict['blackout'] = blackout_mem
                            metric_complete = append_statistics(metric_dict, metric_complete)

                    # sanity check
                    if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and (evaluation_mode == 'rollout'):
                        raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')

            if plot_first_samples == 0:
                average_dic = {}
                for key in metric_complete:

                    # take average over all frames
                    average_dic[key + 'complete_average'] = np.mean(metric_complete[key])
                    average_dic[key + 'complete_std']     = np.std(metric_complete[key])
                    print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}')

                    if evaluation_mode == 'blackout':
                        # take average only for frames where blackout occurs
                        blackout_mask = np.array(metric_complete['blackout']) > 0
                        average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
                        average_dic[key + 'blackout_std']     = np.std(np.array(metric_complete[key])[blackout_mask])
                        average_dic[key + 'visible_average']  = np.mean(np.array(metric_complete[key])[blackout_mask == False])
                        average_dic[key + 'visible_std']      = np.std(np.array(metric_complete[key])[blackout_mask == False])
    
                        print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
                        print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')

                with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
                    pickle.dump(metric_complete, f)
                with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f:
                    pickle.dump(average_dic, f)

    print('-- Evaluation Done --')
    if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
        os.remove(f'{root_path}/tmp.jpg')
    pass



def append_statistics(memory1, memory2, ignore=[]):
    for key in memory1:
        if key not in ignore:
            memory2[key].append(memory1[key])
    return memory2

def store_statistics(memory, *args, extend=False):
    for i,key in enumerate(memory.keys()):
        if i >= len(args):
            break
        if extend:
            memory[key].extend(args[i])
        else:
            memory[key].append(args[i])
    return memory

def get_evaluation_sets(dataset, concept):

    n = len(dataset)
    set = {"samples": np.arange(n, dtype=int), "start": np.zeros(n, dtype=int), "critical": np.zeros(n, dtype=int), "type": "test"}
    evaluation_modes = ['blackout','open'] #'open'
    errors_to_plot = ['image_error_mse']
    set_test_array = [set]

    return set_test_array, evaluation_modes, errors_to_plot

def setup_result_folders(file, set_test, evaluation_mode, trace_view, object_view, individual_views, nice_view, plot_error, latent_trace, mota, errors_to_plot):

    net_name = file.split('/')[-1].split('.')[0]
    root_path = file.split('nets')[0]
    root_path = os.path.join(root_path, 'plots', net_name, set_test['type'])
    plot_path = os.path.join(root_path, evaluation_mode)
    
    # create directories
    #if os.path.exists(plot_path):
    #    shutil.rmtree(plot_path)
    os.makedirs(plot_path, exist_ok = True)
    if trace_view:
        os.makedirs(os.path.join(plot_path, 'trace'), exist_ok = True)
    if object_view:
        os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
    if individual_views:
        os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
        for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
            os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
    if nice_view:
        os.makedirs(os.path.join(plot_path, 'nice'), exist_ok = True)
    if plot_error:
        os.makedirs(os.path.join(root_path, 'errors'), exist_ok = True)
        os.makedirs(os.path.join(plot_path, 'errors'), exist_ok = True)
        for error_to_plot in errors_to_plot:
            os.makedirs(os.path.join(plot_path, 'errors', error_to_plot), exist_ok = True)
    if latent_trace:
        os.makedirs(os.path.join(plot_path, 'latent'), exist_ok = True)
    os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)

    # final directory
    plot_path = plot_path + '/'
    print(f"save plots to {plot_path}")

    return root_path, plot_path
            
def load_model(cfg, cfg_net, dataset, file, device):

    net = Loci(
        cfg_net,
        camera_view_matrix = dataset.cam if cfg.datatype == 'cater' else None,
        zero_elevation     = dataset.z if cfg.datatype == 'cater' else None,
        teacher_forcing    = cfg.teacher_forcing
    )

    # load model
    if file != '':
        print(f"load {file} to device {device}")
        state = th.load(file, map_location=device)

        # backward compatibility
        model = {}
        for key, value in state["model"].items():
            model[key.replace(".module.", ".")] = value

        net.load_state_dict(model)

    # ???
    if net.get_init_status() < 1:
        net.inc_init_level()
    
    # set network to evaluation mode
    net = net.to(device=device)

    return net

def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next):

    slots_bounded = th.squeeze(slots_bounded)[..., None,None,None]
    slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None]
    slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None]
    slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
    slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]

    return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next

def store_latent(gestalt_memory, position_memory, num_objects, plot_path, i):

    gestalt_memory = rearrange(np.array(gestalt_memory), 'l b (n c) -> l b n c', n = num_objects)
    position_memory = rearrange(np.array(position_memory),    'l b (n c) -> l b n c', n = num_objects)

    for n in range(0, num_objects):
        with open(f'{plot_path}latent/latent-{i:04d}-{n:02d}.pickle', "wb") as outfile:
            state = {
                "gestalt":  gestalt_memory[:,:, n],
                "position": position_memory[:,:, n],
            }
            pickle.dump(state, outfile)

    pass
