import torch
import torch.nn as nn

class ResnetBlockConv1d(nn.Module):
    """ 1D-Convolutional ResNet block class.
    Args:
        size_in (int): input dimension
        size_out (int): output dimension
        size_h (int): hidden dimension
    """

    def __init__(self, c_dim, size_in, size_h=None, size_out=None,
                 norm_method='batch_norm', legacy=False):
        super().__init__()
        # Attributes
        if size_h is None:
            size_h = size_in
        if size_out is None:
            size_out = size_in

        self.size_in = size_in
        self.size_h = size_h
        self.size_out = size_out
        # Submodules
        if norm_method == 'batch_norm':
            norm = nn.BatchNorm1d
        elif norm_method == 'sync_batch_norm':
            norm = nn.SyncBatchNorm
        else:
             raise Exception("Invalid norm method: %s" % norm_method)

        self.bn_0 = norm(size_in)
        self.bn_1 = norm(size_h)

        self.fc_0 = nn.Conv1d(size_in, size_h, 1)
        self.fc_1 = nn.Conv1d(size_h, size_out, 1)
        self.fc_c = nn.Conv1d(c_dim, size_out, 1)
        self.actvn = nn.ReLU()

        if size_in == size_out:
            self.shortcut = None
        else:
            self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False)

        # Initialization
        nn.init.zeros_(self.fc_1.weight)

    def forward(self, x, c):
        net = self.fc_0(self.actvn(self.bn_0(x)))
        dx = self.fc_1(self.actvn(self.bn_1(net)))

        if self.shortcut is not None:
            x_s = self.shortcut(x)
        else:
            x_s = x

        out = x_s + dx + self.fc_c(c)

        return out


class ShapeGFDecoder(nn.Module):
    """ Decoder conditioned by adding.

    Example configuration:
        z_dim: 128
        hidden_size: 256
        n_blocks: 5
        out_dim: 3  # we are outputting the gradient
        sigma_condition: True
        xyz_condition: True
    """
    def __init__(self, model_cfg, runtime_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.z_dim = z_dim = model_cfg["z_dim"]
        self.dim = dim = model_cfg["dim"]
        self.out_dim = out_dim = model_cfg["out_dim"]
        self.hidden_size = hidden_size = model_cfg["hidden_size"]
        self.n_blocks = n_blocks = model_cfg["n_blocks"]

        # Input = Conditional = zdim (shape) + dim (xyz) + 1 (sigma)
        #c_dim = z_dim + dim + 1
        c_dim = z_dim + dim  
        self.conv_p = nn.Conv1d(c_dim, hidden_size, 1)
        self.blocks = nn.ModuleList([
            ResnetBlockConv1d(c_dim, hidden_size) for _ in range(n_blocks)
        ])
        self.bn_out = nn.BatchNorm1d(hidden_size)
        self.conv_out = nn.Conv1d(hidden_size, out_dim, 1)
        self.actvn_out = nn.ReLU()
        self.forward_dict = {}
    
    def get_loss(self, tb_dict):
        if tb_dict is None:
            tb_dict = {}

        perturbed_point_clouds = self.forward_dict['perturbed_point_clouds']
        framed_point_clouds = self.forward_dict['framed_point_clouds']
        sigmas = self.forward_dict['used_sigmas']

        pred_grad = self.forward_dict['pred_point_grad']
        gt_grad = -(perturbed_point_clouds - framed_point_clouds).reshape(sigmas.shape[0], -1, 3)

        lambda_sigmas = 1.0 / sigmas.unsqueeze(-1)
        loss = 0.5 * ((pred_grad - gt_grad).square() * lambda_sigmas).sum(dim=2).mean()

        tb_dict['loss'] = loss
        return loss, tb_dict
    
    def forward(self, latent_z, points_input):
        """
        Args:
            perturbed_point_clouds [B, N, 3]: input point clouds
                    usually perturbed/generated by random noise
            z [B, L]: latent code for each point cloud
        Returns:
            point_grad [B, N, 3]: point-wise gradient
        """
        #p = batch_dict['perturbed_point_clouds'].transpose(1, 2)
        #c = batch_dict['z']
        p = points_input.transpose(1, 2) 
        c = latent_z
        
        batch_size, D, num_points = p.size()
    
        c_expand = c.unsqueeze(2).expand(-1, -1, num_points)
        c_xyz = torch.cat([p, c_expand], dim=1)

        net = self.conv_p(c_xyz)

        for block in self.blocks:
            net = block(net, c_xyz)

        point_grad = self.conv_out(self.actvn_out(self.bn_out(net))).transpose(1, 2)
        return point_grad 
        #batch_dict['point_grad'] = point_grad

        #if self.training:
        #    for key in ['framed_point_clouds', 'perturbed_point_clouds', 'used_sigmas']:
        #        self.forward_dict[key] = batch_dict[key]
        #    self.forward_dict['pred_point_grad'] = point_grad

        #return batch_dict


