from torch.nn import Module, Conv2d, ConvTranspose2d
import torch.nn.functional as F
import torchvision.models as models
import torch
import resnet
import torch.nn as nn

def get_and_init_FC_layer(din, dout):
    li = nn.Linear(din, dout)
    nn.init.xavier_uniform_(
       li.weight.data, gain=nn.init.calculate_gain('relu'))
    li.bias.data.fill_(0.)
    return li

class PoseNet(nn.Module):
    """Recurrent Lnet
       input: a sequence of obs with size (B, 3, 80,80)
       output: a sequence of obs with size (B,4,5)
    """
    def __init__(self, num_objects=3):
        super().__init__()
        self.num_objects = num_objects
        self.resnet18_agent = resnet.resnet18(num_input_channels=3)
        self.resnet18_agent.fc = nn.Linear(in_features=512, out_features=128, bias=True)

        self.resnet18_objs = resnet.resnet18(num_input_channels=3)
        self.resnet18_objs.fc = nn.Linear(in_features=512, out_features=256, bias=True)

        self.agent_head = nn.Sequential(
                        get_and_init_FC_layer(128,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,16),
                        nn.ReLU(),
                        get_and_init_FC_layer(16,3),
                        ) 
 
        self.obj_vector_embed_b =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 
        
        self.obj_vector_embed_x =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 
        
        self.obj_vector_embed_y =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 

    def forward(self, obs, objs_vector):
        features_agent = self.resnet18_agent(obs)
        features_objs = self.resnet18_objs(obs)
        # [features_objs_pose, features_objs_binary] = torch.chunk(features_objs, 2, dim=1)
        agent_pose = self.agent_head(features_agent)
        obj_poses = []
        obj_binarys = []
        for i in range(self.num_objects):
            obj_binary = (features_objs*self.obj_vector_embed_b(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_pose_x = (features_objs*self.obj_vector_embed_x(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_pose_y = (features_objs*self.obj_vector_embed_y(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            
            obj_poses.append(torch.cat([obj_pose_x, obj_pose_y], dim = -1))
            obj_binarys.append(obj_binary)
        return agent_pose, obj_poses, obj_binarys
    
class PoseNet_uncertainty(nn.Module):
    """Recurrent Lnet
       input: a sequence of obs with size (B, 3, 80,80)
       output: a sequence of obs with size (B,4,5)
    """
    def __init__(self, num_objects=3):
        super().__init__()
        self.num_objects = num_objects
        self.posenet = PoseNet(num_objects=num_objects)
        
        self.obj_vector_embed_sigma1 =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 
        self.obj_vector_embed_sigma2 =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 
        self.obj_vector_embed_sigma3 =  nn.Sequential(
                        get_and_init_FC_layer(10,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,128),
                        nn.ReLU(),
                        get_and_init_FC_layer(128,256),
                        ) 

    def forward(self, obs, objs_vector):
        features_agent = self.posenet.resnet18_agent(obs)
        features_objs = self.posenet.resnet18_objs(obs)
        # [features_objs_pose, features_objs_binary] = torch.chunk(features_objs, 2, dim=1)
        agent_pose = self.posenet.agent_head(features_agent)
        obj_poses = []
        obj_binarys = []
        obj_var = []
        for i in range(self.num_objects):
            obj_binary = (features_objs*self.posenet.obj_vector_embed_b(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_pose_x = (features_objs*self.posenet.obj_vector_embed_x(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_pose_y = (features_objs*self.posenet.obj_vector_embed_y(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)

            obj_sigma1 = (features_objs*self.obj_vector_embed_sigma1(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_sigma2 = (features_objs*self.obj_vector_embed_sigma2(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)
            obj_sigma3 = (features_objs*self.obj_vector_embed_sigma3(objs_vector[:, i, :])).sum(-1).unsqueeze(-1)

            obj_binarys.append(obj_binary)
            obj_poses.append(torch.cat([obj_pose_x, obj_pose_y], dim = -1))
            obj_var.append(torch.cat([obj_sigma1, obj_sigma2, obj_sigma3], dim = -1))
            
        return agent_pose, obj_poses, obj_binarys, obj_var