import time
import numpy as np
import matplotlib.pyplot as plt
import time

from tqdm import tqdm

import cupy as cp

import imageio
import os

def rotation_mat_cupy(theta):
  sin = cp.sin(theta)
  cos = cp.cos(theta)
  mat = cp.array([[cos, -sin], [sin, cos]])
  return mat
def identity_vf_cupy(M, N, RM=None, RN=None):
    """Get vector field for the identity transformation.
    This returns the vector field (tau_u, tau_v) corresponding to the identity
    transformation, which maps the image plane to itself.
    For more details on these vector fields, see the doc for affine_to_vf
    inputs:
    --------
    M : int
        vertical (number of rows) size of image plane being worked with
    N : int
        horizontal (number of cols) size of image plane being worked with
    RM : int (optional)
        number of points in the M direction desired. by default, this is M,
        giving the identity transformation. when a number other than M is
        provided, this corresponds to a resampling in the vertical direction.
        (we put this operation in this function because it is so naturally
        related)
    RN : int (optional)
        number of points in the N direction desired. by default, this is N,
        giving the identity transformation. when a number other than N is
        provided, this corresponds to a resampling in the horizontal direction.
        (we put this operation in this function because it is so naturally
        related)
    outputs:
    -------
    eu : numpy.ndarray (size (M, N))
        horizontal component of vector field corresponding to (I, 0)
    ev : numpy.ndarray (size (M, N))
        vertical component of vector field corresponding to (I, 0)

    """

    if RM is None:
        RM = M
    if RN is None:
        RN = N

    m_vec = cp.linspace(0, M-1, RM)
    n_vec = cp.linspace(0, N-1, RN)

    eu = cp.dot(m_vec[:,cp.newaxis], cp.ones(RN)[:,cp.newaxis].T)
    ev = cp.dot(cp.ones(RM)[:,cp.newaxis], n_vec[:,cp.newaxis].T)

    return (eu, ev)

def cconv_fourier_cupy(x, y):
    """Compute the circulant convolution of two images in Fourier space.

    Implementing this on its own because scipy.signal.fftconvolve seems to
    handle restriction in its 'same' mode incorrectly

    This function is implemented to work with potentially many-channel images:
    it will just perform the 2D convolution on the *first two dimensions* of
    the inputs. So permute dims if data is such that batch size/etc is first...

    Requires:
    x and y need to have the same shape / be broadcastable. (no automatic
    padding)

    """


    F_X = cp.fft.fft2(x, axes=(0, 1), norm='backward')
    F_Y = cp.fft.fft2(y, axes=(0, 1), norm='backward')
    F_XY = F_X * F_Y

    return cp.real(cp.fft.ifft2(F_XY, axes=(0, 1)))

def gaussian_filter_1d_cupy(N, sigma=1, offset=0):

    i = cp.arange(0, N)

    g = 1/cp.sqrt(2*cp.pi*sigma**2) * cp.exp(-((i - offset + (N-1)/2) % N -
        (N-1)/2)**2 / 2/ sigma**2)

    return g / cp.linalg.norm(g,ord=1)

def gaussian_filter_2d_cupy(M, N=None, sigma_u=1, sigma_v=None, offset_u = 0,
        offset_v = 0):

    if N is None:
        N = M
    if sigma_v is None:
        sigma_v = sigma_u

    # The filter is separable: two 1D filters generate it
    gi = gaussian_filter_1d_cupy(M, sigma=sigma_u, offset=offset_u)
    gj = gaussian_filter_1d_cupy(N, sigma=sigma_v, offset=offset_v)

    # Expand dimensions for outer product...

    return gi[:,cp.newaxis].dot(gj[:,cp.newaxis].T)

def dsp_flip_cupy(X):


    Ndims = len(X.shape)
    ax = tuple(range(Ndims)) * 2
    # what's a log factor between friends?
    return cp.real(cp.fft.fft2(X, axes=ax, norm='ortho'))
