import jax 
import jax.numpy as jnp
import numpy as np
from inv_large import solve_right_inverse
from read_mnist import *
n1 = 3072#28*28
n2 = 1600#16*16
n3 = 16

classes_to_include = (0, 1)
train_images, train_labels, test_images, test_labels = load_binary_cifar(classes_to_include)

key = jax.random.PRNGKey(42)
# x = jnp.array(jnp.arange(n1)).reshape(n1,1)*0.1
x = train_images[0,:].reshape(n1,1)/255
W1 = jnp.triu(jax.random.normal(key,(n2,n1)))/(28*28) #3x4
W2 = jnp.triu(jax.random.normal(key,(n3,n2)))/(16*16) #2x3


def linear_layer(W,x):
    return W@x

def relu(x):
    return jnp.maximum(0,x)

def leaky_relu(x,alpha=0.01):
    return jnp.where(x>0,x*1,x*alpha)

def forward(W2,W1,x,y):
    z1 = linear_layer(W1,x)
    z1 = leaky_relu(z1)
    z3 = linear_layer(W2,z1)

    return (y - jnp.sum(z3) )**2

def test_forward(W2,W1,x):
    z1 = linear_layer(W1,x)
    z1 = leaky_relu(z1)
    z3 = linear_layer(W2,z1)
    return z3.sum()

def forward_input_grad(W2,W1,x):
    #FORWARD part
    z1 = linear_layer(W1,x)
    z2 = leaky_relu(z1) 

    # STORE signs for backward part
    signs_for_grads = jnp.where(z2>0,1,0.01)

    z3 = linear_layer(W2,z2)
    #DISCARD Z2

    y_hat = jnp.sum(z3)
    #DISCARD z3
    err = -2*(y - y_hat )

    #BACKWARD part
    #We have only signs_for_grads for memory
    res = ((W2)@(W1*signs_for_grads))
    #DISCARD signs_for_grads
    return err*res.sum(0), err

def w1_grad(W2,W1,x,input_grad):
    inv_jacobian = solve_right_inverse(W1)
    z = input_grad@inv_jacobian
    return z*x

def w2_grad(W2,W1,x,input_grad):
    inv_jacobian = solve_right_inverse(W1)
    z = input_grad@inv_jacobian
    z2 = linear_layer(W1,x)
    z2 = leaky_relu(z2)
    idx = jnp.where(z2>0,1,0.01)
    z = z / idx.T 
    inv_jacobian = solve_right_inverse(W2)
    z = z @ inv_jacobian
    return z*z2



def svd_inv(w):
    U,S,V = jnp.linalg.svd(w)
    S = 1/S
    S = jnp.diag(S)
    if S.shape[0] != U.shape[1] or S.shape[1] != V.shape[0]:
        if V.shape[0] > S.shape[0]:
            new_mat = jnp.zeros((S.shape[0], V.shape[0]-S.shape[0]))
            S = jnp.hstack([S,new_mat,
                            ]).T
        else:
            new_mat = jnp.zeros((U.shape[1]-S.shape[0], S.shape[0]))
            S = jnp.hstack([S,new_mat
                            ])
    return V.T@S@U.T


from jax.example_libraries import optimizers
learning_rate = 1e-4
opt_init, opt_update, get_params = optimizers.adam(learning_rate)

def extract_upper_triangular(matrix):
    """
    Extracts the upper triangular part of the matrix above the diagonal, 
    as the diagonal is fixed at ones.

    Args:
        matrix (jax.numpy.array): Input matrix of shape (m, n).

    Returns:
        jax.numpy.array: Flattened array of upper triangular elements excluding the diagonal.
    """
    mask = jnp.triu(jnp.ones_like(matrix), k=1)  # Upper triangular mask (excluding diagonal)
    upper_triangular_values = matrix[mask == 1]
    return upper_triangular_values

def reconstruct_from_upper_triangular(flat_values, shape):
    """
    Reconstructs the full matrix from its upper triangular values,
    setting the diagonal explicitly to ones.

    Args:
        flat_values (jax.numpy.array): Flattened array of upper triangular values.
        shape (tuple): Shape of the full matrix (m, n).

    Returns:
        jax.numpy.array: Reconstructed full matrix of shape (m, n) with ones on the diagonal.
    """
    m, n = shape
    full_matrix = jnp.zeros((m, n))
    mask = jnp.triu(jnp.ones((m, n)), k=1)  # Upper triangular mask (excluding diagonal)
    full_matrix = full_matrix.at[mask == 1].set(flat_values)
    # Set the diagonal explicitly to ones
    full_matrix = full_matrix.at[jnp.diag_indices(min(m, n))].set(1.0)
    return full_matrix

W1 = extract_upper_triangular(W1)
W2 = extract_upper_triangular(W2)

opt_state = opt_init((W1, W2))


for x,y in zip(train_images,train_labels):

    W1 = reconstruct_from_upper_triangular(W1,(n2,n1))
    W2 = reconstruct_from_upper_triangular(W2,(n3,n2))

    x = (x.reshape(-1,1))/255
    

    input_grad, loss = forward_input_grad(W2,W1,x)


    W1_grad = w1_grad(W2,W1,x,input_grad).T
    W2_grad = w2_grad(W2,W1,x,input_grad).T
        

    W1_grad = extract_upper_triangular(W1_grad)
    W2_grad = extract_upper_triangular(W2_grad)

    opt_state  = opt_update(i, (W1_grad,W2_grad), opt_state)

    W1,W2 = get_params(opt_state)
