from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import pickle
from scipy import linalg

import numpy as np
import random
import cv2

def sp_noise(image, prob):
    """ Adds salt and pepper noise to image with probability prob"""
    output = image.copy()
    probs = np.random.random(image.shape[:2])
    output[probs < (prob / 2)] = 0
    output[probs > 1 - (prob / 2)] = 255
    return output

image_orig = np.asarray(Image.open('images/building_org.png').convert("L"))
image = np.asarray(Image.open('images/building_occluded.png').convert("L"))
known = (image_orig == image)

# image = known * sp_noise(image, 0.1) + 255*(1-known)
# cv2.imwrite('images/building_occluded_spnoise.png', image)
image = np.asarray(Image.open('images/building_occluded_spnoise.jpg').convert("L"))


image_orig = image_orig / 255
image = image / 255

# portions of code taken from https://ee227c.github.io/code/lecture5.html#nuclear-norm

class L2Impainting():
    def __init__(self, c_1=750):
        self.c_1 = c_1

    def simplex_projection(self, s):
        """Projection onto the unit simplex."""
        if np.sum(s) <=1 and np.alltrue(s >= 0):
            return s
        u = np.sort(s)[::-1] # get the array of cumulative sums of a sorted (decreasing) copy of v
        cssv = np.cumsum(u) # get the number of > 0 components of the optimal solution
        rho = np.nonzero(u * np.arange(1, len(u)+1) > (cssv - 1))[0][-1] # compute the Lagrange multiplier associated to the simplex constraint
        theta = (cssv[rho] - 1) / (rho + 1.0) # compute the projection by thresholding v using theta
        return np.maximum(s-theta, 0)

    def nuclear_projection(self, A):
        """Projection onto nuclear norm ball."""
        U, s, V = np.linalg.svd(A, full_matrices=False)
        s = self.simplex_projection(s)
        return U.dot(np.diag(s).dot(V))

    def df(self, x):
        # f = |A(X)-b|_2
        # df = (A(X)-b) / |A(X)-b|_2
        obj_x = self.obj(x)
        epsilon = 10**-10
        if obj_x < epsilon:
            return np.zeros(image.shape)
        return known * (x - image) / obj_x

    def proj_g(self, x):
        return np.clip(x, 0, 1)

    def proj_h(self, x):
        return self.nuclear_projection(x/self.c_1) * self.c_1

    def obj(self, x):
        return np.sqrt(np.sum((known * (x - image))**2))

class L2SqImpainting():
    def __init__(self, c_1=750):
        self.c_1 = c_1

    def simplex_projection(self, s):
        """Projection onto the unit simplex."""
        if np.sum(s) <=1 and np.alltrue(s >= 0):
            return s
        u = np.sort(s)[::-1] # get the array of cumulative sums of a sorted (decreasing) copy of v
        cssv = np.cumsum(u) # get the number of > 0 components of the optimal solution
        rho = np.nonzero(u * np.arange(1, len(u)+1) > (cssv - 1))[0][-1] # compute the Lagrange multiplier associated to the simplex constraint
        theta = (cssv[rho] - 1) / (rho + 1.0) # compute the projection by thresholding v using theta
        return np.maximum(s-theta, 0)

    def nuclear_projection(self, A):
        """Projection onto nuclear norm ball."""
        U, s, V = np.linalg.svd(A, full_matrices=False)
        s = self.simplex_projection(s)
        return U.dot(np.diag(s).dot(V))

    def df(self, x):
        # f = 0.5*|A(X)-b|_2^2
        # df = (A(X)-b)
        return known * (x - image)

    def proj_g(self, x):
        return np.clip(x, 0, 1)

    def proj_h(self, x):
        return self.nuclear_projection(x/self.c_1) * self.c_1

    def obj(self, x):
        return 0.5 * np.sum((known * (x - image))**2)

class L1Impainting():
    def __init__(self, c_1=750):
        self.c_1 = c_1

    def simplex_projection(self, s):
        """Projection onto the unit simplex."""
        if np.sum(s) <=1 and np.alltrue(s >= 0):
            return s
        u = np.sort(s)[::-1] # get the array of cumulative sums of a sorted (decreasing) copy of v
        cssv = np.cumsum(u) # get the number of > 0 components of the optimal solution
        rho = np.nonzero(u * np.arange(1, len(u)+1) > (cssv - 1))[0][-1] # compute the Lagrange multiplier associated to the simplex constraint
        theta = (cssv[rho] - 1) / (rho + 1.0) # compute the projection by thresholding v using theta
        return np.maximum(s-theta, 0)

    def nuclear_projection(self, A):
        """Projection onto nuclear norm ball."""
        U, s, V = np.linalg.svd(A, full_matrices=False)
        s = self.simplex_projection(s)
        return U.dot(np.diag(s).dot(V))

    def df(self, x):
        # f = |A(X)-b|
        return known * np.sign(x - image)

    def proj_g(self, x):
        return np.clip(x, 0, 1)

    def proj_h(self, x):
        return self.nuclear_projection(x/self.c_1) * self.c_1

    def obj(self, x):
        return np.sum(np.abs(known * (x - image)))
