import time
import numpy as np
import matplotlib.pyplot as plt
import time

from tqdm import tqdm

import cupy as cp

import imageio
import os

def image_registration_obj_returner(center,V, V_j, cur_V, cur_V_j, tau_u, tau_v, tau_u_j, tau_v_j, all_phi, all_phi_j, all_b, all_b_j, sigmas,t):

  n0, n1, c, N = V.shape
  n0_vec = cp.arange(n0) - center[0]
  n1_vec = cp.arange(n1) - center[1]

  g_stack = cp.zeros((n0, n1,c, N))
  for a in range(len(sigmas)):
    g_ = gaussian_filter_2d_cupy(n0,n1,sigmas[a])
    g_ = g_ / cp.sum(g_)
    g_ = cp.repeat(g_[..., cp.newaxis], c, axis =2 )
    g_stack[:,:,:, a] = g_

  FWres = cp.zeros((n0,n1,c,N))
  for i in range(N):
    FWres[:,:,:,i] = cconv_fourier_cupy(g_stack[:,:,:,i], cur_V_j[:,:,:,i]-cur_V[:,:,:,i])

  V_dot_u = dimage_interpolation_bicubic_dtau1_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0)) #shape will be (B, N0, N1, N2)
  V_dot_v = dimage_interpolation_bicubic_dtau2_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0))# shape will be (B, N0, N1, N2)
  V_dot_u = cp.moveaxis(V_dot_u, 0,-1)
  V_dot_v = cp.moveaxis(V_dot_v, 0, -1)

  V_dot_u_j = dimage_interpolation_bicubic_dtau1_batch_cupy( cp.moveaxis(V_j,-1,0), cp.moveaxis(tau_u_j, -1,0), cp.moveaxis(tau_v_j, -1,0)) #shape will be (B, N0, N1, N2)
  V_dot_v_j = dimage_interpolation_bicubic_dtau2_batch_cupy( cp.moveaxis(V_j,-1,0), cp.moveaxis(tau_u_j, -1,0), cp.moveaxis(tau_v_j, -1,0))# shape will be (B, N0, N1, N2)
  V_dot_u_j = cp.moveaxis(V_dot_u_j, 0,-1)
  V_dot_v_j = cp.moveaxis(V_dot_v_j, 0, -1)

  ax = tuple(range(2)) * 2
  g_stack_flip =  cp.real(cp.fft.fft2(g_stack, axes=ax, norm='ortho'))
  dphi_dV_j = cconv_fourier_cupy(g_stack_flip, FWres )
  dphi_dV = -1*dphi_dV_j.copy()

  tau_u_dot = cp.sum(dphi_dV * V_dot_u, 2) #shape will be (N0, N1, B)
  tau_v_dot = cp.sum(dphi_dV * V_dot_v, 2) #shape will be (N0, N1, B)

  tau_u_dot_j = cp.sum(dphi_dV_j * V_dot_u_j, 2) #shape will be (N0, N1, B)
  tau_v_dot_j = cp.sum(dphi_dV_j * V_dot_v_j, 2) #shape will be (N0, N1, B)

  # Computing gradients of A, b , A_j, b_j
  tau_u_dot_rowsum = cp.sum(tau_u_dot, 1) #shape will be (N0, B)
  tau_u_dot_colsum = cp.sum(tau_u_dot, 0) #shape will be (N1, B)
  tau_v_dot_rowsum = cp.sum(tau_v_dot, 1) #shape will be (N0, B)
  tau_v_dot_colsum = cp.sum(tau_v_dot, 0) #shape will be (N1, B)

  grad_A = cp.zeros((2,2, N))
  grad_b = cp.zeros((2,N))

  grad_A[0, 0, :] = cp.matmul(n0_vec, tau_u_dot_rowsum)
  grad_A[1, 0, :] = cp.dot(n0_vec, tau_v_dot_rowsum)
  grad_A[0, 1, :] = cp.matmul(n1_vec, tau_u_dot_colsum, )
  grad_A[1, 1, :] = cp.matmul( n1_vec , tau_v_dot_colsum)
  grad_b[0] = cp.sum(tau_u_dot_rowsum, axis = 0)
  grad_b[1] = cp.sum(tau_v_dot_rowsum, axis = 0)

  tau_u_j_dot_rowsum = cp.sum(tau_u_dot_j, 1) #shape will be (N0, B)
  tau_u_j_dot_colsum = cp.sum(tau_u_dot_j, 0) #shape will be (N1, B)
  tau_v_j_dot_rowsum = cp.sum(tau_v_dot_j, 1) #shape will be (N0, B)
  tau_v_j_dot_colsum = cp.sum(tau_v_dot_j, 0) #shape will be (N1, B)

  grad_A_j = cp.zeros((2,2, N))
  grad_b_j = cp.zeros((2,N))

  grad_A_j[0, 0, :] = cp.matmul(n0_vec, tau_u_j_dot_rowsum)
  grad_A_j[1, 0, :] = cp.dot(n0_vec, tau_v_j_dot_rowsum)
  grad_A_j[0, 1, :] = cp.matmul(n1_vec, tau_u_j_dot_colsum, )
  grad_A_j[1, 1, :] = cp.matmul( n1_vec , tau_v_j_dot_colsum)
  grad_b_j[0] = cp.sum(tau_u_j_dot_rowsum, axis = 0)
  grad_b_j[1] = cp.sum(tau_v_j_dot_rowsum, axis = 0)

  #Computing grad_phi and grad_phi_j
  sphi = cp.sin(all_phi)
  cphi = cp.cos(all_phi)
  r1 = cp.vstack((-sphi, -cphi)).T
  r2 = cp.vstack((cphi, -sphi)).T

  G = cp.stack((r1, r2), axis = 1)
  G = cp.moveaxis(G, 0, -1) #(G is now of shape (2,2,N) )
  grad_phi = cp.sum(G*grad_A, axis = (0,1))

  sphi_j = cp.sin(all_phi_j)
  cphi_j = cp.cos(all_phi_j)
  r1_j = cp.vstack((-sphi_j, -cphi_j)).T
  r2_j = cp.vstack((cphi_j, -sphi_j)).T

  G_j = cp.stack((r1_j, r2_j), axis = 1)
  G_j = cp.moveaxis(G_j, 0, -1) #(G is now of shape (2,2,N) )
  grad_phi_j = cp.sum(G_j*grad_A_j, axis = (0,1))

  #updating parameters
  all_b = all_b -t*grad_b.copy()
  all_phi = all_phi -(t/100)*grad_phi.copy()

  all_b_j = all_b_j -t*grad_b_j.copy()
  all_phi_j = all_phi_j -(t/100)*grad_phi_j.copy()

  #Generating new transformation fields
  sphi = cp.sin(all_phi)
  cphi = cp.cos(all_phi)
  r1 = cp.vstack((cphi, -sphi)).T
  r2 = cp.vstack((sphi, cphi)).T
  all_A = cp.stack((r1, r2), axis = 1)
  # A = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]])
  all_A = cp.moveaxis(all_A, 0, -1) #(G is now of shape (2,2,N) )

  temp = cp.moveaxis(cp.eye(2)[..., cp.newaxis] - all_A, -1, 0)
  all_corr = cp.moveaxis(cp.dot(temp, center), 0, -1)[:,0,:]
  tau_u, tau_v = affine_to_vf_batch_cupy( all_A, all_b + all_corr, n0, n1 )
  cur_V = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u,-1,0), cp.moveaxis(tau_v,-1,0))
  cur_V = cp.moveaxis(cur_V, 0,-1)

  sphi_j = cp.sin(all_phi_j)
  cphi_j = cp.cos(all_phi_j)
  r1_j = cp.vstack((cphi_j, -sphi_j)).T
  r2_j = cp.vstack((sphi_j, cphi_j)).T
  all_A_j = cp.stack((r1_j, r2_j), axis = 1)
  all_A_j = cp.moveaxis(all_A_j, 0, -1) #(G is now of shape (2,2,N) )

  temp_j = cp.moveaxis(cp.eye(2)[..., cp.newaxis] - all_A_j, -1, 0)
  all_corr_j = cp.moveaxis(cp.dot(temp_j, center), 0, -1)[:,0,:]
  tau_u_j, tau_v_j = affine_to_vf_batch_cupy( all_A_j, all_b_j + all_corr_j, n0, n1 )
  cur_V_j = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V_j,-1,0), cp.moveaxis(tau_u_j,-1,0), cp.moveaxis(tau_v_j,-1,0))
  cur_V_j = cp.moveaxis(cur_V_j, 0,-1)
  # now recompute residuals
  FWres = cconv_fourier_cupy(g_stack, cur_V_j - cur_V )
  obj_vals = .5 * cp.sum(FWres**2, axis = (0,1,2))

  return obj_vals