import torch
import torch.nn as nn
import torch.nn.functional as F


def point_sample(input, point_coords, **kwargs):
    """
    From Detectron2, point_features.py#19
    A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
    Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
    [0, 1] x [0, 1] square.
    Args:
        input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
        point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
        [0, 1] x [0, 1] normalized point coordinates.
    Returns:
        output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
            features for points in `point_coords`. The features are obtained via bilinear
            interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
    """
    add_dim = False
    if point_coords.dim() == 3:
        add_dim = True
        point_coords = point_coords.unsqueeze(2)
    output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
    if add_dim:
        output = output.squeeze(3)
    return output


@torch.no_grad()
def sampling_points(mask, N, k=3, beta=0.75, training=True):
    """
    Follows 3.1. Point Selection for Inference and Training
    In Train:, `The sampling strategy selects N points on a feature map to train on.`
    In Inference, `then selects the N most uncertain points`
    Args:
        mask(Tensor): [B, C, H, W]
        N(int): `During training we sample as many points as there are on a stride 16 feature map of the input`
        k(int): Over generation multiplier
        beta(float): ratio of importance points
        training(bool): flag
    Return:
        selected_point(Tensor) : flattened indexing points [B, num_points, 2]
    """
    assert mask.dim() == 4, "Dim must be N(Batch)CHW"
    device = mask.device
    B, _, H, W = mask.shape
    mask, _ = mask.sort(1, descending=True)

    if not training:
        H_step, W_step = 1 / H, 1 / W
        N = min(H * W, N)
        uncertainty_map = -1 * (mask[:, 0] - mask[:, 1])
        _, idx = uncertainty_map.view(B, -1).topk(N, dim=1)

        points = torch.zeros(B, N, 2, dtype=torch.float, device=device)
        points[:, :, 0] = W_step / 2.0 + (idx  % W).to(torch.float) * W_step
        points[:, :, 1] = H_step / 2.0 + (idx // W).to(torch.float) * H_step
        return idx, points

    # Official Comment : point_features.py#92
    # It is crucial to calculate uncertanty based on the sampled prediction value for the points.
    # Calculating uncertainties of the coarse predictions first and sampling them for points leads
    # to worse results. To illustrate the difference: a sampled point between two coarse predictions
    # with -1 and 1 logits has 0 logit prediction and therefore 0 uncertainty value, however, if one
    # calculates uncertainties for the coarse predictions first (-1 and -1) and sampe it for the
    # center point, they will get -1 unceratinty.

    over_generation = torch.rand(B, k * N, 2, device=device)
    over_generation_map = point_sample(mask, over_generation, align_corners=False)

    uncertainty_map = -1 * (over_generation_map[:, 0] - over_generation_map[:, 1])
    _, idx = uncertainty_map.topk(int(beta * N), -1)

    shift = (k * N) * torch.arange(B, dtype=torch.long, device=device)

    idx += shift[:, None]

    importance = over_generation.view(-1, 2)[idx.view(-1), :].view(B, int(beta * N), 2)
    coverage = torch.rand(B, N - int(beta * N), 2, device=device)
    return torch.cat([importance, coverage], 1).to(device)

class dilated_conv_block(nn.Module):
    """
    Convolution Block 
    """

    def __init__(self, in_ch, out_ch, kernel_size=3,dilation=1):
        super(dilated_conv_block, self).__init__()
        
        padding = kernel_size // 2 * dilation
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True,dilation=dilation),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=kernel_size,
                      stride=1, padding=padding, bias=True,dilation=dilation),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x) 
        return x


class Dilated_UNet(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super(Dilated_UNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = dilated_conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = dilated_conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = dilated_conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = dilated_conv_block(filters[2], filters[3], kernel_size=self.ks,dilation=2)
        self.Conv5 = dilated_conv_block(filters[3], filters[4], kernel_size=self.ks,dilation=4)
        
        self.Conv = nn.Conv2d(filters[4],out_ch,kernel_size=1)
        
        self.fine_dim = filters[3-1]
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool(e1)
        e2 = self.Conv2(e2)
        
        e3 = self.Maxpool(e2)
        e3 = self.Conv3(e3)

        e4 = self.Conv4(e3)

        e5 = self.Conv5(e4)

        
        out = self.Conv(e5)
        
        return e3,out



class Dilated_UNet_test(nn.Module):
    """
    UNet - Basic Implementation
    Paper : https://arxiv.org/abs/1505.04597
    """
    def __init__(self, in_ch=3, out_ch=1, kernel_size=3):
        super(Dilated_UNet_test, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
        self.ks = kernel_size
        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = dilated_conv_block(in_ch, filters[0], kernel_size=self.ks)
        self.Conv2 = dilated_conv_block(filters[0], filters[1], kernel_size=self.ks)
        self.Conv3 = dilated_conv_block(filters[1], filters[2], kernel_size=self.ks)
        self.Conv4 = dilated_conv_block(filters[2], filters[3], kernel_size=self.ks,dilation=2)
        self.Conv5 = dilated_conv_block(filters[3], filters[4], kernel_size=self.ks,dilation=4)
        
        self.Conv = nn.Conv2d(filters[4]+filters[2],out_ch,kernel_size=1)
        self._up = nn.Upsample(scale_factor= 4,mode='bilinear', align_corners=True)
        self.fine_dim = filters[3-1]
       # self.active = torch.nn.Sigmoid()

    def forward(self, x):

        e1 = self.Conv1(x)

        e2 = self.Maxpool(e1)
        e2 = self.Conv2(e2)
        
        e3 = self.Maxpool(e2)
        e3 = self.Conv3(e3)

        e4 = self.Conv4(e3)

        e5 = self.Conv5(e4)

        
        out = self.Conv(torch.cat([e3,e5],dim = 1))
        out = self._up(out)
        return out



class PointHead(nn.Module):
    def __init__(self, in_c=512+21, num_classes=21, k=3, beta=0.75):
        super().__init__()
        self.mlp = nn.Conv1d(in_c, num_classes, 1)
        self.k = k
        self.beta = beta

    def forward(self, x, res2, coarse_pred):
        """
        1. Fine-grained features are interpolated from res2 for DeeplabV3
        2. During training we sample as many points as there are on a stride 16 feature map of the input
        3. To measure prediction uncertainty
           we use the same strategy during training and inference: the difference between the most
           confident and second most confident class probabilities.
        """
        if not self.training:
            return self.inference(x, res2, coarse_pred)

        points = sampling_points(coarse_pred, x.shape[-1] // 16, self.k, self.beta)

        coarse = point_sample(coarse_pred, points, align_corners=False)
        fine = point_sample(res2, points, align_corners=False)

        feature_representation = torch.cat([coarse, fine], dim=1)

        rend = self.mlp(feature_representation)

        return {"rend": rend, "points": points}

    @torch.no_grad()
    def inference(self, x, res2, coarse_pred):
        """
        During inference, subdivision uses N=8096
        (i.e., the number of points in the stride 16 map of a 1024×2048 image)
        """
        num_points = 1024

        while coarse_pred.shape[-1] != x.shape[-1]:
            coarse_pred = F.interpolate(coarse_pred, scale_factor=2, mode="bilinear", align_corners=True)

            points_idx, points = sampling_points(coarse_pred, num_points, training=self.training)

            coarse = point_sample(coarse_pred, points, align_corners=False)
            fine = point_sample(res2, points, align_corners=False)

            feature_representation = torch.cat([coarse, fine], dim=1)

            rend = self.mlp(feature_representation)

            B, C, H, W = coarse_pred.shape
            points_idx = points_idx.unsqueeze(1).expand(-1, C, -1)
            coarse_pred = (coarse_pred.reshape(B, C, -1)
                      .scatter_(2, points_idx, rend)
                      .view(B, C, H, W))

        return {"fine": coarse_pred}

    

class PointRend(nn.Module):
    def __init__(self,in_ch=3, out_ch=1):
        super().__init__()
        self.backbone = Dilated_UNet(in_ch,out_ch)
        self.head = PointHead(in_c=self.backbone.fine_dim + out_ch, num_classes=out_ch, k=3, beta=0.75)
        
        
    def forward(self, x):
        result = {}
        result['res2'],result['coarse'] = self.backbone(x)
        result.update(self.head(x, result['res2'], result["coarse"]))
        return result


if __name__=='__main__':
    x = torch.randn(2, 3, 256, 256).cuda()
    net = PointRend(in_ch=3, out_ch=5).cuda()
    result = net(x)
    for k, v in result.items():
        print(k, v.shape)