"""
Visualising animations
"""
import numpy as np
import os
import math


def reshape_and_tile_images(array, shape=(28, 28), n_cols=None):
    if n_cols is None:
        n_cols = int(math.sqrt(array.shape[0]))
    n_rows = int(np.ceil(float(array.shape[0])/n_cols))
    if len(shape) == 2:
        order = 'C'
    else:
        order = 'F'

    def cell(i, j):
        ind = i*n_cols+j
        if i*n_cols+j < array.shape[0]:
            return array[ind].reshape(*shape, order='C')
        else:
            return np.zeros(shape)

    def row(i):
        return np.concatenate([cell(i, j) for j in range(n_cols)], axis=1)

    return np.concatenate([row(i) for i in range(n_rows)], axis=0)


def plot_image(x_seq, shape, path, filename, color_list=None, n_cols=5):
    if n_cols is None:
        n_cols = int(np.sqrt(x_seq.shape[0])) 
    x_seq = x_seq[:n_cols**2]
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib import cm
    from PIL import Image
   
    # first save each frames
    for t in range(T):
        x = reshape_and_tile_images(x_seq[:, t], shape, n_cols)
        #plt.imshow(x, cmap='Greys')
        #plt.axis('off')
        #plt.tight_layout()
        figname = path + filename + '_' + str(t) + '.png'
        plt.imsave(fname=figname, arr=x, cmap='Greys')
        #plt.savefig(figname, format="png")
        #plt.close()
       
    # then merge them together
    img_list = []
    if color_list is None:
        color_list = ['r' for _ in range(T)]
    for t in range(T):
        figname = path + filename + '_' + str(t) + '.png'
        tmp = Image.open(figname).convert('RGBA')
        data = tmp.getdata()
        new_data = []
        for item in data:
            if item[0] == 255 and item[1] == 255 and item[2] == 255:
                new_data.append((255, 255, 255, 0))
            else:
                alpha = int(item[3] / 3)
                #alpha += int(alpha * t / float(T))
                if color_list[t] == 'r':
                    color = (int(255.0 * t / T), 0, 0, alpha)
                elif color_list[t] == 'g':
                    color = (0, int(255.0 * t / T), 0, alpha)
                elif color_list[t] == 'b':
                    color = (0, 0, int(255.0 * t / T), alpha)
                elif color_list[t] == 'y':
                    color = (int(255.0 * t / T), int(255.0 * t / T), 0, alpha)
                else:
                    return NotImplementedError
                new_data.append(color)
        tmp.putdata(new_data)
        img_list.append(tmp)

    # finally save
    figname = path + filename + '.png'
    target = Image.new('RGBA', img_list[0].size, 'white')
    for t in range(T):
        target = Image.alpha_composite(target, img_list[-(t+1)])
    target.save(figname)

    # remove tmp results
    for t in range(T):
        figname = path + filename + '_' + str(t) + '.png'
        os.remove(figname)
 
    print('image saved as ' + path+filename+'.png')

def plot_gif_one_by_one(xRz, xSz, xG, xRf, xSf, x_seq, shape, name, batch_size):
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imsave
    from matplotlib.animation import FuncAnimation
    import scipy.misc as imsv
    for k in range(batch_size):
        for t in range(T):
            imsv.imsave('figsp7/'+name+'P'+str(k)+'fram'+ str(t) + '.png', x_seq[k, t], format="png")
            imsv.imsave('figsp7/' +name+ 'RZ' + str(k) + 'fram' + str(t) + '.png', xRz[k, t], format="png")
            imsv.imsave('figsp7/' +name+ 'SZ' + str(k) + 'fram' + str(t) + '.png', xSz[k, t], format="png")
            imsv.imsave('figsp7/' +name+ 'G' + str(k) + 'fram' + str(t) + '.png', xG[k, t], format="png")
            imsv.imsave('figsp7/' +name+ 'RF' + str(k) + 'fram' + str(t) + '.png', xRf[k, t], format="png")
            imsv.imsave('figsp7/' +name+ 'SF' + str(k) + 'fram' + str(t) + '.png', xSf[k, t], format="png")

def plot_gif_mg_one_by_one(xRz, xSz, xG, xRf, xSf, x_seq, shape, name, batch_size):
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imsave
    from matplotlib.animation import FuncAnimation
    import scipy.misc as imsv
    for k in range(batch_size):
        for t in range(T):
            imsv.imsave('figabdc/'+name+'P'+str(k)+'fram'+ str(t) + '.png', x_seq[k, t], format="png")
            imsv.imsave('figabdc/' +name+ 'RZ' + str(k) + 'fram' + str(t) + '.png', xRz[k, t], format="png")
            imsv.imsave('figabdc/' +name+ 'SZ' + str(k) + 'fram' + str(t) + '.png', xSz[k, t], format="png")
            imsv.imsave('figabdc/' +name+ 'G' + str(k) + 'fram' + str(t) + '.png', xG[k, t], format="png")
            imsv.imsave('figabdc/' +name+ 'RF' + str(k) + 'fram' + str(t) + '.png', xRf[k, t], format="png")
            imsv.imsave('figabdc/' +name+ 'SF' + str(k) + 'fram' + str(t) + '.png', xSf[k, t], format="png")

def plot_gif_mnist_one_by_one(xRz, xSz, xG, xRf, xSf, x_seq, shape, name, batch_size):
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imsave
    from matplotlib.animation import FuncAnimation
    import scipy.misc as imsv
    for k in range(batch_size):
        for t in range(T):
            imsv.imsave('figm74/'+name+'P'+str(k)+'fram'+ str(t) + '.png', np.squeeze(x_seq[k, t]), format="png")
            imsv.imsave('figm74/' +name+ 'RZ' + str(k) + 'fram' + str(t) + '.png', np.squeeze(xRz[k, t]), format="png")
            imsv.imsave('figm74/' +name+ 'SZ' + str(k) + 'fram' + str(t) + '.png', np.squeeze(xSz[k, t]), format="png")
            imsv.imsave('figm74/' +name+ 'G' + str(k) + 'fram' + str(t) + '.png', np.squeeze(xG[k, t]), format="png")
            imsv.imsave('figm74/' +name+ 'RF' + str(k) + 'fram' + str(t) + '.png', np.squeeze(xRf[k, t]), format="png")
            imsv.imsave('figm74/' +name+ 'SF' + str(k) + 'fram' + str(t) + '.png', np.squeeze(xSf[k, t]), format="png")

def plot_gif(x_seq, shape, path, filename):
    n_cols = int(np.sqrt(x_seq.shape[0]))
    print('x data shape', x_seq.shape)
    x_seq = x_seq[:n_cols**2]
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imsave
    from matplotlib.animation import FuncAnimation

    import scipy.misc as imsv 
    # for t in range(T):
    #     x_frame = reshape_and_tile_images(x_seq[:, t], shape, n_cols)
    #     imsave(filename+str(t)+'.png', np.squeeze(x_frame, axis=2), format="png")
    #     # imsv.imsave(filename + str(t) + '.png', np.squeeze(x_seq[:, t], axis=0))

    fig = plt.figure()
    x0 = reshape_and_tile_images(x_seq[:, 0], shape, n_cols)
    print('check output gen shape in plot_gif', x0.shape)
    im = plt.imshow(x0, animated=True, cmap='gray')
    plt.axis('off')

    def update(t):
        x_frame = reshape_and_tile_images(x_seq[:, t], shape, n_cols)
        im.set_array(x_frame)
        return im,


    anim = FuncAnimation(fig, update, frames=np.arange(T), \
                          interval=200, blit=True)
    anim.save(path+filename+'.gif', writer='imagemagick')
    print('image saved as ' + path+filename+'.gif')

def plot_gif_mnist(x_seq, shape, path, filename):
    n_cols = int(np.sqrt(x_seq.shape[0]))
    x_seq = np.squeeze(x_seq, axis=4)
    print('x data shape', x_seq.shape)
    x_seq = x_seq[:n_cols**2]
    T = x_seq.shape[1]
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib.pyplot import imsave
    from matplotlib.animation import FuncAnimation

    import scipy.misc as imsv
    # for t in range(T):
    #     x_frame = reshape_and_tile_images(x_seq[:, t], shape, n_cols)
    #     imsave(filename+str(t)+'.png', np.squeeze(x_frame, axis=2), format="png")
    #     # imsv.imsave(filename + str(t) + '.png', np.squeeze(x_seq[:, t], axis=0))

    fig = plt.figure()
    x0 = reshape_and_tile_images(x_seq[:, 0], shape, n_cols)
    x0 = np.squeeze(x0, axis=2)
    print('check output gen shape in plot_gif', x0.shape)
    im = plt.imshow(x0, animated=True, cmap='gray')
    plt.axis('off')

    def update(t):
        x_frame = reshape_and_tile_images(x_seq[:, t], shape, n_cols)
        # im.set_array(x_frame)
        x0 = np.squeeze(x_frame, axis=2)
        im.set_data(x0)
        return im,


    anim = FuncAnimation(fig, update, frames=np.arange(T), \
                          interval=200, blit=True)
    anim.save(path+filename+'.gif', writer='imagemagick')
    print('image saved as ' + path+filename+'.gif')

