######################## 文件简介 ########################
# 训练策略的实现，策略中规则模块使用的是软规则self-attention网络
import random
from smec_rl_components.smec_sampler import *
from smec_rl_components.gnn_module import GraphCNN
from pytorch_rl.a2c_ppo_acktr.base_modules import *
from smec_rl_components.smec_graph_build import *
from smec_rl_components.optimization import *
from attention.myattention import SelfAttention
import time


class SmecGraphEncoder(nn.Module):
    def __init__(self, lift_num, floor_num, hidden_dim=64, attention_hidden=64, device='cpu'):
        super(SmecGraphEncoder, self).__init__()
        self.hidden_dim = hidden_dim + attention_hidden
        self.gnn_encoder = GraphCNN(num_layers=3, num_mlp_layers=2, input_dim=3, hidden_dim=hidden_dim,
                                    learn_eps=False, neighbor_pooling_type='sum')
        self.relation_encoder = MLPModule(input_size=lift_num, hidden_size=hidden_dim)
        self.AttentionFactor = SelfAttention(floor_num * 2, attention_hidden, 128, 4, 0.3, elevator_num=lift_num,
                                             floor_num=floor_num, device=device)
        self.lift_num = lift_num
        self.floor_num = floor_num
        self.half_node = self.lift_num + self.floor_num

    def forward(self, batch_fused_state):
        # input shape: [SAMPLE, dim]; output shape: [SAMPLE, ELV, dim].
        adj, node_f = batch_fused_state['adj_m'], batch_fused_state['node_feature_m']
        distances = batch_fused_state['distances']

        node_hidden = self.gnn_encoder(node_f, padded_nei=None, adj=adj)
        up_floor_hidden = node_hidden[:, :self.floor_num, :]
        down_floor_hidden = node_hidden[:, self.half_node:self.half_node + self.floor_num, :]
        all_agent_features = torch.cat([up_floor_hidden, down_floor_hidden], dim=1)
        rel_agent_features = self.relation_encoder(distances)

        historical_pf = batch_fused_state['floor_mask'].float()
        hall_calls = batch_fused_state['hall_calls'].float()
        pt = self.AttentionFactor(hall_calls, historical_pf, historical_pf)

        all_agent_features = all_agent_features + rel_agent_features
        all_agent_features = torch.cat((all_agent_features, pt), -1)
        return all_agent_features


class SmecBase(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, device='cpu'):
        super(SmecBase, self).__init__()
        self.use_graph = use_graph
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
        self.actor = SmecGraphEncoder(lift_num, floor_num, device=device)
        self.critic = SmecGraphEncoder(lift_num, floor_num, device=device)
        self.a_output = self.actor.hidden_dim
        self.critic_linear = init_(nn.Linear(self.critic.hidden_dim, 1))
        self.train()

    def forward(self, inputs):
        hidden_critic = self.critic(inputs)
        hidden_actor = self.actor(inputs)
        value = self.critic_linear(hidden_critic)
        # rule = self.rule_linear(hidden_actor).squeeze(2)
        rule = torch.argmax(value, dim=1, keepdim=True)
        return value, hidden_actor, rule


class SmecPolicy(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, open_mask=True, initialization=False, device='cpu'):
        super(SmecPolicy, self).__init__()
        self.base = SmecBase(lift_num, floor_num, use_graph=use_graph, device=device)
        assert use_graph
        self.dist = SmecSampler(self.base.a_output, lift_num, use_graph)

        self.open_mask = open_mask
        self.elevator_num = lift_num
        self.floor_num = floor_num
        self.initialization = initialization

        if initialization:
            initialize_weights(self, "orthogonal")

    def forward(self, inputs_obs):
        return self.act(inputs_obs)

    def reset(self):
        self.hidden_cell = None

    def act(self, inputs_obs, deterministic=False, train=False):
        value, actor_features, rule = self.base(inputs_obs)
        # test pure random:
        # actor_features = torch.ones_like(actor_features)

        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        dist = self.dist(actor_features, legal_mask=legal_mask)

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()
        action_log_probs = dist.log_probs(action)
        return value, action, action_log_probs, rule

    def get_value(self, inputs):
        value, _, _ = self.base(inputs)
        return value

    def evaluate_actions(self, inputs_obs, masks, action):
        value, actor_features, rule = self.base(inputs_obs)
        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        dist = self.dist(actor_features, legal_mask=legal_mask)

        action_log_probs = dist.log_probs(action)
        dist_entropy = dist.entropy()

        # add by JY, remove the effect of those not-taken action.
        valid_action_mask = inputs_obs['valid_action_mask']
        if valid_action_mask is not None:
            dist_entropy = dist_entropy * valid_action_mask
        # dist_entropy = dist_entropy.sum()
        dist_entropy = dist_entropy.mean()

        return value, action_log_probs, dist_entropy

