import time
import numpy as np
import matplotlib.pyplot as plt
import time

from tqdm import tqdm

import cupy as cp

import imageio
import os

def objective_value_batch_sigma_returner_cupy( V, mu, sigma, phis, bs, center, t, weights, whichMotionModel = 'euclidean'):
  # Recommended: Function to be used only within an invariant_mean function.
  n0, n1, c, N = V.shape

  n0_vec = cp.arange(n0) - center[0]
  n1_vec = cp.arange(n1) - center[1]


  # smoothing in the cost function
  g = gaussian_filter_2d_cupy(n0,n1,sigma)    ###############----->
  g = g / cp.sum(g)


  all_phi = phis
  all_b   = bs

  all_dil   = cp.ones((2,N)) #### TODO: UPDATE this

  sphi = cp.sin(all_phi)
  cphi = cp.cos(all_phi)
  r1 = cp.vstack((cphi, -sphi)).T
  r2 = cp.vstack((sphi, cphi)).T
  all_A = cp.stack((r1, r2), axis = 1)
  all_A = cp.moveaxis(all_A, 0, -1) #(G is now of shape (2,2,N) )

  temp = cp.moveaxis(cp.eye(2)[..., cp.newaxis] - all_A, -1, 0)
  all_corr = cp.moveaxis(cp.dot(temp, center), 0, -1)[:,0,:]
  tau_u, tau_v = affine_to_vf_batch_cupy( all_A, all_b + all_corr, n0, n1 )

  obj_val = cp.inf
  # current interpolated images
  cur_V = cp.zeros((n0,n1,c,N))
  cur_V = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u,-1,0), cp.moveaxis(tau_v,-1,0))
  cur_V = cp.moveaxis(cur_V, 0,-1)
  # current filtered residual
  FWres = cp.zeros((n0,n1,c,N))
  FWres= cconv_fourier_cupy(g[...,cp.newaxis, cp.newaxis], cur_V - mu[...,cp.newaxis])       ###################------->
  F2 = FWres ** 2

  V_dot_u = dimage_interpolation_bicubic_dtau1_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0)) #shape will be (B, N0, N1, N2)
  V_dot_v = dimage_interpolation_bicubic_dtau2_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0))# shape will be (B, N0, N1, N2)
  V_dot_u = cp.moveaxis(V_dot_u, 0,-1)
  V_dot_v = cp.moveaxis(V_dot_v, 0, -1)
  dphi_dV = cconv_fourier_cupy(dsp_flip_cupy(g)[...,cp.newaxis, cp.newaxis], FWres ) ################----->
  tau_u_dot = cp.sum(dphi_dV * V_dot_u, 2) #shape will be (N0, N1, B)
  tau_v_dot = cp.sum(dphi_dV * V_dot_v, 2) #shape will be (N0, N1, B)

  tau_u_dot_rowsum = cp.sum(tau_u_dot, 1) #shape will be (N0, B)
  tau_u_dot_colsum = cp.sum(tau_u_dot, 0) #shape will be (N1, B)
  tau_v_dot_rowsum = cp.sum(tau_v_dot, 1) #shape will be (N0, B)
  tau_v_dot_colsum = cp.sum(tau_v_dot, 0) #shape will be (N1, B)

  grad_A = cp.zeros((2,2, N))
  grad_b = cp.zeros((2,N))

  grad_A[0, 0, :] = cp.matmul(n0_vec, tau_u_dot_rowsum)
  grad_A[1, 0, :] = cp.dot(n0_vec, tau_v_dot_rowsum)
  grad_A[0, 1, :] = cp.matmul(n1_vec, tau_u_dot_colsum, )
  grad_A[1, 1, :] = cp.matmul( n1_vec , tau_v_dot_colsum)

  grad_b[0] = cp.sum(tau_u_dot_rowsum, axis = 0)
  grad_b[1] = cp.sum(tau_v_dot_rowsum, axis = 0)

  #Computing grad_phi
  sphi = cp.sin(all_phi)
  cphi = cp.cos(all_phi)
  r1 = cp.vstack((-sphi, -cphi)).T
  r2 = cp.vstack((cphi, -sphi)).T

  G = cp.stack((r1, r2), axis = 1)
  G = cp.moveaxis(G, 0, -1) #(G is now of shape (2,2,N) )
  grad_phi = cp.sum(G*grad_A, axis = (0,1))

  #updating parameters
  all_b = all_b -t*grad_b.copy()
  all_phi = all_phi -(t/100)*grad_phi.copy()

  #Generating new transformation fields
  # A = cp.array([[cp.cos(phi), -cp.sin(phi)], [cp.sin(phi), cp.cos(phi)]])
  sphi = cp.sin(all_phi)
  cphi = cp.cos(all_phi)
  r1 = cp.vstack((cphi, -sphi)).T
  r2 = cp.vstack((sphi, cphi)).T
  all_A = cp.stack((r1, r2), axis = 1)
  all_A = cp.moveaxis(all_A, 0, -1) #(G is now of shape (2,2,N) )

  temp = cp.moveaxis(cp.eye(2)[..., cp.newaxis] - all_A, -1, 0)
  all_corr = cp.moveaxis(cp.dot(temp, center), 0, -1)[:,0,:]
  tau_u, tau_v = affine_to_vf_batch_cupy( all_A, all_b + all_corr, n0, n1 )
  cur_V = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u,-1,0), cp.moveaxis(tau_v,-1,0))
  cur_V = cp.moveaxis(cur_V, 0,-1)

  # now compute the new mean
  mu = cp.sum(cur_V* weights[cp.newaxis, cp.newaxis, cp.newaxis, ...] ,-1) / cp.sum(weights)

  # now recompute residuals
  FWres =  cur_V - mu[...,cp.newaxis]

  # compute the objective function and output
  F2 = FWres ** 2
  obj_val = .5 * cp.dot( F2.sum(axis=tuple(range(FWres.ndim - 1))) , weights)
  return obj_val

def invariant_mean_batch_sigma_schedule_cupy( V, sigma, sigma_schedule , center, t, MAX_ITER, weights, whichMotionModel = 'euclidean', generate_gif = False ):
  '''
  Inputs:
    V --------------- is the set of images of shape (n0,n1,c=color_channel,N=batch_size),
    sigma ----------- initial smoothing level, is optimized within the function
    center ---------- origin w.r.t to which transformations are applied. Typically this is the center of the image
    t --------------- learning rate of gradient descent
    MAX_ITER -------- maximum number of gradient descent iterations
    weights --------- a 1d vector of shape (B, )

  Ouputs:
    tau_u ------------ transformation vector field
    tau_v ------------ transformation vector field
    all_phi ---------- angles
    all_b ------------ shifts
    mu --------------- the invariant mean, size (n0,n1,c=color_channels)
    obj_vals --------- the invariant mean objective function across iterations of gradient descent

  '''


  n0, n1, c, N = V.shape

  n0_vec = cp.arange(n0) - center[0]
  n1_vec = cp.arange(n1) - center[1]


  # smoothing in the cost function
  g = gaussian_filter_2d_cupy(n0,n1,sigma)
  g = g / cp.sum(g)

  # transformation field for each input
  tau_u = cp.zeros((n0,n1,N))
  tau_v = cp.zeros((n0,n1,N))

  tau_u_Id, tau_v_Id = identity_vf_cupy(n0,n1)

  for i in range(N):
      tau_u[:,:,i] = tau_u_Id.copy()
      tau_v[:,:,i] = tau_v_Id.copy()

  # initial optimization parameters
  all_phi = cp.zeros((N,))
  all_b   = cp.zeros((2,N))

  all_dil   = cp.ones((2,N))

  ### initialize
  obj_vals = cp.inf * cp.ones( (MAX_ITER,) )

  # initial mean (just the euclidean mean)
  mu = cp.sum(V,3) / N
  # plt.imshow(mu.get())
  # plt.colorbar()
  # plt.show()

  # current interpolated images
  cur_V = cp.zeros((n0,n1,c,N))

  cur_V = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u,-1,0), cp.moveaxis(tau_v,-1,0))
  cur_V = cp.moveaxis(cur_V, 0,-1)

  # current filtered residual
  FWres = cp.zeros((n0,n1,c,N))

  FWres= cconv_fourier_cupy(g[...,cp.newaxis, cp.newaxis], cur_V - mu[...,cp.newaxis])

  # current objective
  F2 = FWres ** 2

  obj_vals[0] = .5 * cp.dot( F2.sum(axis=tuple(range(FWres.ndim - 1))), weights)

  files = []
  for idx in tqdm(range(MAX_ITER)):
  # for idx in range(MAX_ITER):
      if((generate_gif == True) ):#and (idx%10==0) ):
          name = str(idx) + '.png'
          _ = plt.imshow(mu.get())
          plt.title(name)
          plt.savefig(name )
          files.append(name)

      # take a step in the transformation of each of the images

      V_dot_u = dimage_interpolation_bicubic_dtau1_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0)) #shape will be (B, N0, N1, N2)
      V_dot_v = dimage_interpolation_bicubic_dtau2_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u, -1,0), cp.moveaxis(tau_v, -1,0))# shape will be (B, N0, N1, N2)
      V_dot_u = cp.moveaxis(V_dot_u, 0,-1)
      V_dot_v = cp.moveaxis(V_dot_v, 0, -1) #shape is now (N0, N1, N2, B)
      dphi_dV = cconv_fourier_cupy(dsp_flip_cupy(g)[...,cp.newaxis, cp.newaxis], FWres )
      tau_u_dot = cp.sum(dphi_dV *V_dot_u, 2) #shape will be (N0, N1, B)
      tau_v_dot = cp.sum(dphi_dV * V_dot_v, 2) #shape will be (N0, N1, B)

      tau_u_dot_rowsum = cp.sum(tau_u_dot, 1) #shape will be (N0, B)
      tau_u_dot_colsum = cp.sum(tau_u_dot, 0) #shape will be (N1, B)
      tau_v_dot_rowsum = cp.sum(tau_v_dot, 1) #shape will be (N0, B)
      tau_v_dot_colsum = cp.sum(tau_v_dot, 0) #shape will be (N1, B)

      grad_A = cp.zeros((2,2, N))
      grad_b = cp.zeros((2,N))

      grad_A[0, 0, :] = cp.matmul(n0_vec, tau_u_dot_rowsum)
      grad_A[1, 0, :] = cp.dot(n0_vec, tau_v_dot_rowsum)
      grad_A[0, 1, :] = cp.matmul(n1_vec, tau_u_dot_colsum, )
      grad_A[1, 1, :] = cp.matmul( n1_vec , tau_v_dot_colsum)


      grad_b[0] = cp.sum(tau_u_dot_rowsum, axis = 0)
      grad_b[1] = cp.sum(tau_v_dot_rowsum, axis = 0)

      #Computing grad_phi
      sphi = cp.sin(all_phi)
      cphi = cp.cos(all_phi)
      r1 = cp.vstack((-sphi, -cphi)).T
      r2 = cp.vstack((cphi, -sphi)).T

      G = cp.stack((r1, r2), axis = 1)
      G = cp.moveaxis(G, 0, -1) #(G is now of shape (2,2,N) )
      grad_phi = cp.sum(G*grad_A, axis = (0,1))

      #updating parameters
      all_b = all_b -t*grad_b.copy()
      all_phi = all_phi -(t/100)*grad_phi.copy()

      #Generating new transformation fields
      # A = np.array([[np.cos(phi), -np.sin(phi)], [np.sin(phi), np.cos(phi)]])
      sphi = cp.sin(all_phi)
      cphi = cp.cos(all_phi)
      r1 = cp.vstack((cphi, -sphi)).T
      r2 = cp.vstack((sphi, cphi)).T
      all_A = cp.stack((r1, r2), axis = 1)
      all_A = cp.moveaxis(all_A,0, -1) #(G is now of shape (2,2,N) )

      temp = cp.moveaxis(cp.eye(2)[..., cp.newaxis] - all_A, -1, 0)
      all_corr = cp.moveaxis(cp.dot(temp, center), 0, -1)[:,0,:]
      tau_u, tau_v = affine_to_vf_batch_cupy( all_A, all_b + all_corr, n0, n1 )
      cur_V = image_interpolation_bicubic_batch_cupy( cp.moveaxis(V,-1,0), cp.moveaxis(tau_u,-1,0), cp.moveaxis(tau_v,-1,0))
      # for i in range(10):
      #   plt.subplot(2,5,i+1)
      #   plt.imshow(cur_V[i,:,:,0].get())
      # plt.colorbar()
      # plt.show()
      cur_V = cp.moveaxis(cur_V, 0,-1)

      # now compute the new mean
      mu = cp.sum(cur_V* weights[cp.newaxis, cp.newaxis, cp.newaxis, ...] ,-1) / cp.sum(weights)
      plt.imshow(mu.get())
      plt.colorbar()
      plt.show()

      if(sigma_schedule == True):
        least_obj = np.inf
        sigma_ = sigma
        for sig in range(1, sigma):
            obj = objective_value_batch_sigma_returner_cupy( V, mu.copy(), sig ,all_phi.copy(), all_b.copy(), center, t, weights, whichMotionModel = 'euclidean')
            if(obj <= least_obj):
                least_obj = obj
                sigma_ = sig
        g = gaussian_filter_2d_cupy(n0,n1,sigma_)
        g = g / cp.sum(g)
        # print(sigma_)
      # now recompute residuals
      FWres = cconv_fourier_cupy(g[...,cp.newaxis, cp.newaxis], cur_V - mu[...,cp.newaxis] )

      # compute the objective function and output
      F2 = FWres ** 2
      obj_vals[idx] = .5 * cp.dot( F2.sum(axis=tuple(range(FWres.ndim - 1))) , weights)
      # print(obj_vals[idx])

      if(obj_vals[idx] <=0.005):
        if(generate_gif == True):
          with imageio.get_writer('new.gif', mode='I' ) as writer:
            for filename in files:
                image = imageio.imread(filename)
                for i in range(1):
                    writer.append_data(image)
          for filename in set(files):
              os.remove(filename)
        return tau_u, tau_v, all_phi, all_b, mu, obj_vals

  if(generate_gif == True):
      with imageio.get_writer('new.gif', mode='I' ) as writer:
          for filename in files:
              image = imageio.imread(filename)
              for i in range(1):
                  writer.append_data(image)
      for filename in set(files):
          os.remove(filename)


  return tau_u, tau_v, all_phi, all_b, mu, obj_vals