import numpy as np
import torch
from torch import nn

from augment.utils import Augment


def to_grid(img_h, img_w, device, scale=False):
    mask1 = torch.arange(img_h, dtype=torch.float32, device=device)
    mask2 = torch.arange(img_w, dtype=torch.float32, device=device)
    if scale:
        mask1 /= img_h
        mask2 /= img_w
    mask1 = mask1.view(-1, 1).expand(-1, img_w)
    mask2 = mask2.view(1, -1).expand(img_h, -1)
    return torch.stack([mask1, mask2], dim=2)


class TriMatrix(nn.Module):
    def __init__(self, scale, ratio, angle, requires_grad=False):
        super().__init__()
        self.elements = nn.Parameter(
            self.initialize_cov_matrix(scale, ratio, angle),
            requires_grad=requires_grad
        )
        self.mask = nn.Parameter(
            torch.tril(torch.ones(2, 2, dtype=torch.bool)),
            requires_grad=False
        )

    @staticmethod
    def initialize_cov_matrix(scale, ratio, angle):
        theta = torch.pi * angle / 180
        rot_matrix = torch.tensor([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ], dtype=torch.float32)
        scl_matrix = torch.tensor([
            [scale / ratio, 0],
            [0, scale * ratio]
        ], dtype=torch.float32).sqrt()
        cov_matrix = rot_matrix.matmul(scl_matrix) \
            .matmul(scl_matrix) \
            .matmul(rot_matrix.t())
        return torch.linalg.cholesky(cov_matrix)

    def forward(self):
        return self.mask * self.elements


class CutDiff(Augment):
    def __init__(self, scale=0.02, ratio=1.0, angle=0, requires_grad=False,
                 uniform_batch=False):
        super().__init__()
        self.scale = scale
        self.tri_matrix = TriMatrix(scale, ratio, angle, requires_grad)
        self.uniform_batch = uniform_batch

    def to_patch(self, grid, mu, scaler=0.24):
        sigma_inv = torch.inverse(self.tri_matrix())
        diff = torch.einsum('nijl,kl->nijk', grid - mu, sigma_inv)
        out = torch.einsum('nijk,nijk->nij', diff, diff)
        return torch.exp(-out / scaler)

    def forward(self, imgs):
        device = imgs.device
        img_h, img_w = imgs.shape[-2:]

        with torch.no_grad():
            out = self.tri_matrix().matmul(self.tri_matrix())
            patch_h = out[0, 0] * img_h
            patch_w = out[1, 1] * img_w

        pos_len = 1 if self.uniform_batch else len(imgs)
        pos_i = torch.rand(pos_len, 1, 1, device=device)
        pos_j = torch.rand(pos_len, 1, 1, device=device)
        mu_i = (img_h - patch_h) * pos_i + patch_h / 2
        mu_j = (img_w - patch_w) * pos_j + patch_w / 2
        mu = torch.stack([mu_i / img_h, mu_j / img_w], dim=3)

        grid = to_grid(img_h, img_w, device, scale=True).unsqueeze(0)
        patch = self.to_patch(grid, mu)
        return torch.clamp(imgs - patch.unsqueeze(1), 0, 1)

    def get_parameters(self):
        return [
            self.tri_matrix.elements[0, 0].item(),
            self.tri_matrix.elements[1, 0].item(),
            self.tri_matrix.elements[1, 1].item(),
        ]
