import torch as th
from torch import nn
import numpy as np
import cv2
from einops import rearrange, repeat
import matplotlib.pyplot as plt
import PIL

def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False):

    if normalize:
        min_ = th.min(tensor)
        max_ = th.max(tensor)
        tensor = (tensor - min_) / (max_ - min_)

    if mean_std_normalize:
        mean = th.mean(tensor)
        std = th.std(tensor)
        tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5

    if scale > 1:
        upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device)
        tensor = upsample(tensor)

    return tensor

def preprocess_multi(*args, scale):
    return [preprocess(a, scale) for a in args]

def color_mask(mask):

    colors = th.tensor([
	[ 255,   0,   0 ],
	[   0,   0, 255 ],
	[ 255, 255,   0 ],
	[ 255,   0, 255 ],
	[   0, 255, 255 ],
	[   0, 255,   0 ],
	[ 255, 128,   0 ],
	[ 128, 255,   0 ],
	[ 128,   0, 255 ],
	[ 255,   0, 128 ],
	[   0, 255, 128 ],
	[   0, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 255 ],
	[ 128, 255, 255 ],
	[ 128, 255, 255 ],
	[ 255, 255, 128 ],
	[ 255, 255, 128 ],
	[ 255, 128, 255 ],
	[ 128,   0,   0 ],
	[   0,   0, 128 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
    ], device = mask.device) / 255.0

    colors = colors.view(1, -1, 3, 1, 1)
    mask = mask.unsqueeze(dim=2)

    return th.sum(colors[:,:mask.shape[1]] * mask, dim=1)


def priority_to_img(priority, h, w):

    imgs = []

    for p in range(priority.shape[2]):

        img = np.zeros((h,w,3), np.uint8)

        font                   = cv2.FONT_HERSHEY_SIMPLEX
        text_position          = (h // 6, w //2)
        font_scale             = w / 256
        font_color             = (255,255,255)
        thickness              = 2
        lineType               = 2

        cv2.putText(img,f'{priority[0,0,p].item():.2e}',
            text_position,
            font,
            font_scale,
            font_color,
            thickness,
            lineType)

        imgs.append(rearrange(th.tensor(img, device=priority.device), 'h w c -> 1 1 c h w'))

    return imgs

def get_color(o):
    colors = th.tensor([
	[ 255,   0,   0 ],
	[   0,   0, 255 ],
	[ 255, 255,   0 ],
	[ 255,   0, 255 ],
	[   0, 255, 255 ],
	[   0, 255,   0 ],
	[ 255, 128,   0 ],
	[ 128, 255,   0 ],
	[ 128,   0, 255 ],
	[ 255,   0, 128 ],
	[   0, 255, 128 ],
	[   0, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 128 ],
	[ 128, 255, 128 ],
	[ 128, 128, 255 ],
	[ 255, 128, 255 ],
	[ 128, 255, 255 ],
	[ 128, 255, 255 ],
	[ 255, 255, 128 ],
	[ 255, 255, 128 ],
	[ 255, 128, 255 ],
	[ 128,   0,   0 ],
	[   0,   0, 128 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128, 128,   0 ],
	[ 128,   0, 128 ],
	[ 128,   0, 128 ],
	[   0, 128, 128 ],
	[   0, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
	[ 128, 128, 128 ],
    ]) / 255.0

    colors = colors.view(48,3)
    return colors[o]

def to_rgb(tensor: th.Tensor):
    return th.cat((
        tensor * 0.6 + 0.4,
        tensor, 
        tensor
    ), dim=1)

def visualise_gate(gate, h, w):
    bar = th.ones((1,h,w), device=gate.device) * 0.9
    black = int(w*gate.item())
    if black > 0:
        bar[:,:, -black:] = 0
    return bar


def get_highlighted_input(input, mask_cur):

    # highlight error
    highlighted_input = input
    if mask_cur is not None:
        grayscale        = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
        object_mask_cur  = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1)
        highlighted_input  = grayscale * (1 - object_mask_cur) 
        highlighted_input += grayscale * object_mask_cur * 0.3333333 
        cmask = color_mask(mask_cur[:,:-1])
        highlighted_input  = highlighted_input + cmask * 0.6666666

    return highlighted_input


def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur):

    image = (1-image) * slots_bounded + image * (1-slots_bounded)
    image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur)
    image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur)

    return image

def compute_occlusion_mask(maskraw_cur, maskraw_next,  mask_cur, mask_next, scale):

    # compute occlusion mask
    occluded_cur    = th.clip(maskraw_cur - mask_cur, 0, 1)[:,:-1]
    occluded_next   = th.clip(maskraw_next - mask_next, 0, 1)[:,:-1]

    # to rgb
    maskraw_cur     = repeat(maskraw_cur[:,:-1], 'b o h w -> b (o 3) h w')
    maskraw_next    = repeat(maskraw_next[:,:-1], 'b o h w -> b (o 3) h w')

    # scale 
    occluded_next   = preprocess(occluded_next, scale)
    occluded_cur    = preprocess(occluded_cur, scale)
    maskraw_cur     = preprocess(maskraw_cur, scale)
    maskraw_next    = preprocess(maskraw_next, scale)

    # set occlusion to red
    maskraw_cur         = rearrange(maskraw_cur, 'b (o c) h w -> b o c h w', c = 3)
    maskraw_cur[:,:,0]  = maskraw_cur[:,:,0] * (1 - occluded_next)
    maskraw_cur[:,:,1]  = maskraw_cur[:,:,1] * (1 - occluded_next)

    maskraw_next        = rearrange(maskraw_next, 'b (o c) h w -> b o c h w', c = 3)
    maskraw_next[:,:,0] = maskraw_next[:,:,0] * (1 - occluded_next)
    maskraw_next[:,:,1] = maskraw_next[:,:,1] * (1 - occluded_next)

    return maskraw_cur, maskraw_next


def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3):
    error_plots = []
    if len(errors) > 0:
        num_slots = int(th.sum(slots_bounded).item())
        errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
        visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
        for error,visibility in zip(errors, visibilty_memory):

            if len(error) < sequence_len:
                fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2)))
                plt.plot(error, label=error_name)

                visibility = np.concatenate((visibility, np.ones(sequence_len-len(error))))
                ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform())
                
                plt.xlim((0,sequence_len))
                plt.ylim((0,ylim))
                fig.tight_layout()
                plt.savefig(f'{root_path}/tmp.jpg')

                error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
                error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
                plt.close(fig)  

                error_plots.append(error_plot)

    return error_plots

def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False):

    fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) ))
    plt.plot(error, label=error_name)
    
    if online_surprise:
        # compute moving average of error
        moving_average_length = 10
        if t > moving_average_length:
            moving_average_length += 1
            average_error = np.mean(error[-moving_average_length:-1])
            current_sd = np.std(error[-moving_average_length:-1])
            current_error = error[-1]

            if current_error > average_error + 2 * current_sd:
                fig.set_facecolor('orange')

    plt.xlim((0,sequence_len))
    plt.legend()
    # increase title size
    plt.title(f'{error_name}', fontsize=20)
    plt.xlabel('timestep')
    plt.ylabel('error')
    plt.savefig(f'{root_path}/tmp.jpg')

    error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
    error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
    plt.close(fig)  

    return error_plot

def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, error_next, object_cur, object_next, maskraw_cur, maskraw_next, position_cur2d, velocity_next2d, output, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):

    # add ground truth positions of objects to image
    if gt_positions_target_next is not None:
        for o in range(gt_positions_target_next.shape[1]):
            position = gt_positions_target_next[0, o]
            position = position/2 + 0.5

            if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
                width = 5
                w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
                h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
                col = get_color(o).view(3,1,1)
                target[0,:,(w-width):(w+width), (h-width):(h+width)] = col

                # add these positions to the associated slots velocity_next2d ilustration
                slots = (association_table[0] == o).nonzero()
                for s in slots.flatten():
                    velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col

                    if output_hidden is not None and s != largest_object:
                        output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col

    gateheight = 60
    ch = 40
    gh = 40
    gh_bar = gh-20
    gh_margin = int((gh-gh_bar)/2)
    margin = 20
    slots_margin  = 10
    height = size[0] * 6 + 18*5
    width  = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1)
    img = th.ones((3, height, width), device = object_next.device) * 0.4
    row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin,   2*size[0]*(row_index+1) + (row_index+1)*margin])
    col1 = range(margin, margin + size[1]*2)
    col2 = range(width-(margin+size[1]*2), width-margin)

    img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0] 
    img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0]
    img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0] 

    if error_plot is not None:
        img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True)
    if error_plot2 is not None:
        img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True)

    for o in range(num_objects):

        col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin
        col = range(col, col + size[1])

        # color bar for the gate
        if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
            img[:,margin:margin+ch,  col] = get_color(o).view(3,1,1).to(object_next.device)

        img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin,  col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col))
        offset = gh+margin-gh_margin+ch+2*margin
        row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
        
        img[:,row(0)[0]:row(0)[1],  col] = preprocess(maskraw_next[0,o].to(object_next.device))
        img[:,row(1)[0]:row(1)[1],  col] = preprocess(object_next[:,o].to(object_next.device))
        if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
            img[:,row(2)[0]:row(2)[1],  col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)

        offset = margin*2-8
        row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])   
        img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin,  col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col))
        img[:,row(4)[0]:row(4)[1],  col]    = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0]
        if (error_plot_slots is not None) and len(error_plot_slots) > o:
            img[:,row(5)[0]:row(5)[1],  col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)

    img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()

    return img


def plot_object_view_debug(error_plot, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, error_next, object_cur, object_next, maskraw_cur, maskraw_next, position_cur2d, velocity_next2d, output, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):

    # add ground truth positions of objects to image
    for o in range(gt_positions_target_next.shape[1]):
        position = gt_positions_target_next[0, o]
        position = position/2 + 0.5

        if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
            width = 5
            w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
            h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
            col = get_color(o).view(3,1,1)
            target[0,:,(w-width):(w+width), (h-width):(h+width)] = col

            # add these positions to the associated slots velocity_next2d ilustration
            slots = (association_table[0] == o).nonzero()
            for s in slots.flatten():
                velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col

                if output_hidden is not None and s != largest_object:
                    output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col

    gateheight = 60
    gh = 30
    margin = 20
    height = size[0] * 8 + 18*5 + gateheight
    width  = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1)
    img = th.ones((3, height, width), device = object_next.device) * 0.2
    row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin,   2*size[0]*(row_index+1) + (row_index+1)*margin])
    col1 = range(margin, margin + size[1]*2)
    col2 = range(width-(margin+size[1]*2), width-margin)

    img[:,row(0)[0]:row(0)[1], col1]      = preprocess(error_plot.to(object_next.device), normalize= True) # preprocess(input, 2)[0]
    img[:,row(1)[0]:row(1)[1], col1]   = preprocess(highlighted_input.to(object_next.device), 2)[0]
    img[:,row(2)[0]:row(2)[1], col1] = preprocess(error_next.to(object_next.device), 2)[0]

    img[:,row(0)[0]:row(0)[1], col2]    = preprocess(output.to(object_next.device), 2)[0]
    img[:,row(1)[0]:row(1)[1], col2] = preprocess(target.to(object_next.device), 2)[0]
    if output_hidden is not None:
        img[:,row(2)[0]:row(2)[1], col2] = preprocess(output_hidden.to(object_next.device), 2)[0]

    for o in range(num_objects):

        col = 18+size[1]*2+6+o*(6+size[1])
        col = range(col, col + size[1])
        row = (lambda row_index: [18+(size[0]+6)*row_index, 12+(size[0]+6)*(row_index+1)])

        img[:,row(0)[0]:row(0)[1],  col] = preprocess(object_cur[:,o].to(object_next.device), normalize=True)[0]
        img[:,row(1)[0]:row(1)[1],  col] = preprocess(maskraw_cur[0,o].to(object_next.device))
        img[:,row(2)[0]:row(2)[1],  col] = preprocess(position_cur2d[o].to(object_next.device))

        img[:,row(2)[1]+10:row(2)[1]+gh-10, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh-20, w=len(col))
        img[:,row(2)[1]+10+gh:row(2)[1]+2*gh-10, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh-20, w=len(col))

        img[:,gateheight+row(3)[0]:gateheight+row(3)[1],  col] = preprocess(object_next[:,o].to(object_next.device), normalize=True)[0]
        img[:,gateheight+row(4)[0]:gateheight+row(4)[1],  col] = preprocess(maskraw_next[0,o].to(object_next.device))
        img[:,gateheight+row(5)[0]:gateheight+row(5)[1],  col] = preprocess(velocity_next2d[o].to(object_next.device))

        if len(error_plot_slots) > o:
            img[:,gateheight+row(6)[0]:gateheight+row(6)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)

        if len(error_plot_slots2) > o:
            img[:,gateheight+row(7)[0]:gateheight+row(7)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)

    img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()

    return img

def plot_error_batch(statistics_batch, evaluation_phase_critical, errors_to_plot, evaluation_phase_start, plot_path, i):
    for error_to_plot in errors_to_plot:
        plt.figure(figsize=(10, 5))
        plt.plot(statistics_batch[error_to_plot], label=error_to_plot)

        if evaluation_phase_critical is not None:
            # plot vertial line at critical point
            plt.axvline(x=round(evaluation_phase_critical), color='r', linestyle='--', label='critical point')

        plt.legend()
        plt.title(f'{error_to_plot} over time for sample {i}')
        plt.xlabel('timestep')
        plt.ylabel('error')
        plt.savefig(f'{plot_path}errors/{error_to_plot}/error-plot{i:04d}.jpg')
        plt.close()
    pass


def write_image(file, img):
    img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
    cv2.imwrite(file, img)

    pass