import numpy as np
from tqdm import tqdm


class OF:
    def __init__(self, num_inputs, num_units, batch_size, lr_r=1e-2, lr_Phi=1e-2, lmda=5e-3):
        self.lr_r = lr_r  # learning rate of r
        self.lr_Phi = lr_Phi  # learning rate of Phi
        self.lmda = lmda  # regularization parameter

        self.num_inputs = num_inputs
        self.num_units = num_units
        self.batch_size = batch_size

        # Initialize the weights from standard normal distribution
        # phi is the "dictionary" as referenced in the original paper
        Phi = np.random.randn(self.num_inputs, self.num_units).astype(np.float32)
        self.Phi = Phi  # * np.sqrt(1/self.num_units)

        # activity of neurons --> the sparse representations in our language
        self.r = np.zeros((self.batch_size, self.num_units))

    def initialize_states(self):
        self.r = np.zeros((self.batch_size, self.num_units))

    def normalize_rows(self):
        self.Phi = self.Phi / np.maximum(np.linalg.norm(self.Phi, ord=2, axis=0, keepdims=True), 1e-8)

    # thresholding function of S(x)=|x|
    def soft_thresholding_func(self, x, lmda):
        return np.maximum(x - lmda, 0) - np.maximum(-x - lmda, 0)

    def calculate_total_error(self, error):
        recon_error = np.mean(error ** 2)  # MSE
        sparsity_r = self.lmda * np.mean(np.abs(self.r))

        return recon_error + sparsity_r

    def __call__(self, inputs, training=True):
        # Updates
        error = inputs - self.r @ self.Phi.T

        r = self.r + self.lr_r * error @ self.Phi
        self.r = self.soft_thresholding_func(r, self.lmda)

        if training:
            error = inputs - self.r @ self.Phi.T
            dPhi = error.T @ self.r
            self.Phi += self.lr_Phi * dPhi  # basically do gradient descent here

        return error, self.r


def run_of_experiment(model, nt_max, eps, num_iter, batch_size, inputs_list):
    for iter_ in tqdm(range(num_iter)):

        # pull from the data we made up top
        idx = np.arange(0, len(inputs_list) + 1, batch_size)

        inputs = np.array(inputs_list[idx[iter_]:idx[iter_ + 1]])

        model.initialize_states()  # Reset states
        model.normalize_rows()  # Normalize weights

        # Input an image patch until latent variables are converged
        r_tm1 = model.r  # set previous r (t minus 1)

        for t in range(nt_max):
            # Update r without update weights
            # don't update phi
            error, r = model(inputs, training=False)
            dr = r - r_tm1

            # Compute norm of r
            dr_norm = np.linalg.norm(dr, ord=2) / (eps + np.linalg.norm(r_tm1, ord=2))
            r_tm1 = r  # update r_tm1 #this is the model state of the latent

            # Check convergence of r, then do gradient ascent/descent on phi
            if dr_norm < eps:
                error, r = model(inputs, training=True)
                break

            # If failure to convergence, break and print error
            if t >= nt_max - 2:
                print("Error with iter:", iter_)
                print(dr_norm)
                break

    return model.Phi.T