import time
import numpy as np
import matplotlib.pyplot as plt
import time

from tqdm import tqdm

import cupy as cp

import imageio
import os

def affine_to_vf_batch_cupy(A, b, M, N):
# A is of shape (2,2,batch_size)
# b is of shape (2, batch_size)
#batch_size is batch size
#M,N are the transformation field dimensions
  A0 = A[:,0,:]
  A1 = A[:,1,:]
  eu = cp.dot(cp.arange(M)[:,cp.newaxis], cp.ones(N)[:,cp.newaxis].T)
  ev = cp.dot(cp.ones(M)[:,cp.newaxis], cp.arange(N)[:,cp.newaxis].T)

  tau = A0[cp.newaxis, cp.newaxis, :] * eu[..., cp.newaxis, cp.newaxis] + \
          A1[cp.newaxis, cp.newaxis, :] * ev[..., cp.newaxis, cp.newaxis] + \
          b[cp.newaxis, cp.newaxis, :] * cp.ones((M, N, 1))[...,cp.newaxis]

  return (tau[:,:,0,:], tau[:,:,1,:])

def image_interpolation_bicubic_batch_cupy(x,tau1,tau2):
  # x is of shape (batch_size, N0, N 1, N2), where (N0,N1, N2) are image dimensions channel last
  # tau1 , tau2 are of shape (batch_size, N0,N1)
  # batch_size is batch size

  B = x.shape[0]
  N0 = x.shape[1]
  N1 = x.shape[2]
  N2 = x.shape[3]

  # embed with zeros at boundary
  xx = cp.zeros((B, N0+2,N1+2,N2))
  xx[:,1:(N0+1),1:(N1+1),:] = x.copy()

  ###Generating the index matrix for slicing
  ones = cp.ones((N0,N1))
  B_Idx = [l*ones for l in range(B)]
  B_Idx = cp.stack(B_Idx, axis=0).astype(int)

  # shift tau1 and tau2 to account for this embedding
  tau1 = tau1 + 1
  tau2 = tau2 + 1

  ## generate the 16 resampled slices that will be combined to make up our interpolated image
  #
  #
  ft1 = cp.floor(tau1)
  ft2 = cp.floor(tau2)

  t1_0 = ( cp.minimum( cp.maximum( ft1 - 1, 0 ), N0 + 1 )  ).astype(int)
  t1_1 = ( cp.minimum( cp.maximum( ft1, 0     ), N0 + 1 ) ).astype(int)
  t1_2 = ( cp.minimum( cp.maximum( ft1 + 1, 0 ), N0 + 1 ) ).astype(int)
  t1_3 = ( cp.minimum( cp.maximum( ft1 + 2, 0 ), N0 + 1 ) ).astype(int)

  t2_0 = ( cp.minimum( cp.maximum( ft2 - 1, 0 ), N1 + 1 ) ).astype(int)
  t2_1 = ( cp.minimum( cp.maximum( ft2, 0     ), N1 + 1 ) ).astype(int)
  t2_2 = ( cp.minimum( cp.maximum( ft2 + 1, 0 ), N1 + 1 ) ).astype(int)
  t2_3 = ( cp.minimum( cp.maximum( ft2 + 2, 0 ), N1 + 1 ) ).astype(int)



  x_00 = xx[B_Idx, t1_0, t2_0 ]
  x_01 = xx[B_Idx, t1_0, t2_1 ]
  x_02 = xx[B_Idx, t1_0, t2_2 ]
  x_03 = xx[B_Idx, t1_0, t2_3 ]
  x_10 = xx[B_Idx, t1_1, t2_0 ]
  x_11 = xx[B_Idx, t1_1, t2_1 ]
  x_12 = xx[B_Idx, t1_1, t2_2 ]
  x_13 = xx[B_Idx, t1_1, t2_3 ]
  x_20 = xx[B_Idx, t1_2, t2_0 ]
  x_21 = xx[B_Idx, t1_2, t2_1 ]
  x_22 = xx[B_Idx, t1_2, t2_2 ]
  x_23 = xx[B_Idx, t1_2, t2_3 ]
  x_30 = xx[B_Idx, t1_3, t2_0 ]
  x_31 = xx[B_Idx, t1_3, t2_1 ]
  x_32 = xx[B_Idx, t1_3, t2_2 ]
  x_33 = xx[B_Idx, t1_3, t2_3 ]

  # generate the 16 weights which will be used to combine the x_ij
  #
  # note:
  #    phi(u) = { 1.5 |u|^3 - 2.5 |u|^2 + 1           0 <= |u| <= 1   (0)
  #             { -.5 |u|^3 + 2.5 |u|^2 - 4 |u| + 2   1 <= |u| <= 2   (1)

  # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
  #          and u is positive (|u| = u)
  u = tau1 - t1_0
  a0 = -.5 * u ** 3 + 2.5 * u ** 2 - 4 * u + 2

  # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
  #          and u is positive (|u| = u)
  u = tau1 - t1_1
  a1 = 1.5 * u ** 3 - 2.5 * u ** 2 + 1

  # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
  #          and u is negative (|u| = - u)
  u = tau1 - t1_2
  a2 = -1.5 * u ** 3 - 2.5 * u ** 2 + 1
  # 3: here, we are in case (1)
  #          and u is negative (|u| = - u)
  u = tau1 - t1_3
  a3 = .5 * u ** 3 + 2.5 * u ** 2 + 4 * u + 2

  # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
  #          and u is positive (|u| = u)
  u = tau2 - t2_0
  b0 = -.5 * u ** 3 + 2.5 * u ** 2 - 4 * u + 2

  # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
  #          and u is positive (|u| = u)
  u = tau2 - t2_1
  b1 = 1.5 * u ** 3 - 2.5 * u ** 2 + 1

  # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
  #          and u is negative (|u| = - u)
  u = tau2 - t2_2
  b2 = -1.5 * u ** 3 - 2.5 * u ** 2 + 1

  # 3: here, we are in case (1)
  #          and u is negative (|u| = - u)
  u = tau2 - t2_3
  b3 = .5 * u ** 3 + 2.5 * u ** 2 + 4 * u + 2

  x_pr = ((a0*b0)[...,None] * x_00
          + (a0*b1)[...,None] * x_01
          + (a0*b2)[...,None] * x_02
          + (a0*b3)[...,None] * x_03
          + (a1*b0)[...,None] * x_10
          + (a1*b1)[...,None] * x_11
          + (a1*b2)[...,None] * x_12
          + (a1*b3)[...,None] * x_13
          + (a2*b0)[...,None] * x_20
          + (a2*b1)[...,None] * x_21
          + (a2*b2)[...,None] * x_22
          + (a2*b3)[...,None] * x_23
          + (a3*b0)[...,None] * x_30
          + (a3*b1)[...,None] * x_31
          + (a3*b2)[...,None] * x_32
          + (a3*b3)[...,None] * x_33)
  return x_pr

def dimage_interpolation_bicubic_dtau1_batch_cupy(x,tau1,tau2):
    # x is of shape (batch_size, N0, N1, N2), where (N0,N1, N2) are image dimensions channel last
    # tau1 , tau2 are of shape (batch_size, N0,N1)
    # batch_size is batch size

    B = x.shape[0]
    N0 = x.shape[1]
    N1 = x.shape[2]
    N2 = x.shape[3]

    # embed with zeros at boundary
    xx = cp.zeros((B, N0+2,N1+2,N2))
    xx[:,1:(N0+1),1:(N1+1),:] = x.copy()

    ###Generating the index matrix for slicing
    ones = cp.ones((N0,N1))
    B_Idx = [l*ones for l in range(B)]
    B_Idx = cp.stack(B_Idx, axis=0).astype(int)

    # shift tau1 and tau2 to account for this embedding
    tau1 = tau1 + 1
    tau2 = tau2 + 1

    ## generate the 16 resampled slices that will be combined to make up our interpolated image
    #
    #
    ft1 = np.floor(tau1)
    ft2 = np.floor(tau2)

    t1_0 = ( cp.minimum( cp.maximum( ft1 - 1, 0 ), N0 + 1 ) ).astype(int)
    t1_1 = ( cp.minimum( cp.maximum( ft1, 0     ), N0 + 1 ) ).astype(int)
    t1_2 = ( cp.minimum( cp.maximum( ft1 + 1, 0 ), N0 + 1 ) ).astype(int)
    t1_3 = ( cp.minimum( cp.maximum( ft1 + 2, 0 ), N0 + 1 ) ).astype(int)

    t2_0 = ( cp.minimum( cp.maximum( ft2 - 1, 0 ), N1 + 1 ) ).astype(int)
    t2_1 = ( cp.minimum( cp.maximum( ft2, 0     ), N1 + 1 ) ).astype(int)
    t2_2 = ( cp.minimum( cp.maximum( ft2 + 1, 0 ), N1 + 1 ) ).astype(int)
    t2_3 = ( cp.minimum( cp.maximum( ft2 + 2, 0 ), N1 + 1 ) ).astype(int)

    x_00 = xx[B_Idx, t1_0, t2_0 ]
    x_01 = xx[B_Idx, t1_0, t2_1 ]
    x_02 = xx[B_Idx, t1_0, t2_2 ]
    x_03 = xx[B_Idx, t1_0, t2_3 ]
    x_10 = xx[B_Idx, t1_1, t2_0 ]
    x_11 = xx[B_Idx, t1_1, t2_1 ]
    x_12 = xx[B_Idx, t1_1, t2_2 ]
    x_13 = xx[B_Idx, t1_1, t2_3 ]
    x_20 = xx[B_Idx, t1_2, t2_0 ]
    x_21 = xx[B_Idx, t1_2, t2_1 ]
    x_22 = xx[B_Idx, t1_2, t2_2 ]
    x_23 = xx[B_Idx, t1_2, t2_3 ]
    x_30 = xx[B_Idx, t1_3, t2_0 ]
    x_31 = xx[B_Idx, t1_3, t2_1 ]
    x_32 = xx[B_Idx, t1_3, t2_2 ]
    x_33 = xx[B_Idx, t1_3, t2_3 ]

    # generate the 16 weights which will be used to combine the x_ij
    #

    # phi_dot(u) = {  4.5 sgn(u) u^2 - 5 u              0 <= |u| <= 1   (0)
    #              { -1.5 sgn(u) u^2 + 5 u - 4 sgn(u)   1 <= |u| <= 2   (1)

    # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
    #          and u is positive (sgn(u) = 1)
    u = tau1 - t1_0
    a0 = -1.5 * u ** 2 + 5 * u - 4

    # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
    #          and u is positive (sgn(u) = 1)
    u = tau1 - t1_1
    a1 = 4.5 * u ** 2 - 5 * u

    # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
    #          and u is negative (sgn(u) = -1)
    u = tau1 - t1_2
    a2 = -4.5 * u ** 2 - 5 * u

    # 3: here, we are in case (1)
    #          and u is negative (sgn(u) = -1)
    u = tau1 - t1_3
    a3 = 1.5 * u ** 2 + 5 * u + 4

    # note:
    #    phi(u) = { 1.5 |u|^3 - 2.5 |u|^2 + 1           0 <= |u| <= 1   (0)
    #             { -.5 |u|^3 + 2.5 |u|^2 - 4 |u| + 2   1 <= |u| <= 2   (1)

    # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
    #          and u is positive (|u| = u)
    u = tau2 - t2_0
    b0 = -.5 * u ** 3 + 2.5 * u ** 2 - 4 * u + 2

    # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
    #          and u is positive (|u| = u)
    u = tau2 - t2_1
    b1 = 1.5 * u ** 3 - 2.5 * u ** 2 + 1

    # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
    #          and u is negative (|u| = - u)
    u = tau2 - t2_2
    b2 = -1.5 * u ** 3 - 2.5 * u ** 2 + 1

    # 3: here, we are in case (1)
    #          and u is negative (|u| = - u)
    u = tau2 - t2_3
    b3 = .5 * u ** 3 + 2.5 * u ** 2 + 4 * u + 2

    dx_pr_dtau1 = ((a0*b0)[...,None] * x_00
            + (a0*b1)[...,None] * x_01
            + (a0*b2)[...,None] * x_02
            + (a0*b3)[...,None] * x_03
            + (a1*b0)[...,None] * x_10
            + (a1*b1)[...,None] * x_11
            + (a1*b2)[...,None] * x_12
            + (a1*b3)[...,None] * x_13
            + (a2*b0)[...,None] * x_20
            + (a2*b1)[...,None] * x_21
            + (a2*b2)[...,None] * x_22
            + (a2*b3)[...,None] * x_23
            + (a3*b0)[...,None] * x_30
            + (a3*b1)[...,None] * x_31
            + (a3*b2)[...,None] * x_32
            + (a3*b3)[...,None] * x_33)

    return dx_pr_dtau1

def dimage_interpolation_bicubic_dtau2_batch_cupy(x,tau1,tau2):
    # x is of shape (batch_size, N0, N1, N2), where (N0,N1, N2) are image dimensions channel last
    # tau1 , tau2 are of shape (batch_size, N0,N1)
    # batch_size is batch size
    B = x.shape[0]
    N0 = x.shape[1]
    N1 = x.shape[2]
    N2 = x.shape[3]

    # embed with zeros at boundary
    # embed with zeros at boundary
    xx = cp.zeros((B, N0+2,N1+2,N2))
    xx[:,1:(N0+1),1:(N1+1),:] = x.copy()

    ###Generating the index matrix for slicing
    ones = cp.ones((N0,N1))
    B_Idx = [l*ones for l in range(B)]
    B_Idx = cp.stack(B_Idx, axis=0).astype(int)

    # shift tau1 and tau2 to account for this embedding
    tau1 = tau1 + 1
    tau2 = tau2 + 1

    ## generate the 16 resampled slices that will be combined to make up our interpolated image
    #
    #
    ft1 = cp.floor(tau1)
    ft2 = cp.floor(tau2)

    t1_0 = ( cp.minimum( cp.maximum( ft1 - 1, 0 ), N0 + 1 ) ).astype(int)
    t1_1 = ( cp.minimum( cp.maximum( ft1, 0     ), N0 + 1 ) ).astype(int)
    t1_2 = ( cp.minimum( cp.maximum( ft1 + 1, 0 ), N0 + 1 ) ).astype(int)
    t1_3 = ( cp.minimum( cp.maximum( ft1 + 2, 0 ), N0 + 1 ) ).astype(int)

    t2_0 = ( cp.minimum( cp.maximum( ft2 - 1, 0 ), N1 + 1 ) ).astype(int)
    t2_1 = ( cp.minimum( cp.maximum( ft2, 0     ), N1 + 1 ) ).astype(int)
    t2_2 = ( cp.minimum( cp.maximum( ft2 + 1, 0 ), N1 + 1 ) ).astype(int)
    t2_3 = ( cp.minimum( cp.maximum( ft2 + 2, 0 ), N1 + 1 ) ).astype(int)

    x_00 = xx[B_Idx, t1_0, t2_0 ]
    x_01 = xx[B_Idx, t1_0, t2_1 ]
    x_02 = xx[B_Idx, t1_0, t2_2 ]
    x_03 = xx[B_Idx, t1_0, t2_3 ]
    x_10 = xx[B_Idx, t1_1, t2_0 ]
    x_11 = xx[B_Idx, t1_1, t2_1 ]
    x_12 = xx[B_Idx, t1_1, t2_2 ]
    x_13 = xx[B_Idx, t1_1, t2_3 ]
    x_20 = xx[B_Idx, t1_2, t2_0 ]
    x_21 = xx[B_Idx, t1_2, t2_1 ]
    x_22 = xx[B_Idx, t1_2, t2_2 ]
    x_23 = xx[B_Idx, t1_2, t2_3 ]
    x_30 = xx[B_Idx, t1_3, t2_0 ]
    x_31 = xx[B_Idx, t1_3, t2_1 ]
    x_32 = xx[B_Idx, t1_3, t2_2 ]
    x_33 = xx[B_Idx, t1_3, t2_3 ]

    # generate the 16 weights which will be used to combine the x_ij
    #
    # note:
    #    phi(u) = { 1.5 |u|^3 - 2.5 |u|^2 + 1           0 <= |u| <= 1   (0)
    #             { -.5 |u|^3 + 2.5 |u|^2 - 4 |u| + 2   1 <= |u| <= 2   (1)

    # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
    #          and u is positive (|u| = u)
    u = tau1 - t1_0
    a0 = -.5 * u ** 3 + 2.5 * u ** 2 - 4 * u + 2

    # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
    #          and u is positive (|u| = u)
    u = tau1 - t1_1
    a1 = 1.5 * u ** 3 - 2.5 * u ** 2 + 1

    # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
    #          and u is negative (|u| = - u)
    u = tau1 - t1_2
    a2 = -1.5 * u ** 3 - 2.5 * u ** 2 + 1

    # 3: here, we are in case (1)
    #          and u is negative (|u| = - u)
    u = tau1 - t1_3
    a3 = .5 * u ** 3 + 2.5 * u ** 2 + 4 * u + 2

    # phi_dot(u) = {  4.5 sgn(u) u^2 - 5 u              0 <= |u| <= 1   (0)
    #              { -1.5 sgn(u) u^2 + 5 u - 4 sgn(u)   1 <= |u| <= 2   (1)

    # 0: here, we are in case (1), because t1_0 + 1 <= tau1 <= t1_0 + 2
    #          and u is positive (sgn(u) = 1)
    u = tau2 - t2_0
    b0 = -1.5 * u ** 2 + 5 * u - 4

    # 1: here, we are in case (0), because t1_1 <= tau1 <= t1_0 + 1
    #          and u is positive (sgn(u) = 1)
    u = tau2 - t2_1
    b1 = 4.5 * u ** 2 - 5 * u

    # 2: here, we are in case (0) because tau1 <= t1_2 <= tau1 + 1
    #          and u is negative (sgn(u) = -1)
    u = tau2 - t2_2
    b2 = -4.5 * u ** 2 - 5 * u

    # 3: here, we are in case (1)
    #          and u is negative (sgn(u) = -1)
    u = tau2 - t2_3
    b3 = 1.5 * u ** 2 + 5 * u + 4

    dx_pr_dtau2 = ((a0*b0)[...,None] * x_00
            + (a0*b1)[...,None] * x_01
            + (a0*b2)[...,None] * x_02
            + (a0*b3)[...,None] * x_03
            + (a1*b0)[...,None] * x_10
            + (a1*b1)[...,None] * x_11
            + (a1*b2)[...,None] * x_12
            + (a1*b3)[...,None] * x_13
            + (a2*b0)[...,None] * x_20
            + (a2*b1)[...,None] * x_21
            + (a2*b2)[...,None] * x_22
            + (a2*b3)[...,None] * x_23
            + (a3*b0)[...,None] * x_30
            + (a3*b1)[...,None] * x_31
            + (a3*b2)[...,None] * x_32
            + (a3*b3)[...,None] * x_33)

    return dx_pr_dtau2