import numpy as np
import cupy as cp
import weights
import inv_mean
import matplotlib.pyplot as plt
import time
V = np.load("V.npy")
print(V.shape)
V = cp.array(V)
for i in range(18):
  plt.subplot(3,6,i+1)
  plt.imshow(V[:,:,:,i].get())
plt.show()
for i in range(3):
  plt.subplot(1,3,i +1)
  plt.imshow(cp.sum(V[:,:,:,6*i:6*(i+1)],-1).get())
plt.show()
n0, n1, n2, N = V.shape

V_next = V.copy()
center = cp.array([[32],[32]])
MAX_ITER = 3
W = cp.zeros((N, N))
for idx in range(MAX_ITER):
  # Compute Weights
  Weights = cp.zeros((N, N))
  for j in range(N):
    Weights[j], obj_vals, cur_V, cur_V_j, all_phi, all_phi_j,all_b, all_b_j = compute_weights_j_cupy(V[:,:,:,:] , sigma = 32, center = center, t = 1, j =  j, MAX_ITER = 115, gamma = 2, whichMotionModel = 'euclidean' )
    print(idx, j, Weights[j])
    tau_u, tau_v, all_phi, all_b, mu, obj_vals = invariant_mean_batch_sigma_schedule_cupy(V = V, sigma = 14, sigma_schedule = True, center =center, t = 1, MAX_ITER = 30 , weights = Weights[j], generate_gif = False)
    V_next[:,:,:,j] = mu
    plt.subplot(3, int(N/3), j+1)
    plt.imshow(mu.get())
    plt.title(str(j))
  W = Weights
  plt.show()
  V = V_next.copy()
  plt.imshow(Weights.get())
  plt.colorbar()
  plt.show()