import torch
import numpy as np
from collections import deque
from torch.nn.functional import smooth_l1_loss
from torch.nn import Module, Linear, BatchNorm1d, Sequential, ReLU
from torch.distributions import Categorical
import random
import sys
sys.path.append('../LNet/')
from LMapNet_final import *
ROOM_SIZE_METER = 12
import math

# todo ####################################################################
class LnetsplitRNN(Module):
    def __init__(self, device, hidden_size):
        super(LnetsplitRNN, self).__init__()
        self.device = device
        self.Lnet = LMapNet_withPoseNet_uncertainty(hidden_size=hidden_size, device=device)

    def forward(self, obs, objs_vector, \
                hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, \
                hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3):
        seq_len = obs.shape[1]
        B_size = obs.shape[0]
        input_obs = obs.view(-1, 3, 80, 80)  ## B*L,3,80,80
        
        objs_vector = objs_vector.unsqueeze(1)
        objs_vector = objs_vector.expand(-1, seq_len, -1, -1)
        objs_vector = objs_vector.reshape(-1, 3, 10)

        agent_pose, obj_poses, obj_binarys, obj_var = self.Lnet.LMapNet.PoseNet(input_obs, objs_vector) ## agent_pose B*L*3, obj_poses B*L*2   binary B*L*1
        
        features_agent = self.Lnet.LMapNet.PoseNet.posenet.resnet18_agent(input_obs)
        features_objs = self.Lnet.LMapNet.PoseNet.posenet.resnet18_objs(input_obs)
        ## 

        agent_input = torch.cat((agent_pose, features_agent), dim=1) ## B, 3+128
        ###########  preprocess and process data of agent
        output_agent, (hidden_state_agent,cell_state_agent) = self.Lnet.LMapNet.rnn_agent(agent_input.unsqueeze(1),(hidden_state_agent,cell_state_agent)) ### output shape = (B,1,hidden size)
        predict_agent_pose_RNN = self.Lnet.LMapNet.MLP_agent_mean(output_agent).squeeze(1) # B, 3

        ###########  preprocess and process data of objs
        obj_global_poses_input = []
        
        for objs_local_pose in obj_poses:
            objs_global_x = objs_local_pose[:,0:1] * torch.cos(predict_agent_pose_RNN[:,2:3])\
                                - objs_local_pose[:,1:2]* torch.sin(predict_agent_pose_RNN[:,2:3]) + predict_agent_pose_RNN[:,0:1]
                                
            objs_global_y = objs_local_pose[:,1:2] * torch.cos(predict_agent_pose_RNN[:,2:3])\
                                + objs_local_pose[:,0:1]* torch.sin(predict_agent_pose_RNN[:,2:3]) + predict_agent_pose_RNN[:,1:2]
            
            obj_global_poses_input.append(torch.cat((objs_global_x,objs_global_y), dim=1))
      
        obj_var = torch.cat([obj_var[i].unsqueeze(1) for i in range(len(obj_var))], dim = 1)
        covar_matrix = torch.zeros((B_size, 3, 2, 2))
        # print(obj_var.shape)
        covar_matrix[:, :,0,0] = torch.exp(obj_var[:, :, 0])
        covar_matrix[:, :,0,1] = obj_var[:, :, 1]
        covar_matrix[:, :,1,1] = torch.exp(obj_var[:, :, 2])
    
        covar_matrix = torch.linalg.inv(torch.matmul(covar_matrix.permute(0, 1, 3,2),covar_matrix)).to(obj_var.device)

        rotate_matrix = torch.cat([torch.cos(predict_agent_pose_RNN[:,2]).unsqueeze(1), torch.sin(predict_agent_pose_RNN[:,2]).unsqueeze(1),
                                    -torch.sin(predict_agent_pose_RNN[:,2]).unsqueeze(1), torch.cos(predict_agent_pose_RNN[:,2]).unsqueeze(1)], 
                                    dim = 1).view(-1, 2, 2).unsqueeze(1).repeat(1,3,1,1)

        covar_matrix = torch.matmul(torch.matmul(rotate_matrix.permute(0,1,3,2),covar_matrix),rotate_matrix)
        obj_var = torch.cat([covar_matrix[:, :,0,0].unsqueeze(-1),covar_matrix[:, :,0,1].unsqueeze(-1),covar_matrix[:, :,1,1].unsqueeze(-1)], dim = 2).permute(1,0,2)

        obj_global_poses_input = torch.cat(obj_global_poses_input,dim=1) ## B,6  
        obj_binarys_input = torch.cat(obj_binarys,dim=1)  ## B,3
        obj_binarys_input = self.Lnet.LMapNet.sigmoid_net(obj_binarys_input)

        # obj_binarys_input = torch.where(obj_binarys_input <= 0.5, 0, 1)
        input_data_objs_1 = torch.cat((obj_global_poses_input[:,0:2], obj_var[0], obj_binarys_input[:,0:1], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
        input_data_objs_2 = torch.cat((obj_global_poses_input[:,2:4], obj_var[1], obj_binarys_input[:,1:2], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256
        input_data_objs_3 = torch.cat((obj_global_poses_input[:,4:6], obj_var[2], obj_binarys_input[:,2:3], predict_agent_pose_RNN, features_objs),dim=1).unsqueeze(1) ## B,1,2+1+3+256

        output_objs_1, (hidden_state_objs_1, cell_state_objs_1) = self.Lnet.LMapNet.rnn_objs(input_data_objs_1,(hidden_state_objs_1, cell_state_objs_1)) ### output shape = (B,1,hidden size)
        object_pose_predict_1 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_1) # B,1, 2
        object_var_predict_1 = self.Lnet.MLP_objs_var(output_objs_1)

        output_objs_2, (hidden_state_objs_2, cell_state_objs_2) = self.Lnet.LMapNet.rnn_objs(input_data_objs_2,(hidden_state_objs_2, cell_state_objs_2)) ### output shape = (B,1,hidden size)
        object_pose_predict_2 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_2) # B, 1, 2
        object_var_predict_2 = self.Lnet.MLP_objs_var(output_objs_2)

        output_objs_3, (hidden_state_objs_3, cell_state_objs_3) = self.Lnet.LMapNet.rnn_objs(input_data_objs_3,(hidden_state_objs_3, cell_state_objs_3)) ### output shape = (B,1,hidden size)
        object_pose_predict_3 = self.Lnet.LMapNet.MLP_objs_mean(output_objs_3) # B, 1, 2
        object_var_predict_3 = self.Lnet.MLP_objs_var(output_objs_3)

        object_pose_predict_all = torch.cat((object_pose_predict_1,object_pose_predict_2,object_pose_predict_3),dim=1) ## next obj pos mean  (B, 3, 2)    
        object_var_predict_all = torch.cat((object_var_predict_1,object_var_predict_2,object_var_predict_3),dim=1) ## next obj pos variance  (B, 3, 3)    


        ###########  detach from torch to np array 

        object_globalpose_mean_predict_all = object_pose_predict_all.detach().cpu().numpy()  ## next obj pos mean  (B, 3, 2)    
        object_globalpose_variance_predict_all = object_var_predict_all.detach().cpu().numpy() ## next obj pos variance  (B, 3, 3) 
        
        covar_matrix = np.zeros((B_size, 3, 2, 2))
        covar_matrix[:, :, 0,0] = np.exp(object_globalpose_variance_predict_all[:, :,0])
        covar_matrix[:, :, 0,1] = object_globalpose_variance_predict_all[:, :, 1]
        covar_matrix[:, :, 1,1] = np.exp(object_globalpose_variance_predict_all[:, :, 2])
        covar_matrix_final = np.linalg.inv(np.matmul(np.transpose(covar_matrix,(0, 1, 3,2)),covar_matrix))

        ######  processing state for polict net

        self.predicted_agent_pose = predict_agent_pose_RNN.detach().cpu().numpy()  ## B, 3
        x_agent_ori = np.cos(self.predicted_agent_pose[:, 2])  ## B,1
        y_agent_ori = np.sin(self.predicted_agent_pose[:, 2])  ## B,1

        state = []

        for i in range(3):
            state.extend([object_globalpose_mean_predict_all[:, i, 0] / (ROOM_SIZE_METER / 2) - 1,
                          object_globalpose_mean_predict_all[:, i, 1] / (ROOM_SIZE_METER / 2) - 1,
                          covar_matrix_final[:,i,0,0], covar_matrix_final[:,i, 0,1], covar_matrix_final[:,i,1,1]])
                       
        state.extend([self.predicted_agent_pose[:, 0] / (ROOM_SIZE_METER / 2) - 1,
                      self.predicted_agent_pose[:, 1] / (ROOM_SIZE_METER / 2) - 1,
                      x_agent_ori, y_agent_ori])

        state = np.array(state)

        # state = [self.predicted_agent_pose, object_globalpose_mean_predict_all, covar_matrix_final] 
        return state, (hidden_state_agent, cell_state_agent, hidden_state_objs_1, cell_state_objs_1, \
                       hidden_state_objs_2, cell_state_objs_2, hidden_state_objs_3, cell_state_objs_3)

class PNet_PPO_GTPos(Module):
    def __init__(self, device, input_size):
        super(PNet_PPO_GTPos, self).__init__()
        self.input_size = input_size
        self.device = device
        self.model = Sequential(
            Linear(self.input_size, 256),
            ReLU(),
            Linear(256, 128),
            ReLU(),
            Linear(128, 64),
            ReLU(),
            Linear(64, 32),
            ReLU(),
            Linear(32, 4),
        )

    def forward(self, states, batch_size):
        states = torch.from_numpy(states).float().to(self.device)
        x = states.view(batch_size, self.input_size)
        x = self.model(x)
        return torch.nn.functional.softmax(x, dim=1).squeeze()


class value_net(Module):
    def __init__(self, device, input_size):
        super(value_net, self).__init__()
        self.input_size = input_size
        self.device = device
        self.model = Sequential(
            Linear(self.input_size, 64),
            ReLU(),
            Linear(64, 64),
            ReLU(),
            Linear(64, 32),
            ReLU(),
            Linear(32, 1),
        )

    def forward(self, states, batch_size):
        states = torch.from_numpy(states).float().to(self.device)
        x = states.view(batch_size, self.input_size)
        x = self.model(x)
        return x.squeeze()


class Agent_PNet_PPO():
    def __init__(self, device, input_size=17, lr=0.0001, batch_size=1000,
                 discount=0.90, clip=0.1, test=False, test_model="",
                 num_envs=1, num_test_envs=1, LMapNet_pretrain_model=''):
        self.input_size = input_size
        self.test = test
        self.device = device
        self.batch_size = batch_size
        self.hidden_size = 512
        self.LoclizationNet = LnetsplitRNN(self.device, self.hidden_size).to(self.device)
        self.LoclizationNet.Lnet.load_state_dict(torch.load(LMapNet_pretrain_model))
        self.LoclizationNet.eval()
        self.num_envs = num_envs

        self.policy_net = PNet_PPO_GTPos(self.device, self.input_size).to(self.device)
        self.target_net = PNet_PPO_GTPos(self.device, self.input_size).to(self.device)
        self.value_net = value_net(self.device, self.input_size).to(self.device)

        if test:
            self.target_net.load_state_dict(torch.load(test_model))
            self.target_net.eval()
           
        else:
            self.lr = lr
            self.clip = clip
            self.discount = discount
            self.policy_net.train()
            self.target_net.eval()
            self.value_net.train()

            self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)
            self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=self.lr/2)

        self.states = deque([[]]*self.num_envs)
        self.actions = deque([[]]*self.num_envs)
        self.rewards = deque([[]]*self.num_envs)
        self.dones = deque([[]]*self.num_envs)
        self.rtgs = deque([[]]*self.num_envs)
        self.advantages = deque([])
        self.actor_prob = deque([[]]*self.num_envs)
        # self.rewards = deque(maxlen=self.replay_memory_size)

        self.num_test_envs = num_test_envs

    def init_hidden_test(self):
        self.hidden_batch_agent_test, self.cell_batch_agent_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_1_test, self.cell_batch_objs_1_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)
        self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_test_envs)

    def reset_agent(self):
        self.hidden_batch_agent, self.cell_batch_agent = self.LoclizationNet.Lnet.init_hidden_states(self.num_envs)
        self.hidden_batch_objs_1, self.cell_batch_objs_1 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)
        self.hidden_batch_objs_2, self.cell_batch_objs_2 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)
        self.hidden_batch_objs_3, self.cell_batch_objs_3 = self.LoclizationNet.Lnet.init_hidden_states(
            self.num_envs)

    def reset_per_memory(self, idx):
        self.hidden_batch_agent[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_agent[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_1[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_1[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_2[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_2[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.hidden_batch_objs_3[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)
        self.cell_batch_objs_3[:, idx:idx + 1, :] = torch.zeros(1, 1, self.hidden_size).float().to(self.device)

    def compute_rtgs(self, idx):
        if len(self.rewards[idx]) == 1:
            if self.rtgs[idx]:
                self.rtgs[idx].append(self.rewards[idx].popleft())
            else:
                self.rtgs[idx] = deque([self.rewards[idx].popleft()])
            return 0
        for i in range(-2, -1-len(self.rewards[idx]), -1):
            self.rewards[idx][i] += self.discount * self.rewards[idx][i+1]
        while self.rewards[idx]:
            if self.rtgs[idx]:
                self.rtgs[idx].append(self.rewards[idx].popleft())
            else:
                self.rtgs[idx] = deque([self.rewards[idx].popleft()])

    def clear_replay_memory(self):
        self.states = deque([[]]*self.num_envs)
        self.actions = deque([[]]*self.num_envs)
        self.rewards = deque([[]]*self.num_envs)
        self.dones = deque([[]]*self.num_envs)
        self.advantages = deque([])
        self.actor_prob = deque([[]]*self.num_envs)
        self.rtgs = deque([[]]*self.num_envs)

    def store_replay_memory(self, s, a, r, d):
        # state is a deque like [[world1 states], [world2 states], ..., [world5 states]]
        # action is a deque like [[world1 actions], [world2 actions], ..., [world5 actions]]
        for idx, state in enumerate(s):
            if self.states[idx]:
                self.states[idx].append(state)
            else:
                self.states[idx] = deque([state])
            if self.actions[idx]:
                self.actions[idx].append(a[idx])
            else:
                self.actions[idx] = deque([a[idx]])
            if self.rewards[idx]:
                self.rewards[idx].append(r[idx])
            else:
                self.rewards[idx] = deque([r[idx]])
            if self.dones[idx]:
                self.dones[idx].append(d[idx])
            else:
                self.dones[idx] = deque([d[idx]])

    def state_estimator(self, states, test=False):
        obs_list = []
        color_vector_list = []

        if test:
            current_obs = states[0]
            current_objs_type_and_color_vector = states[1]
            current_obs = current_obs.reshape(3, 80, 80)
            current_obs = 2 * current_obs / 255 - 1
            current_obs_torch = torch.from_numpy(current_obs).float().unsqueeze(0).unsqueeze(0).to(self.device)
            current_objs_type_and_color_vector_torch = torch.from_numpy(
                current_objs_type_and_color_vector).float().unsqueeze(0).to(self.device)
            with torch.no_grad():
                est_state, hidden_ = self.LoclizationNet(current_obs_torch, current_objs_type_and_color_vector_torch,
                                                         self.hidden_batch_agent_test, self.cell_batch_agent_test,
                                                         self.hidden_batch_objs_1_test, self.cell_batch_objs_1_test,
                                                         self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test,
                                                         self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test)
                self.hidden_batch_agent_test, self.cell_batch_agent_test, self.hidden_batch_objs_1_test, \
                    self.cell_batch_objs_1_test, self.hidden_batch_objs_2_test, self.cell_batch_objs_2_test, \
                    self.hidden_batch_objs_3_test, self.cell_batch_objs_3_test = hidden_

        else:
            for i, state in enumerate(states):
                current_obs = state[0]
                current_objs_type_and_color_vector = state[1]
                current_obs = current_obs.reshape(3, 80, 80)
                current_obs = 2 * current_obs / 255 - 1
                current_obs_torch = torch.from_numpy(current_obs).float().unsqueeze(0).unsqueeze(0).to(self.device)
                current_objs_type_and_color_vector_torch = torch.from_numpy(
                    current_objs_type_and_color_vector).float().unsqueeze(0).to(self.device)
                obs_list.append(current_obs_torch)
                color_vector_list.append(current_objs_type_and_color_vector_torch)

            obs_list = torch.cat(obs_list, dim=0)  ### num_env, 1 ,3, 80, 80
            color_vector_list = torch.cat(color_vector_list, dim=0)  ### num_env, 3, 10
            with torch.no_grad():
                est_state, hidden_ = self.LoclizationNet(obs_list, color_vector_list,
                                                         self.hidden_batch_agent, self.cell_batch_agent,
                                                         self.hidden_batch_objs_1, self.cell_batch_objs_1,
                                                         self.hidden_batch_objs_2, self.cell_batch_objs_2,
                                                         self.hidden_batch_objs_3, self.cell_batch_objs_3)
                self.hidden_batch_agent, self.cell_batch_agent, self.hidden_batch_objs_1, self.cell_batch_objs_1, \
                    self.hidden_batch_objs_2, self.cell_batch_objs_2, self.hidden_batch_objs_3, \
                    self.cell_batch_objs_3 = hidden_
        return np.transpose(est_state)

    def step(self, state, test=False, exploration_rate=0.005):  # todo renewed
        # state = np.array(state, dtype=float)

        actions_prob = self.target_net(state, batch_size=1)
        # if self.test:
            # print([actions_prob.detach()])
        distribution = Categorical(actions_prob)

        if random.random() < exploration_rate:
            action = np.array(3)
            self.actor_prob.append(actions_prob[action].detach())
            return action

        action = distribution.sample()
        self.actor_prob.append(actions_prob[action].detach())
        return np.array(action.item())

    # todo we have the ratio and rtgs now, next step is to find advantage function and train policy and value net

    def step_mp(self, states, test=False, exploration_rate=0):
        # states = np.array(self.states)
        with torch.no_grad():
            actions_probs = self.target_net(states, batch_size=self.num_envs)
        actions = []
        for i, p in enumerate(actions_probs):
            distribution = Categorical(p)
            action = distribution.sample()
            actions.append(action.item())
            if self.actor_prob[i]:
                self.actor_prob[i].append(actions_probs[i][action])
            else:
                self.actor_prob[i] = deque([actions_probs[i][action]])

        return np.array(actions, dtype=int)

    def clear_over_step(self):
        # clear incomplete trajectories
        states = deque([])
        actions = deque([])
        actor_prob = deque([])
        rtgs = deque([])
        for i in range(self.num_envs):
            for _ in range(len(self.states[i])-len(self.rtgs[i])):
                self.states[i].pop()
                self.actions[i].pop()
                self.actor_prob[i].pop()
            states += self.states[i]
            actions += self.actions[i]
            actor_prob += self.actor_prob[i]
            rtgs += self.rtgs[i]

        # merge the data
        self.states = np.array(states)
        self.actions = actions
        self.actor_prob = torch.tensor(actor_prob).to(self.device).float()
        self.rtgs = rtgs
        self.rtgs = torch.tensor(self.rtgs).to(self.device).float()

    def get_adv_func(self):
        self.clear_over_step()

    def learn(self):
        with torch.no_grad():
            self.advantages = self.rtgs - self.value_net(self.states, batch_size=len(self.states))
        exp_rtgs = self.value_net(self.states, batch_size=len(self.states))
        current_probs = self.policy_net(self.states, batch_size=len(self.states))
        current_prob = deque()
        for i, a in enumerate(self.actions):
            current_prob.append(current_probs[i][a])
        current_prob = torch.stack(list(current_prob)).float()
        ratios = (current_prob / self.actor_prob)

        surr1 = ratios * self.advantages
        surr2 = torch.clamp(ratios, 1 - self.clip, 1 + self.clip) * self.advantages
        actor_loss = (-torch.min(surr1, surr2)).mean()
        crit_loss = torch.nn.MSELoss()(exp_rtgs, self.rtgs)

        # backward
        self.optimizer.zero_grad()
        actor_loss.backward()
        self.optimizer.step()

        self.value_optimizer.zero_grad()
        crit_loss.backward()
        self.value_optimizer.step()

        return actor_loss.item(), crit_loss.item()

    def update_target(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
