import numpy as np
import torch
from torch.nn import functional as F


def make_affine_matrix(sc, ro, tx, ty, use_opencv=False):
    if use_opencv:
        matrix = np.zeros((2, 3))

        sc = float(sc.data.cpu())
        ro = float(ro.data.cpu())
        tx = float(tx.data.cpu())
        ty = float(ty.data.cpu())

        matrix[0, 0] = sc * np.cos(ro)
        matrix[0, 1] = -sc * np.sin(ro)
        matrix[0, 2] = tx * sc * np.cos(ro) - ty * sc * np.sin(ro)
        matrix[1, 0] = sc * np.sin(ro)
        matrix[1, 1] = sc * np.cos(ro)
        matrix[1, 2] = tx * sc * np.sin(ro) + ty * sc * np.cos(ro)

    else:
    
        matrix = torch.zeros((1, 2, 3))

        # for making visualization for resnet18 and googlenet
        if sc.device.type != 'cpu':
            sc = sc.data.cpu()
            ro = ro.data.cpu()
            tx = tx.data.cpu()
            ty = ty.data.cpu()

        matrix[0, 0, 0] = sc * torch.cos(ro)
        matrix[0, 0, 1] = -sc * torch.sin(ro)
        matrix[0, 0, 2] = tx * sc * torch.cos(ro) - ty * sc * torch.sin(ro)
        matrix[0, 1, 0] = sc * torch.sin(ro)
        matrix[0, 1, 1] = sc * torch.cos(ro)
        matrix[0, 1, 2] = tx * sc * torch.sin(ro) + ty * sc * torch.cos(ro)

    return matrix


def create_next_frame(s, r, x, y, data, device, use_opencv=False):

    affine_matrix = make_affine_matrix(s, r, x, y, use_opencv)

    if len(data.shape) == 2:
        h_, w_ = data.shape
        c_ = 1
    else:
        c_, h_, w_ = data.shape

    affine_matrix = F.affine_grid(theta=affine_matrix, size=[1, c_, h_, w_])
    # affine_matrix = affine_matrix.cuda(device)
    if len(data.shape) == 2:
        data = data.unsqueeze(0).unsqueeze(0)
        affine_matrix = F.grid_sample(data, affine_matrix,)
        return affine_matrix[0][0]
    else:
        data = data.unsqueeze(0)
        affine_matrix = F.grid_sample(data, affine_matrix)
        return affine_matrix[0]


def normalize(data, use_opencv=False, z=255.):

    y = 0
    if use_opencv:
        a = float(data.max())
        b = float(data.min())
    else:
        if len(data.shape) == 2:
            a = float(data.max().cpu())
            b = float(data.min().cpu())
        else:
            a_s = []
            b_s = []
            for c in range(data.shape[0]):
                a_s.append(float(data[c].max()))
                b_s.append(float(data[c].min()))
            
            a1, a2 = data[0].shape
            a_tmp = np.ones(shape=(3, a1, a2))
            a_tmp[0] = a_tmp[0] * a_s[0]
            a_tmp[1] = a_tmp[1] * a_s[1]
            a_tmp[2] = a_tmp[2] * a_s[2]
            a_s = a_tmp

            b_tmp = np.ones(shape=(3, a1, a2))
            b_tmp[0] = b_tmp[0] * b_s[0]
            b_tmp[1] = b_tmp[1] * b_s[1]
            b_tmp[2] = b_tmp[2] * b_s[2]
            b_s = b_tmp

    if len(data.shape) == 2:
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                c = data[i, j]
                data[i, j] = (c - a) * (z - y) / (b - a) + y
    else:

        y_tmp = np.ones(shape=data[0].shape) * y
        y = y_tmp
        z_tmp = np.ones(shape=data[0].shape) * z
        z = z_tmp
        
        for i in range(data.shape[0]):
            c = data[i]
            num = (c - b_s[i])
            denom = (a_s[i]-b_s[i])
            denom = np.where(denom == 0, 1, denom)
            # print('normalization: ', num, denom)
            data[i] = (num / denom) * (z-y) + y
            
    return data



