import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings

from PIL import Image
from torchvision.transforms import transforms

warnings.filterwarnings("ignore")  # remove warning
sigma = 0.1  # 0, 0.1, 0.2, 0.4, 0.6, 0.8


def relaxed_rot(x, i, g, delta):
    if len(x.shape) >= 5:
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
    theta = 2 * np.pi / g * i
    delta = delta.to(dtype=x.dtype, device=x.device)
    rot_mat = torch.Tensor([[np.cos(theta), -np.sin(theta)],
                            [np.sin(theta), np.cos(theta)]]).to(dtype=x.dtype, device=x.device)
    rot_mat += delta
    extra = torch.Tensor([[0.], [0.]]).to(dtype=x.dtype, device=x.device)
    rot_mat = torch.cat((rot_mat, extra), dim=1)
    rot_mat = rot_mat.repeat(x.shape[0], 1, 1)
    grid = F.affine_grid(rot_mat, x.size())
    x = F.grid_sample(x, grid)
    return x.view(x.shape[0], -1, g, x.shape[-2], x.shape[-1]) if len(x.shape) >= 5 else x


def strictly_rot(x, i, g):
    if len(x.shape) >= 5:
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
    theta = 2 * np.pi / g * i
    rot_mat = torch.Tensor([[np.cos(theta), -np.sin(theta), 0],
                            [np.sin(theta), np.cos(theta), 0]]).to(dtype=x.dtype, device=x.device)
    rot_mat = rot_mat.repeat(x.shape[0], 1, 1)
    grid = F.affine_grid(rot_mat, x.size())
    x = F.grid_sample(x, grid)
    return x.view(x.shape[0], -1, g, x.shape[-2], x.shape[-1]) if len(x.shape) >= 5 else x


g_order = 4


class R2Lift(nn.Module):
    def __init__(self, c1, c2, k, s, p, g=g_order):
        super(R2Lift, self).__init__()
        self.c1 = c1
        self.c2 = c2
        self.k = k
        self.s = s
        self.p = p
        self.g = g
        self.w = nn.Parameter(torch.empty(c2, c1, k, k))
        nn.init.kaiming_uniform_(self.w, a=(5 ** 0.5))
        self.delta = nn.Parameter(torch.empty(g, 2, 2))
        nn.init.uniform_(self.delta, a=-sigma, b=sigma)

    def build_filters(self):
        rotated_filters = [relaxed_rot(self.w, i, self.g, self.delta[i]) for i in range(self.g)]
        # rotated_filters = [strictly_rot(self.w, i, self.g) for i in range(self.g)]
        rotated_filters = torch.stack(rotated_filters, dim=1)
        return rotated_filters.view(self.c2 * self.g, self.c1, self.k, self.k)

    def forward(self, x):
        x = torch.conv2d(
            x,
            self.build_filters(),
            stride=self.s,
            padding=self.p,
            bias=None
        )
        return x.view(x.shape[0], -1, self.g, x.shape[-2], x.shape[-1])


# Point-wise operator for R2GConv
class PR2GConv(nn.Module):
    def __init__(self, c1, c2, g=g_order):
        super(PR2GConv, self).__init__()
        self.c1 = c1
        self.c2 = c2
        self.k = 1
        self.s = 1
        self.p = 0
        self.g = g
        self.pw = nn.Parameter(torch.empty(c2, c1, g, 1, 1))
        nn.init.kaiming_uniform_(self.pw, a=(5 ** 0.5))

    def build_pw_filters(self):
        rotated_filters = []
        for i in range(self.g):
            rotated_filters.append(torch.roll(self.pw, i, dims=-3))
        rotated_filters = torch.stack(rotated_filters, dim=1)
        return rotated_filters.view(self.c2 * self.g, self.c1 * self.g, 1, 1)

    def forward(self, x):
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
        x = torch.conv2d(
            x,
            self.build_pw_filters(),
            bias=None
        )
        return x.view(x.shape[0], -1, self.g, x.shape[-2], x.shape[-1])


# Depth-wise operator for R2GConv
class DR2GConv(nn.Module):
    def __init__(self, c1, c2, k, s, p, g=g_order):
        super(DR2GConv, self).__init__()
        assert c1 == c2
        self.c1 = c1
        self.c2 = c2
        self.k = k
        self.s = s
        self.p = p
        self.g = g
        self.groups = c2 * g
        self.dw = nn.Parameter(torch.empty(c2, 1, k, k))
        nn.init.kaiming_uniform_(self.dw, a=(5 ** 0.5))
        self.delta = nn.Parameter(torch.empty(g, 2, 2))
        nn.init.uniform_(self.delta, a=-sigma, b=sigma)

    def build_dw_filters(self):
        rotated_filters = [relaxed_rot(self.dw, i, self.g, self.delta[i]) for i in range(self.g)]
        # rotated_filters = [strictly_rot(self.dw, i, self.g) for i in range(self.g)]
        rotated_filters = torch.stack(rotated_filters, dim=1)
        return rotated_filters.view(self.c2 * self.g, 1, self.k, self.k)

    def forward(self, x):
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
        x = torch.conv2d(
            x,
            self.build_dw_filters(),
            stride=self.s,
            padding=self.p,
            groups=self.groups,
            bias=None
        )
        return x.view(x.shape[0], -1, self.g, x.shape[-2], x.shape[-1])


# Depth-wise operator for R2GUp, 2 x upsample with stride=2
class DR2GUp(nn.Module):
    def __init__(self, c1, c2, g=g_order):
        super(DR2GUp, self).__init__()
        assert c1 == c2
        self.c1 = c1
        self.c2 = c2
        self.k = 2
        self.s = 2
        self.p = 0
        self.g = g
        self.groups = c2 * g
        self.dw = nn.Parameter(torch.empty(c2, 1, self.k, self.k))
        nn.init.kaiming_uniform_(self.dw, a=(5 ** 0.5))
        self.delta = nn.Parameter(torch.empty(g, 2, 2))
        nn.init.uniform_(self.delta, a=-sigma, b=sigma)

    def build_dw_filters(self):
        rotated_filters = [relaxed_rot(self.dw, i, self.g, self.delta[i]) for i in range(self.g)]
        # rotated_filters = [strictly_rot(self.dw, i, self.g) for i in range(self.g)]
        rotated_filters = torch.stack(rotated_filters, dim=1)
        return rotated_filters.view(self.c2 * self.g, 1, self.k, self.k)

    def forward(self, x):
        x = x.view(x.shape[0], -1, x.shape[-2], x.shape[-1])
        # using transposed conv for the upsample
        x = torch.conv_transpose2d(
            x,
            self.build_dw_filters(),
            stride=self.s,
            padding=self.p,
            groups=self.groups,
            bias=None
        )
        return x.view(x.shape[0], -1, self.g, x.shape[-2], x.shape[-1])


# ER2GConv (Efficient R2GConv) = PR2GConv + DR2GConv, i.e., Point-wise + Depth-wise operators for R2GConv
class ER2GConv(torch.nn.Module):
    def __init__(self, c1, c2, k, s, p):
        super(ER2GConv, self).__init__()
        self.conv = nn.Sequential(
            PR2GConv(c1, c2),
            DR2GConv(c2, c2, k, s, p)
        )

    def forward(self, x):
        return self.conv(x)


# ER2GUp = PR2GConv + DR2GUp, i.e., Point-wise + Depth-wise operators for R2GUp, 2 x upsample with stride = 2
class ER2GUp(torch.nn.Module):
    def __init__(self, c1, c2):
        super(ER2GUp, self).__init__()
        self.up = nn.Sequential(
            PR2GConv(c1, c2),
            DR2GUp(c2, c2)
        )

    def forward(self, x):
        return self.up(x)
