import torch
from torch import nn
from torch.nn.functional import affine_grid, grid_sample

from augment.utils import Augment


def to_theta(angle):
    return torch.pi * angle / 180


def to_angle(theta):
    return 180 * theta / torch.pi


def to_rotation_matrix(angle):
    sin = torch.sin(angle)
    cos = torch.cos(angle)
    zero = torch.zeros(1, device=angle.device)
    return torch.stack([
        torch.cat([cos, -sin, zero]),
        torch.cat([sin, cos, zero]),
    ])


def transform_images(imgs, matrix, mode='bilinear'):
    if imgs.dim() == 4 and matrix.dim() == 2:
        matrix = matrix.expand(imgs.size(0), -1, -1)
    grid = affine_grid(matrix, imgs.size(), align_corners=False)
    return grid_sample(imgs, grid, mode, padding_mode='border',
                       align_corners=False)


class Rotate(Augment):
    def __init__(self, angle=90, scaler=4, requires_grad=False):
        super().__init__()
        self.scaler = scaler
        self.theta = nn.Parameter(
            torch.tensor([to_theta(angle) / scaler], dtype=torch.float32),
            requires_grad=requires_grad
        )

    def forward(self, imgs):
        matrix = to_rotation_matrix(self.scaler * self.theta)
        return transform_images(imgs, matrix)

    def get_parameters(self):
        return [to_angle(self.scaler * self.theta.item())]
