import jax.numpy as jnp
import jax
from inv_large import solve_right_inverse

x = jnp.array([1.,2.,3.,4.,5])
w2 = jnp.array([1,2.,3.])
w1 = jnp.array([1,-2.])
print(x.shape)

def conv_3(w,x):
    w1,w2,w3 = w[0],w[1], w[2]
    x1,x2,x3,x4,x5 = x[0],x[1],x[2], x[3],x[4]

    return jnp.array([w1*x1 + w2*x2 + w3*x3,
                      w1*x2 + w2*x3 + w3*x4,
                      w1*x3 + w2*x4 + w3*x5])

def conv_2(w,x):
    w1,w2 = w[0],w[1]
    x1,x2,x3 = x[0],x[1],x[2]
    return jnp.array([w1*x1 + w2*x2,
                      w1*x2 + w2*x3,])

def grad_input_3(w,x):
    w1,w2,w3 = w[0],w[1], w[2]
    return jnp.array([[w1,0, 0],[w2,w1,0],[w3,w2,w1],[0,w3,w2],[0,0,w3]]).T

def grad_weight_3(w,x):
    x1,x2,x3,x4,x5 = x[0],x[1],x[2], x[3],x[4]
    return jnp.array([[x1,x2,x3],[x2,x3,x4],[x3,x4,x5]]).reshape(3,3)
 
def grad_weight_2(w,x):
    x1,x2,x3 = x[0],x[1],x[2]
    return jnp.array([[x1,x2,],[x2,x3,]]).reshape(2,2)


def grad_input_2(w,x):
    w1,w2 = w[0],w[1]
    return jnp.array([[w1,0, ],[w2,w1,],[0,w2]]).T

def forward(w2,w1,x):
    z1 = conv_3(w2,x)
    z2 = conv_2(w1,z1)
    return z2.sum()

def grad_input(w2,w1,x):
    z1 = jnp.array([1,1.,]).reshape(1,2)
    A = grad_input_3(w2,x)
    B = grad_input_2(w1,x)

    return z1@B@A

def grad_wrt_weights_3(w2,w1,x):
    inp_grad = grad_input(w2,w1,x)
    A = grad_input_3(w2,x)
    B = grad_input_2(w1,x)

    A_inv = solve_right_inverse(A)
    B_inv = solve_right_inverse(B)

    return inp_grad@A_inv@grad_weight_3(w2,x)

def grad_wrt_weights_2(w2,w1,x):
    z1 = conv_3(w2,x)
    z2 = conv_2(w1,z1)

    inp_grad = grad_input(w2,w1,x)
    A = grad_input_3(w2,x)
    B = grad_input_2(w1,x)

    A_inv = solve_right_inverse(A)
    B_inv = solve_right_inverse(B)

    return inp_grad@A_inv@B_inv@grad_weight_2(w1,z1)


# print(jax.grad(forward,2)(w2,w1,x))
# print(jax.jacfwd(forward,-1)(w2,w1,x))
# C = grad_input(w2,w1,x)
# print(C)

print(jax.grad(forward,0)(w2,w1,x))
print(grad_wrt_weights_3(w2,w1,x))

print(jax.grad(forward,1)(w2,w1,x))
print(grad_wrt_weights_2(w2,w1,x))
# A = grad_input_3(w2,x)
# print(jnp.ones(3)@A)
# B = grad_input_2(w1,x)


# print(A.shape,B.shape)
# print(grad_input(w2,w1,x))
# print(A.shape)
# z = jnp.array([1,1.,1.])@A
# print(z.shape,jnp.array([1,1.,1.]).shape,A.shape)

# print(z.reshape(1,5)@solve_right_inverse(A))
