from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from feature_extractor.cutie.cutie.model.channel_attn import CAResBlock


def interpolate_groups(g: torch.Tensor, ratio: float, mode: str,
                       align_corners: bool) -> torch.Tensor:
    batch_size, num_objects = g.shape[:2]
    g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
                      scale_factor=ratio,
                      mode=mode,
                      align_corners=align_corners)
    g = g.view(batch_size, num_objects, *g.shape[1:])
    return g


def upsample_groups(g: torch.Tensor,
                    ratio: float = 2,
                    mode: str = 'bilinear',
                    align_corners: bool = False) -> torch.Tensor:
    return interpolate_groups(g, ratio, mode, align_corners)


def downsample_groups(g: torch.Tensor,
                      ratio: float = 1 / 2,
                      mode: str = 'area',
                      align_corners: bool = None) -> torch.Tensor:
    return interpolate_groups(g, ratio, mode, align_corners)


class GConv2d(nn.Conv2d):
    def forward(self, g: torch.Tensor) -> torch.Tensor:
        batch_size, num_objects = g.shape[:2]
        g = super().forward(g.flatten(start_dim=0, end_dim=1))
        return g.view(batch_size, num_objects, *g.shape[1:])


class GroupResBlock(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()

        if in_dim == out_dim:
            self.downsample = nn.Identity()
        else:
            self.downsample = GConv2d(in_dim, out_dim, kernel_size=1)

        self.conv1 = GConv2d(in_dim, out_dim, kernel_size=3, padding=1)
        self.conv2 = GConv2d(out_dim, out_dim, kernel_size=3, padding=1)

    def forward(self, g: torch.Tensor) -> torch.Tensor:
        out_g = self.conv1(F.relu(g))
        out_g = self.conv2(F.relu(out_g))

        g = self.downsample(g)

        return out_g + g


class MainToGroupDistributor(nn.Module):
    def __init__(self,
                 x_transform: Optional[nn.Module] = None,
                 g_transform: Optional[nn.Module] = None,
                 method: str = 'cat',
                 reverse_order: bool = False):
        super().__init__()

        self.x_transform = x_transform
        self.g_transform = g_transform
        self.method = method
        self.reverse_order = reverse_order

    def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
        num_objects = g.shape[1]

        if self.x_transform is not None:
            x = self.x_transform(x)

        if self.g_transform is not None:
            g = self.g_transform(g)

        if not skip_expand:
            x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
        if self.method == 'cat':
            if self.reverse_order:
                g = torch.cat([g, x], 2)
            else:
                g = torch.cat([x, g], 2)
        elif self.method == 'add':
            g = x + g
        elif self.method == 'mulcat':
            g = torch.cat([x * g, g], dim=2)
        elif self.method == 'muladd':
            g = x * g + g
        else:
            raise NotImplementedError

        return g


class GroupFeatureFusionBlock(nn.Module):
    def __init__(self, x_in_dim: int, g_in_dim: int, out_dim: int):
        super().__init__()

        x_transform = nn.Conv2d(x_in_dim, out_dim, kernel_size=1)
        g_transform = GConv2d(g_in_dim, out_dim, kernel_size=1)

        self.distributor = MainToGroupDistributor(x_transform=x_transform,
                                                  g_transform=g_transform,
                                                  method='add')
        self.block1 = CAResBlock(out_dim, out_dim)
        self.block2 = CAResBlock(out_dim, out_dim)

    def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
        batch_size, num_objects = g.shape[:2]

        g = self.distributor(x, g)

        g = g.flatten(start_dim=0, end_dim=1)

        g = self.block1(g)
        g = self.block2(g)

        g = g.view(batch_size, num_objects, *g.shape[1:])

        return g