import sys
import torch
import torch.nn as nn 
import numpy as np
import resnet
from PoseNet_final import PoseNet_uncertainty

def get_and_init_FC_layer(din, dout):
    li = nn.Linear(din, dout)
    return li

class LMapNet_withPoseNet(nn.Module):
    """Recurrent Lnet using teacher forcing 
       input: a sequence of obs with size (B,L,80,80,3), obj_vector (B, 3, 10) and hidden state, agent_GT_pose (B,L,3)
       output: a sequence of global state with size agent (B,L,7), x_mean, y_mean, corv_11, corv_12, corv_22, theta_mean, theta_variance,         
                                                    objs (B,L, 15)
    """
    def __init__(self, hidden_size, device):
        super().__init__()
        # Encoder and decoder configuration
        self.hidden_size = hidden_size
        self.device = device
        # Encoder and decoder
        self.rnn_agent = nn.LSTM(128+3, self.hidden_size, batch_first=True).to(device)
        self.rnn_objs = nn.LSTM(2+3+1+3+256, self.hidden_size, batch_first=True).to(device)

        self.PoseNet = PoseNet_uncertainty()
        self.sigmoid_net = nn.Sigmoid()
         
        self.MLP_agent_mean = nn.Sequential(
                        get_and_init_FC_layer(self.hidden_size, 64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,16),
                        nn.ReLU(),
                        get_and_init_FC_layer(16, 3))

        self.MLP_objs_mean = nn.Sequential(
                        get_and_init_FC_layer(self.hidden_size,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,16),
                        nn.ReLU(),
                        get_and_init_FC_layer(16, 2))

    
    def init_hidden_states(self,bsize):
        h = torch.zeros(1,bsize,self.hidden_size).float().to(self.device)
        c = torch.zeros(1,bsize,self.hidden_size).float().to(self.device)
        return h,c

    def forward(self, obs, objs_vector, agent_GT_pose, hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, \
                                                                                                hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3):
        seq_len = obs.shape[1]
        B_size = obs.shape[0]

        predict_agent_pose = []
        predicted_objs_pose_global = []

        for i in range(0, seq_len):
            input_obs = obs[:,i:i+1,:,:,:].squeeze(1) ## B,3,80,80
            agent_pose, obj_poses, obj_binarys, obj_var = self.PoseNet(input_obs, objs_vector) ## agent_pose B,3 obj_poses is a list of B,2 obj_binarys is a list of B,1
            features_agent = self.PoseNet.posenet.resnet18_agent(input_obs)
            features_objs = self.PoseNet.posenet.resnet18_objs(input_obs)

            ###########  preprocess and process data of agent 
            agent_input = torch.cat((agent_pose, features_agent), dim=1) ## B,3+128
            output_agent, (hidden_state_agent,cell_state_agent) = self.rnn_agent(agent_input.unsqueeze(1),(hidden_state_agent,cell_state_agent)) ### output shape = (B,1,hidden size)
            predict_agent_pose_RNN = self.MLP_agent_mean(output_agent).squeeze(1)
            predict_agent_pose.append(predict_agent_pose_RNN) ## next pos size (B, 3) mean value of x, y, theta of the agent

            ###########  preprocess and process data of objs
            obj_global_poses_input = []
          
            for objs_local_pose in obj_poses:
                objs_global_x = objs_local_pose[:,0:1] * torch.cos(agent_GT_pose[:,i,2:3])\
                                    - objs_local_pose[:,1:2]* torch.sin(agent_GT_pose[:,i,2:3]) + agent_GT_pose[:,i,0:1]
                                    
                objs_global_y = objs_local_pose[:,1:2] * torch.cos(agent_GT_pose[:,i,2:3])\
                                    + objs_local_pose[:,0:1]* torch.sin(agent_GT_pose[:,i,2:3]) + agent_GT_pose[:,i,1:2]
                
                obj_global_poses_input.append(torch.cat((objs_global_x,objs_global_y), dim=1))

            obj_var = torch.cat([obj_var[i].unsqueeze(1) for i in range(len(obj_var))], dim = 1)
            covar_matrix = torch.zeros((B_size, 3, 2, 2))
            # print(obj_var.shape)
            covar_matrix[:, :,0,0] = torch.exp(obj_var[:, :, 0])
            covar_matrix[:, :,0,1] = obj_var[:, :, 1]
            covar_matrix[:, :,1,1] = torch.exp(obj_var[:, :, 2])
        
            covar_matrix = torch.linalg.inv(torch.matmul(covar_matrix.permute(0, 1, 3,2),covar_matrix)).to(obj_var.device)

            rotate_matrix = torch.cat([torch.cos(agent_GT_pose[:,i,2]).unsqueeze(1), torch.sin(agent_GT_pose[:,i,2]).unsqueeze(1),
                                       -torch.sin(agent_GT_pose[:,i,2]).unsqueeze(1), torch.cos(agent_GT_pose[:,i,2]).unsqueeze(1)], 
                                       dim = 1).view(-1, 2, 2).unsqueeze(1).repeat(1,3,1,1)
            
            covar_matrix = torch.matmul(torch.matmul(rotate_matrix.permute(0,1,3,2),covar_matrix),rotate_matrix)
            obj_var = torch.cat([covar_matrix[:, :,0,0].unsqueeze(-1),covar_matrix[:, :,0,1].unsqueeze(-1),covar_matrix[:, :,1,1].unsqueeze(-1)], dim = 2).permute(1,0,2)

            obj_global_poses_input = torch.cat(obj_global_poses_input,dim=1) ## B,6  
            obj_binarys_input = torch.cat(obj_binarys,dim=1)  ## B,3
            obj_binarys_input = self.sigmoid_net(obj_binarys_input)

            # obj_binarys_input = torch.where(obj_binarys_input <= 0.5, 0, 1)
            input_data_objs_1 = torch.cat((obj_global_poses_input[:,0:2], obj_var[0], obj_binarys_input[:,0:1], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
            input_data_objs_2 = torch.cat((obj_global_poses_input[:,2:4], obj_var[1], obj_binarys_input[:,1:2], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
            input_data_objs_3 = torch.cat((obj_global_poses_input[:,4:6], obj_var[2], obj_binarys_input[:,2:3], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256

            output_objs_1, (hidden_state_objs_1, cell_state_objs_1) = self.rnn_objs(input_data_objs_1,(hidden_state_objs_1, cell_state_objs_1)) ### output shape = (B,1,hidden size)
            object_pose_predict_1 = self.MLP_objs_mean(output_objs_1) # B,1, 2

            output_objs_2, (hidden_state_objs_2, cell_state_objs_2) = self.rnn_objs(input_data_objs_2,(hidden_state_objs_2, cell_state_objs_2)) ### output shape = (B,1,hidden size)
            object_pose_predict_2 = self.MLP_objs_mean(output_objs_2) # B, 1, 2

            output_objs_3, (hidden_state_objs_3, cell_state_objs_3) = self.rnn_objs(input_data_objs_3,(hidden_state_objs_3, cell_state_objs_3)) ### output shape = (B,1,hidden size)
            object_pose_predict_3 = self.MLP_objs_mean(output_objs_3) # B, 1, 2

            object_pose_predict_all = torch.cat((object_pose_predict_1,object_pose_predict_2,object_pose_predict_3),dim=1) ## next pos size (B, 3, 2)    
            
            predicted_objs_pose_global.append(object_pose_predict_all) ## a list of next pos size (B,6)      

        return predict_agent_pose, predicted_objs_pose_global


class LMapNet_withPoseNet_uncertainty(nn.Module):
    """Recurrent Lnet using teacher forcing 
       input: a sequence of obs with size (B,L,80,80,3), obj_vector (B, 3, 10) and hidden state, agent_GT_pose (B,L,3)
       output: a sequence of global state with size agent (B,L,7), x_mean, y_mean, corv_11, corv_12, corv_22, theta_mean, theta_variance,         
                                                    objs (B,L, 15)
    """
    def __init__(self, hidden_size, device):
        super().__init__()
        # Encoder and decoder configuration

        self.LMapNet = LMapNet_withPoseNet(hidden_size=hidden_size, device=device)

        self.hidden_size = hidden_size
        self.device = device
        
        self.MLP_objs_var = nn.Sequential(
                        get_and_init_FC_layer(self.hidden_size,64),
                        nn.ReLU(),
                        get_and_init_FC_layer(64,16),
                        nn.ReLU(),
                        get_and_init_FC_layer(16, 3))

    
    def init_hidden_states(self,bsize):
        h = torch.zeros(1,bsize,self.hidden_size).float().to(self.device)
        c = torch.zeros(1,bsize,self.hidden_size).float().to(self.device)
        return h,c

    def forward(self, obs, objs_vector, agent_GT_pose, hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, 
                hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3, generate = False):
        seq_len = obs.shape[1]
        B_size = obs.shape[0]

        predict_agent_pose = []
        predicted_objs_pose_global = []
        predicted_objs_var_global = []

        predict_agent_pose_posnet = []
        predict_posnet_pose = []
        predict_posnet_var = []

        for i in range(0, seq_len):
            input_obs = obs[:,i:i+1,:,:,:].squeeze(1) ## B,3,80,80
            agent_pose, obj_poses, obj_binarys, obj_var = self.LMapNet.PoseNet(input_obs, objs_vector) ## agent_pose B,3 obj_poses is a list of B,2 obj_binarys is a list of B,1
            features_agent = self.LMapNet.PoseNet.posenet.resnet18_agent(input_obs)
            features_objs = self.LMapNet.PoseNet.posenet.resnet18_objs(input_obs)

            ###########  preprocess and process data of agent 
            agent_input = torch.cat((agent_pose, features_agent), dim=1) ## B,3+128
            output_agent, (hidden_state_agent,cell_state_agent) = self.LMapNet.rnn_agent(agent_input.unsqueeze(1),(hidden_state_agent,cell_state_agent)) ### output shape = (B,1,hidden size)
            predict_agent_pose_RNN = self.LMapNet.MLP_agent_mean(output_agent).squeeze(1)
            predict_agent_pose.append(predict_agent_pose_RNN) ## next pos size (B, 3) mean value of x, y, theta of the agent

            ###########  preprocess and process data of objs
            obj_global_poses_input = []
          
            for objs_local_pose in obj_poses:
                objs_global_x = objs_local_pose[:,0:1] * torch.cos(agent_GT_pose[:,i,2:3])\
                                    - objs_local_pose[:,1:2]* torch.sin(agent_GT_pose[:,i,2:3]) + agent_GT_pose[:,i,0:1]
                                    
                objs_global_y = objs_local_pose[:,1:2] * torch.cos(agent_GT_pose[:,i,2:3])\
                                    + objs_local_pose[:,0:1]* torch.sin(agent_GT_pose[:,i,2:3]) + agent_GT_pose[:,i,1:2]
                
                obj_global_poses_input.append(torch.cat((objs_global_x,objs_global_y), dim=1))

            obj_var = torch.cat([obj_var[i].unsqueeze(1) for i in range(len(obj_var))], dim = 1)
            covar_matrix = torch.zeros((B_size, 3, 2, 2))
            # print(obj_var.shape)
            covar_matrix[:, :,0,0] = torch.exp(obj_var[:, :, 0])
            covar_matrix[:, :,0,1] = obj_var[:, :, 1]
            covar_matrix[:, :,1,1] = torch.exp(obj_var[:, :, 2])
        
            covar_matrix = torch.linalg.inv(torch.matmul(covar_matrix.permute(0, 1, 3,2),covar_matrix)).to(obj_var.device)

            rotate_matrix = torch.cat([torch.cos(agent_GT_pose[:,i,2]).unsqueeze(1), torch.sin(agent_GT_pose[:,i,2]).unsqueeze(1),
                                       -torch.sin(agent_GT_pose[:,i,2]).unsqueeze(1), torch.cos(agent_GT_pose[:,i,2]).unsqueeze(1)], 
                                       dim = 1).view(-1, 2, 2).unsqueeze(1).repeat(1,3,1,1)
            
            covar_matrix = torch.matmul(torch.matmul(rotate_matrix.permute(0,1,3,2),covar_matrix),rotate_matrix)
            obj_var = torch.cat([covar_matrix[:, :,0,0].unsqueeze(-1),covar_matrix[:, :,0,1].unsqueeze(-1),covar_matrix[:, :,1,1].unsqueeze(-1)], dim = 2).permute(1,0,2)

            obj_global_poses_input = torch.cat(obj_global_poses_input,dim=1) ## B,6  
            obj_binarys_input = torch.cat(obj_binarys,dim=1)  ## B,3
            obj_binarys_input = self.LMapNet.sigmoid_net(obj_binarys_input)

            # obj_binarys_input = torch.where(obj_binarys_input <= 0.5, 0, 1)
            input_data_objs_1 = torch.cat((obj_global_poses_input[:,0:2], obj_var[0], obj_binarys_input[:,0:1], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
            input_data_objs_2 = torch.cat((obj_global_poses_input[:,2:4], obj_var[1], obj_binarys_input[:,1:2], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
            input_data_objs_3 = torch.cat((obj_global_poses_input[:,4:6], obj_var[2], obj_binarys_input[:,2:3], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256

            output_objs_1, (hidden_state_objs_1, cell_state_objs_1) = self.LMapNet.rnn_objs(input_data_objs_1,(hidden_state_objs_1, cell_state_objs_1)) ### output shape = (B,1,hidden size)
            object_pose_predict_1 = self.LMapNet.MLP_objs_mean(output_objs_1) # B,1, 2
            object_var_predict_1 = self.MLP_objs_var(output_objs_1)

            output_objs_2, (hidden_state_objs_2, cell_state_objs_2) = self.LMapNet.rnn_objs(input_data_objs_2,(hidden_state_objs_2, cell_state_objs_2)) ### output shape = (B,1,hidden size)
            object_pose_predict_2 = self.LMapNet.MLP_objs_mean(output_objs_2) # B, 1, 2
            object_var_predict_2 = self.MLP_objs_var(output_objs_2)

            output_objs_3, (hidden_state_objs_3, cell_state_objs_3) = self.LMapNet.rnn_objs(input_data_objs_3,(hidden_state_objs_3, cell_state_objs_3)) ### output shape = (B,1,hidden size)
            object_pose_predict_3 = self.LMapNet.MLP_objs_mean(output_objs_3) # B, 1, 2
            object_var_predict_3 = self.MLP_objs_var(output_objs_3)

            object_pose_predict_all = torch.cat((object_pose_predict_1,object_pose_predict_2,object_pose_predict_3),dim=1) ## next pos size (B, 3, 2)    
            object_var_predict_all = torch.cat((object_var_predict_1,object_var_predict_2,object_var_predict_3),dim=1)
            
            predicted_objs_pose_global.append(object_pose_predict_all) ## a list of next pos size (B,6)    
            predicted_objs_var_global.append(object_var_predict_all) 

            predict_agent_pose_posnet.append(agent_pose) 
            predict_posnet_pose.append(obj_global_poses_input)
            predict_posnet_var.append(covar_matrix)
        
        if generate:
            return predict_agent_pose, predicted_objs_pose_global, predicted_objs_var_global, predict_agent_pose_posnet, predict_posnet_pose, predict_posnet_var

        return predict_agent_pose, predicted_objs_pose_global, predicted_objs_var_global

    

