"""Create 2D density plots."""
import matplotlib.pyplot as plt
from pylab import text
import time
import os
import numpy as np
import cv2


def plot_2d_hist(data, base_folder, data_name, time_in_name=True, xlim=None, ylim=None, obs=None, marker_mode='yellow', weights=None):
    """Create 2D histogram plot."""
    if time_in_name:
        filename = os.path.join(base_folder,  f'{data_name}_2D_hist_{time.time()}.png')
    else:
        filename = os.path.join(base_folder,  f'{data_name}_2D_hist.png')
    if not os.path.isdir(base_folder):
        os.makedirs(base_folder)
    np_data = data.cpu().detach().numpy()
    if xlim is not None and ylim is not None:
        ranges = np.array(np.stack([xlim, ylim]))
    else:
        ranges = None
    plt.hist2d(np_data[:, 0], np_data[:, 1], bins=(100, 100), cmap=plt.cm.jet,  range=ranges, weights=weights)

    if obs is not None:
        if marker_mode == 'yellow':
            marker = '+'
            color = 'yellow'
            display_text = 'Observation time'
        else:
            marker = 'o'
            color = 'red'
            display_text = ''
        for obs_i in obs:
            plt.plot(obs_i[0], obs_i[1],marker=marker, markersize='50', markeredgecolor=color, markerfacecolor='none', fillstyle='none') 
        text(obs[0, 0], obs[0, 1],display_text,
            fontsize=15,
            color='yellow',
            horizontalalignment='center',
            verticalalignment='center')
    plt.savefig(filename, bbox_inches='tight')
    plt.clf()
    return filename


def plot_as_video(base_folder, images, video_name):
    img_array = []
    video_file_name = os.path.join(base_folder, f'{video_name}_{time.time()}.mp4')
    for filename in images:
        img = cv2.imread(filename)
        height, width, layers = img.shape
        size = (width,height)
        img_array.append(img)
    if not os.path.isdir(base_folder):
        os.makedirs(base_folder)
    out = cv2.VideoWriter(video_file_name, fourcc=cv2.VideoWriter_fourcc(*'mp4v'), fps=5, frameSize=size)
    for i in range(len(img_array)):
        out.write(img_array[i])
    cv2.destroyAllWindows()
    out.release()
    print(f'Video written to file {video_file_name}')


def plot_trajectory_video(plot_folder, rand_select, obs_ts, particles, xlim=None, ylim=None, filename='', weights=None):
    """Very specific plotting."""
    plot_files = []
    for i in range(particles.shape[0] - 1):
        if i in rand_select:
            marker_mode = 'yellow'
        else:
            marker_mode = 'red'
        obs = obs_ts[:, i]
        sample = particles[i]
        if weights is not None:
            plot_weights = weights[i]
        else:
            plot_weights = None
        img_filename = plot_2d_hist(sample, base_folder=os.path.join(plot_folder,  'trajectories'), data_name=f'{filename}_{i}', xlim=xlim, ylim=ylim, obs=obs if i in rand_select else None, marker_mode=marker_mode, weights=plot_weights)
        plot_files.append(img_filename)
    plot_as_video(os.path.join(plot_folder, 'videos'), plot_files, filename)
    for file in plot_files: # cleanup
        os.remove(file)

