import math

import cv2
import numpy as np
import torch

def rotate(img,angle):
    h, w =img.shape
    center = (h // 2, w // 2)
    M = cv2.getRotationMatrix2D(center, angle, 1)
    return cv2.warpAffine(img, M, (h, w))

def gen_rotation(img, size=(11, 11), ):
    h, w = size
    imgs = np.zeros((40, 40, 40, 64, 64))
    imgs1 = np.zeros((40, h, w))

    for k in range(40):
        center = (h // 2, w // 2)
        M = cv2.getRotationMatrix2D(center, k * 180 / 40, 1)
        imgs1[k] = cv2.warpAffine(img, M, (h, w))

    s = 12 - h // 2
    for k in range(40):
        for i in range(40):
            for j in range(40):
                imgs[k, i, j][s + i:s + i + h, s + j:s + j + w] = imgs1[k]

    labels = np.zeros((40, 40, 40, 3), np.float32)
    for k in range(40):
        for i in range(40):
            for j in range(40):
                labels[k, i, j] = [k, i, j]
    labels = torch.from_numpy(labels)
    return torch.tensor(imgs).float(), labels.float()


def gen_translation(img, size=(11, 11), ):
    h, w = size
    imgs = np.zeros((40, 40, 64, 64))

    s = 12 - h // 2
    for i in range(40):
        for j in range(40):
            imgs[i, j] \
                [s + i:s + i + h, s + j:s + j + w] = img

    labels = np.zeros((40, 40, 2), np.float32)
    for i in range(40):
        for j in range(40):
            labels[i, j] = [i, j]
    labels = torch.from_numpy(labels).reshape(40, 40, 2)
    return torch.tensor(imgs).float(), labels.float()


def gen_polar(img, size=(11, 11), ):
    h, w = size
    imgs = np.zeros((40, 40, 64, 64), np.float32)
    labels = np.zeros((40, 40, 2), np.float32)
    s = 32 - h // 2
    for i, r in enumerate(np.linspace(1, 20, 40)):
        for j, t in enumerate(np.linspace(0, 3.14 * 2, 40)):
            dx = r * math.cos(t) + s
            dy = r * math.sin(t) + s
            M = np.float32([[1, 0, dx], [0, 1, dy]])
            dst = cv2.warpAffine(img, M, (64, 64))
            imgs[i, j] = dst
            labels[i, j] = [dx, dy]

    labels = torch.from_numpy(labels).reshape(40, 40, 2)
    return torch.tensor(imgs).float(), labels.float()

class Translation:
    lat_names = ( 'posY', 'posX')
    lat_sizes = np.array([40,40])
    img_size = (1, 64, 64)
    def __init__(self,img):
        imgs, labels = gen_translation(img, img.shape)
        self.imgs=imgs.reshape(-1,1,64,64)
        self.labels = labels.reshape(-1,2)
        self.o_imgs,self.o_labels=imgs, labels

    def __getitem__(self, item):
        return self.imgs[item],self.labels[item]

    def __len__(self):
        return len(self.imgs)

class RotationTranslation:
    lat_names = ('orientation', 'posY', 'posX')
    lat_sizes = np.array([40,40,40])
    img_size = (1, 64, 64)
    def __init__(self,img):
        imgs, labels = gen_rotation(img, img.shape)
        self.imgs=imgs.reshape(-1,1,64,64)
        self.labels = labels.reshape(-1,3)
        self.o_imgs,self.o_labels=imgs, labels

    def __getitem__(self, item):
        return self.imgs[item],self.labels[item]

    def __len__(self):
        return len(self.imgs)

class PolarTranslation:
    lat_names = ('theta', 'r')
    lat_sizes = np.array([40,40])
    img_size = (1, 64, 64)
    def __init__(self,img):
        imgs, labels = gen_polar(img, img.shape)
        self.imgs=imgs.reshape(-1,1,64,64)
        self.labels = labels.reshape(-1,2)
        self.o_imgs,self.o_labels=imgs, labels

    def __getitem__(self, item):
        return self.imgs[item],self.labels[item]

    def __len__(self):
        return len(self.imgs)

class RelatedTranslation:
    lat_names = ('posY', 'posX')
    lat_sizes = np.array([40,40])
    img_size = (1, 64, 64)
    def __init__(self,img):
        h, w = img.shape
        imgs = []
        s = 12 - h // 2
        labels = []
        for i in range(0,40):
            if i < 20:
                l = i // 2
                r=40-i//2
            else:
                l = 19-i//2
                r = i//2 + 21
            for j in range(l,r+1):
                dy=i
                dx=j
                M = np.float32([[1, 0, dx+s], [0, 1, dy+s]])
                dst = cv2.warpAffine(img, M, (64, 64))
                imgs.append(dst)
                labels.append([dy, dx])

        imgs = torch.tensor(imgs).float()
        self.imgs=imgs.reshape(-1,1,64,64)
        self.labels = torch.tensor(labels).reshape(-1,2)

    def __getitem__(self, item):
        return self.imgs[item],self.labels[item]

    def __len__(self):
        return len(self.imgs)


class Scaling:
    lat_names = ('scaling')
    lat_sizes = np.array([40])
    img_size = (1, 64, 64)

    def __init__(self, img, max_scale):
        h, w = img.shape
        imgs = np.zeros((40, 64, 64))
        for k in range(40):
            s = (max_scale - 1) / 40 * k + 1
            l = 32 - h * s / 2 + 1
            M = np.float32([[s, 0, l], [0, s, l]])
            img1 = img.copy()
            imgs[k] = cv2.warpAffine(img1, M, (64, 64))

        self.imgs = torch.Tensor(imgs).reshape(-1, 1, 64, 64)
        self.labels = torch.arange(40).reshape(-1, 1)

    def __getitem__(self, item):
        return self.imgs[item], self.labels[item]

    def __len__(self):
        return len(self.imgs)


class TranslationX:
    lat_names = ('posY')
    lat_sizes = np.array([40])
    img_size = (1, 64, 64)
    def __init__(self,img,scope):
        h, w = img.shape
        move_x = scope
        imgs = np.zeros((40, 64, 64))

        for j in range(40):
            dx = move_x / 40 * j
            M = np.float32([[1, 0, dx], [0, 1, 0]])
            dst = cv2.warpAffine(img, M, (64, 64))
            imgs[j] = dst

        labels = np.zeros((40, 1), np.float32)
        for k in range(40):
            labels[k, 0] = k
        labels = torch.from_numpy(labels).float()
        imgs = torch.tensor(imgs).float()

        self.imgs=imgs.reshape(-1,1,64,64)
        self.labels = labels.reshape(-1, 1)
        self.o_imgs,self.o_labels=imgs, labels

    def __getitem__(self, item):
        return self.imgs[item],self.labels[item]

    def __len__(self):
        return len(self.imgs)