import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class decoder_default:
    def __init__(self, weight=1, use_weight_map=False):
        self.weight = weight
        self.use_weight_map = use_weight_map

    def _make_grid(self, h, w):
        yy, xx = torch.meshgrid(
            torch.arange(h).float() / (h - 1) * 2 - 1,
            torch.arange(w).float() / (w - 1) * 2 - 1)
        return yy, xx

    def get_coords_from_heatmap(self, heatmap):
        """
            inputs:
            - heatmap: batch x npoints x h x w

            outputs:
            - coords: batch x npoints x 2 (x,y), [-1, +1]
            - radius_sq: batch x npoints
        """
        batch, npoints, h, w = heatmap.shape
        if self.use_weight_map:
            heatmap = heatmap * self.weight

        yy, xx = self._make_grid(h, w)
        yy = yy.view(1, 1, h, w).to(heatmap)
        xx = xx.view(1, 1, h, w).to(heatmap)

        heatmap_sum = torch.clamp(heatmap.sum([2, 3]), min=1e-6)

        yy_coord = (yy * heatmap).sum([2, 3]) / heatmap_sum  # batch x npoints
        xx_coord = (xx * heatmap).sum([2, 3]) / heatmap_sum  # batch x npoints
        coords = torch.stack([xx_coord, yy_coord], dim=-1)

        return coords


class AddCoordsTh(nn.Module):
    def __init__(self, x_dim, y_dim, with_r=False, with_boundary=False):
        super(AddCoordsTh, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.with_r = with_r
        self.with_boundary = with_boundary

    def forward(self, input_tensor, heatmap=None):
        """
        input_tensor: (batch, c, x_dim, y_dim)
        """
        batch_size_tensor = input_tensor.shape[0]

        xx_ones = torch.ones([1, self.y_dim], dtype=torch.int32).to(input_tensor)
        xx_ones = xx_ones.unsqueeze(-1)

        xx_range = torch.arange(self.x_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
        xx_range = xx_range.unsqueeze(1)

        xx_channel = torch.matmul(xx_ones.float(), xx_range.float())
        xx_channel = xx_channel.unsqueeze(-1)

        yy_ones = torch.ones([1, self.x_dim], dtype=torch.int32).to(input_tensor)
        yy_ones = yy_ones.unsqueeze(1)

        yy_range = torch.arange(self.y_dim, dtype=torch.int32).unsqueeze(0).to(input_tensor)
        yy_range = yy_range.unsqueeze(-1)

        yy_channel = torch.matmul(yy_range.float(), yy_ones.float())
        yy_channel = yy_channel.unsqueeze(-1)

        xx_channel = xx_channel.permute(0, 3, 2, 1)
        yy_channel = yy_channel.permute(0, 3, 2, 1)

        xx_channel = xx_channel / (self.x_dim - 1)
        yy_channel = yy_channel / (self.y_dim - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size_tensor, 1, 1, 1)
        yy_channel = yy_channel.repeat(batch_size_tensor, 1, 1, 1)

        if self.with_boundary and type(heatmap) != type(None):
            boundary_channel = torch.clamp(heatmap[:, -1:, :, :],
                                        0.0, 1.0)

            zero_tensor = torch.zeros_like(xx_channel).to(xx_channel)
            xx_boundary_channel = torch.where(boundary_channel>0.05,
                                              xx_channel, zero_tensor)
            yy_boundary_channel = torch.where(boundary_channel>0.05,
                                              yy_channel, zero_tensor)
        ret = torch.cat([input_tensor, xx_channel, yy_channel], dim=1)


        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel, 2) + torch.pow(yy_channel, 2))
            rr = rr / torch.max(rr)
            ret = torch.cat([ret, rr], dim=1)

        if self.with_boundary and type(heatmap) != type(None):
            ret = torch.cat([ret, xx_boundary_channel,
                             yy_boundary_channel], dim=1)
        return ret


class CoordConvTh(nn.Module):
    """CoordConv layer as in the paper."""
    def __init__(self, x_dim, y_dim, with_r, with_boundary,
                 in_channels, out_channels, first_one=False, relu=False, bn=False, *args, **kwargs):
        super(CoordConvTh, self).__init__()
        self.addcoords = AddCoordsTh(x_dim=x_dim, y_dim=y_dim, with_r=with_r,
                                    with_boundary=with_boundary)
        in_channels += 2
        if with_r:
            in_channels += 1
        if with_boundary and not first_one:
            in_channels += 2
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, *args, **kwargs)
        self.relu = nn.ReLU() if relu else None
        self.bn = nn.BatchNorm2d(out_channels) if bn else None

        self.with_boundary = with_boundary
        self.first_one = first_one


    def forward(self, input_tensor, heatmap=None):
        assert (self.with_boundary and not self.first_one) == (heatmap is not None)
        ret = self.addcoords(input_tensor, heatmap)
        ret = self.conv(ret)
        if self.bn is not None:
            ret = self.bn(ret)
        if self.relu is not None:
            ret = self.relu(ret)

        return ret


class Activation(nn.Module):
    def __init__(self, kind: str = 'relu', channel=None):
        super().__init__()
        self.kind = kind

        if '+' in kind:
            norm_str, act_str = kind.split('+')
        else:
            norm_str, act_str = 'none', kind

        self.norm_fn = {
            'in': F.instance_norm,
            'bn': nn.BatchNorm2d(channel),
            'bn_noaffine': nn.BatchNorm2d(channel, affine=False, track_running_stats=True),
            'none': None
        }[norm_str]

        self.act_fn = {
            'relu': F.relu,
            'softplus': nn.Softplus(),
            'exp': torch.exp,
            'sigmoid': torch.sigmoid,
            'tanh': torch.tanh,
            'none': None
        }[act_str]

        self.channel = channel

    def forward(self, x):
        if self.norm_fn is not None:
            x = self.norm_fn(x)
        if self.act_fn is not None:
            x = self.act_fn(x)
        return x

    def extra_repr(self):
        return f'kind={self.kind}, channel={self.channel}'


class ConvBlock(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, groups=1):
        super(ConvBlock, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size,
                              stride, padding=(kernel_size - 1) // 2, groups=groups, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class ResBlock(nn.Module):
    def __init__(self, inp_dim, out_dim, mid_dim=None):
        super(ResBlock, self).__init__()
        if mid_dim is None:
            mid_dim = out_dim // 2
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = ConvBlock(inp_dim, mid_dim, 1, relu=False)
        self.bn2 = nn.BatchNorm2d(mid_dim)
        self.conv2 = ConvBlock(mid_dim, mid_dim, 3, relu=False)
        self.bn3 = nn.BatchNorm2d(mid_dim)
        self.conv3 = ConvBlock(mid_dim, out_dim, 1, relu=False)
        self.skip_layer = ConvBlock(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True

    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out


class Hourglass(nn.Module):
    def __init__(self, n, f, increase=0, up_mode='nearest', add_coord=False, first_one=False, x_dim=64, y_dim=64):
        super(Hourglass, self).__init__()
        nf = f + increase

        Block = ResBlock

        if add_coord:
            self.coordconv = CoordConvTh(x_dim=x_dim, y_dim=y_dim,
                                         with_r=True, with_boundary=True,
                                         relu=False, bn=False,
                                         in_channels=f, out_channels=f,
                                         first_one=first_one,
                                         kernel_size=1,
                                         stride=1, padding=0)
        else:
            self.coordconv = None
        self.up1 = Block(f, f)

        # Lower branch
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.low1 = Block(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n=n - 1, f=nf, increase=increase, up_mode=up_mode, add_coord=False)
        else:
            self.low2 = Block(nf, nf)
        self.low3 = Block(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode=up_mode)

    def forward(self, x, heatmap=None):
        if self.coordconv is not None:
            x = self.coordconv(x, heatmap)
        up1 = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2 = self.up2(low3)
        return up1 + up2


class E2HTransform(nn.Module):
    def __init__(self, edge_info, num_points, num_edges):
        super().__init__()

        e2h_matrix = np.zeros([num_points, num_edges])
        for edge_id, isclosed_indices in enumerate(edge_info):
            is_closed, indices = isclosed_indices
            for point_id in indices:
                e2h_matrix[point_id, edge_id] = 1
        e2h_matrix = torch.from_numpy(e2h_matrix).float()

        # pn x en x 1 x 1.
        self.register_buffer('weight', e2h_matrix.view(
            e2h_matrix.size(0), e2h_matrix.size(1), 1, 1))

        # some keypoints are not coverred by any edges,
        # in these cases, we must add a constant bias to their heatmap weights.
        bias = ((e2h_matrix @ torch.ones(e2h_matrix.size(1)).to(
            e2h_matrix)) < 0.5).to(e2h_matrix)
        # pn x 1.
        self.register_buffer('bias', bias)

    def forward(self, edgemaps):
        # input: batch_size x en x hw x hh.
        # output: batch_size x pn x hw x hh.
        return F.conv2d(edgemaps, weight=self.weight, bias=self.bias)


class StylizedFacePoint(nn.Module):
    def __init__(self, args, device='cpu', nlevels=4, in_channel=512, increase=0):
        super(StylizedFacePoint, self).__init__()

        self.classes_num =  [98, 9, 98]
        self.add_coord = True
        self.decoder = decoder_default()
        self.nstack = args.nstack
        self.device = device

        self.num_heats = self.classes_num[0]

        if self.add_coord:
            convBlock_1 = CoordConvTh(x_dim=args.input_size, y_dim=args.input_size,
                                    with_r=True, with_boundary=False,
                                    relu=True, bn=True,
                                    in_channels=3, out_channels=64,
                                    kernel_size=7,
                                    stride=2, padding=3)
        else:
            convBlock_1 = ConvBlock(3, 64, 7, 2, bn=True, relu=True)

        if self.add_coord:
            convBlock_2 = CoordConvTh(x_dim=int(args.input_size/4), y_dim=int(args.input_size/4),
                                    with_r=True, with_boundary=False,
                                    relu=True, bn=True,
                                    in_channels=128, out_channels=128,
                                    kernel_size=7,
                                    stride=2, padding=3)
        else:
            convBlock_2 = ConvBlock(128, 128, 7, 2, bn=True, relu=True)

        pool = nn.MaxPool2d(kernel_size=2, stride=2)

        Block = ResBlock

        self.pre = nn.Sequential(
            convBlock_1,
            Block(64, 128),
            pool,
            convBlock_2,
            Block(128, 256),
            pool,
            Block(256, 256),
            Block(256, in_channel)
        )

        self.hgs = nn.ModuleList(
            [Hourglass(n=nlevels, f=in_channel, increase=increase, add_coord=self.add_coord, first_one=(_ == 0),
                       x_dim=16, y_dim=16)
             for _ in range(self.nstack)])

        self.features = nn.ModuleList([
            nn.Sequential(
                Block(in_channel, in_channel),
                ConvBlock(in_channel, in_channel, 1, bn=True, relu=True)
            ) for _ in range(self.nstack)])

        self.out_heatmaps = nn.ModuleList(
            [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
             for _ in range(self.nstack)])
        self.x_out_heatmaps = nn.ModuleList(
            [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
             for _ in range(self.nstack)])
        self.y_out_heatmaps = nn.ModuleList(
            [ConvBlock(in_channel, self.num_heats, 1, relu=False, bn=False)
             for _ in range(self.nstack)])
        self.nb_x_out_heatmaps = nn.ModuleList(
            [ConvBlock(in_channel, self.num_heats * 3, 1, relu=False, bn=False)
             for _ in range(self.nstack)])
        self.nb_y_out_heatmaps = nn.ModuleList(
            [ConvBlock(in_channel, self.num_heats * 3, 1, relu=False, bn=False)
             for _ in range(self.nstack)])

        self.merge_features = nn.ModuleList(
            [ConvBlock(in_channel, in_channel, 1, relu=False, bn=False)
             for _ in range(self.nstack - 1)])
        self.merge_heatmaps = nn.ModuleList(
            [ConvBlock(self.num_heats * 9, in_channel, 1, relu=False, bn=False)
             for _ in range(self.nstack - 1)])

        self.heatmap_act = Activation("in+relu", self.num_heats)

        self.inference = False

    def set_inference(self, inference):
        self.inference = inference

    def forward(self, x):

        x = self.pre(x)

        fusionmaps = torch.tensor([]).to(self.device)
        xfusionmaps = torch.tensor([]).to(self.device)
        yfusionmaps = torch.tensor([]).to(self.device)
        nbxfusionmaps = torch.tensor([]).to(self.device)
        nbyfusionmaps = torch.tensor([]).to(self.device)
        heatmaps = None
        x_heatmaps = None
        y_heatmaps = None
        nb_x_heatmaps = None
        nb_y_heatmaps = None

        for i in range(self.nstack):
            hg = self.hgs[i](x, heatmap=heatmaps)
            feature = self.features[i](hg)

            heatmaps0 = self.out_heatmaps[i](feature)
            heatmaps = self.heatmap_act(heatmaps0)
            
            x_heatmaps0 = self.x_out_heatmaps[i](feature)
            x_heatmaps = self.heatmap_act(x_heatmaps0)

            y_heatmaps0 = self.y_out_heatmaps[i](feature)
            y_heatmaps = self.heatmap_act(y_heatmaps0)

            nb_x_heatmaps0 = self.nb_x_out_heatmaps[i](feature)
            nb_x_heatmaps = self.heatmap_act(nb_x_heatmaps0)

            nb_y_heatmaps0 = self.nb_y_out_heatmaps[i](feature)
            nb_y_heatmaps = self.heatmap_act(nb_y_heatmaps0)

            if i < self.nstack - 1:
                heatmaps_concat = torch.cat((heatmaps, x_heatmaps, y_heatmaps, nb_x_heatmaps, nb_y_heatmaps), dim = 1)
                x = x + self.merge_features[i](feature) + self.merge_heatmaps[i](heatmaps_concat)

            fusionmaps = torch.cat((fusionmaps, heatmaps), dim=1)
            xfusionmaps = torch.cat((xfusionmaps, x_heatmaps), dim=1)
            yfusionmaps = torch.cat((yfusionmaps, y_heatmaps), dim=1)
            nbxfusionmaps = torch.cat((nbxfusionmaps, nb_x_heatmaps), dim=1)
            nbyfusionmaps = torch.cat((nbyfusionmaps, nb_y_heatmaps), dim=1)

        return fusionmaps, xfusionmaps, yfusionmaps, nbxfusionmaps, nbyfusionmaps