import numpy as np 
import moviepy.editor as mpe
import scipy
from skimage.transform import resize
import matplotlib.pyplot as plt
from rpca import rpca, pcp, just_pca, pcp2

scale = 100   # Adjust scale to change resolution of image
dims = (int(240 * (scale/100)), int(320 * (scale/100)))
#dims = (int(360 * (scale/100)), int(360 * (scale/100)))

def rgb2gray(rgb):
	return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])

def create_data_matrix_from_video(clip, k=5, scale=50):
	img = [rgb2gray(clip.get_frame(i/float(k))).astype(int) for i in range(k * int(clip.duration))]
	print(img[0].shape, img[0][0,0])
	resized = [resize(img[i], dims, preserve_range=True).flatten() for i in range(k * int(clip.duration))]
	print(resized[0].shape, resized[0][0])
	return np.vstack(resized).T

def plt_images(M, A, E, index_array, dims, filename=None):
	f = plt.figure(figsize=(15, 10))
	r = len(index_array)
	pics = r * 3
	for k, i in enumerate(index_array):
		for j, mat in enumerate([M, A, E]):
			sp = f.add_subplot(r, 3, 3*k + j + 1)
			sp.axis('Off')
			pixels = mat[:,i]
			if isinstance(pixels, scipy.sparse.csr_matrix):
				pixels = pixels.todense()
			plt.imshow(np.reshape(pixels, dims), cmap='gray')
	return f

def plots(ims, dims, figsize=(15,20), rows=1, interp=False, titles=None):
	if type(ims[0]) is np.ndarray:
		ims = np.array(ims)
	f = plt.figure(figsize=figsize)
	for i in range(len(ims)):
		sp = f.add_subplot(rows, len(ims)//rows, i+1)
		sp.axis('Off')
		plt.imshow(np.reshape(ims[i], dims), cmap="gray")

video = mpe.VideoFileClip("video.mp4")


#M = create_data_matrix_from_video(video, 100, scale)
#np.save("video.npy", M)
M = np.load("video2.npy")
#M = np.load("monument.npy")

#plt.imshow(np.reshape(M[:,140], dims), cmap='gray')
#plt.figure(figsize=(12, 12))
#plt.imshow(M, cmap='gray')


#plt.imsave(fname="image1.jpg", arr=np.reshape(M[:,140], dims), cmap='gray')

L, S, examples =  pcp(M, maxiter=5, k=10)
L = np.clip(np.abs(L), 0, 256)
S = np.clip(np.abs(S), 0, 256)

#L, S, examples =  rpca(M)
#L, S = just_pca(M)

#L, S, examples = rpca(M, loss='l2')

S_new = np.transpose(S).reshape(-1, dims[0], dims[1])
L_new = np.transpose(L).reshape(-1, dims[0], dims[1])

#plots(examples, dims, rows= int(0.5 * len(examples)))

S_new = [np.concatenate([x[..., np.newaxis] , x[..., np.newaxis] , x[..., np.newaxis]] , axis=2) for x in S_new]
clip = mpe.ImageSequenceClip(S_new, fps=100, with_mask=False)
clip.write_videofile("S_pcp_lambda_0.001.mp4")
L_new = [np.concatenate([x[..., np.newaxis] , x[..., np.newaxis] , x[..., np.newaxis]] , axis=2) for x in L_new]
clip = mpe.ImageSequenceClip(L_new, fps=100, with_mask=False)
clip.write_videofile("L_pcp_lambda_0.001.mp4")

#plt.show()
