from __future__ import absolute_import

import torch
import torch.nn as nn

'''
this folder and code is modified base on SemGCN code,
https://github.com/garyzhao/SemGCN
the Simple Yet Baseline model.
'''

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)


class Linear(nn.Module):
    def __init__(self, linear_size, p_dropout=0.5):
        super(Linear, self).__init__()
        self.l_size = linear_size

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p_dropout)

        self.w1 = nn.Linear(self.l_size, self.l_size)
        self.batch_norm1 = nn.BatchNorm1d(self.l_size)

        self.w2 = nn.Linear(self.l_size, self.l_size)
        self.batch_norm2 = nn.BatchNorm1d(self.l_size)

    def forward(self, x):
        y = self.w1(x)
        y = self.batch_norm1(y)
        y = self.relu(y)
        y = self.dropout(y)

        y = self.w2(y)
        y = self.batch_norm2(y)
        y = self.relu(y)
        y = self.dropout(y)

        out = x + y

        return out


class LinearModel(nn.Module):
    def __init__(self, input_size, output_size, linear_size=1024, num_stage=2, p_dropout=0.5, num_branches=5):
        super(LinearModel, self).__init__()

        self.linear_size = linear_size
        self.p_dropout = p_dropout
        self.num_stage = num_stage
        self.num_branches = num_branches
        self.noise = True

        # 2d joints
        self.input_size = input_size  # 16 * 2
        # 3d joints
        self.output_size = output_size  # 16 * 3

        # process input to linear size
        self.w1 = nn.Linear(self.input_size, self.linear_size)
        self.batch_norm1 = nn.BatchNorm1d(self.linear_size)

        self.linear_stages = []
        for l in range(num_stage):
            self.linear_stages.append(Linear(self.linear_size, self.p_dropout))
        self.linear_stages = nn.ModuleList(self.linear_stages)

        # post processing
        self.regression_stages = []
        for rl in range(self.num_branches):
            self.regression_stages.append(nn.Linear(self.linear_size, self.output_size+1))
        self.regression_stages = nn.ModuleList(self.regression_stages)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(self.p_dropout)
        self.dropout_mh = nn.Dropout(0.1)
        self.sigmoid = nn.Sigmoid()

        self.arms = [10,11,12,13,14,15]
        self.torsos = [7,8,9]
        self.legs = [1,2,3,4,5,6]

    def forward(self, x):
        """
        input: bx16x2 / bx32
        output: bx16x3
        """
        if len(x.shape) == 2:
            x = x.view(x.shape[0], 16, 2)
        # pre-processing
        x = x.view(x.shape[0], 16 * 2)  # 0924

        y = self.w1(x)
        y = self.batch_norm1(y)
        y = self.relu(y)
        y = self.dropout(y)

        # linear layers
        for i in range(self.num_stage):
            y = self.linear_stages[i](y)

        # regression layers
        poses = []
        scales = []
        # x = x.reshape(-1,16,2)
        for i in range(self.num_branches):
            y_hat = self.regression_stages[i](y)
            poses.append(y_hat[:, :self.output_size])
            scales.append(self.relu(y_hat[:, self.output_size:])+1)

        scaled_poses = []
        for i in range(self.num_branches):
            scaled_poses.append(poses[i]/(scales[i]))

        scaled_poses = torch.stack(scaled_poses, dim=0)     # [num_branch, #b, 48]           

        return scaled_poses

