######################## 文件简介 ########################
# 训练策略的实现，策略中规则模块使用的是硬规则
import random
from smec_rl_components.smec_sampler import *
from smec_rl_components.gnn_module import GraphCNN
from pytorch_rl.a2c_ppo_acktr.base_modules import *
from smec_rl_components.smec_graph_build import *
from smec_rl_components.optimization import *
import time
from torch.nn import functional as F
from attention.myattention import SelfAttention


class SmecGraphEncoderV1(nn.Module):
    def __init__(self, lift_num, floor_num, hidden_dim=128):
        super(SmecGraphEncoderV1, self).__init__()
        self.hidden_dim = hidden_dim
        self.gnn_encoder = GraphCNN(num_layers=3, num_mlp_layers=2, input_dim=3, hidden_dim=self.hidden_dim,
                                    learn_eps=False, neighbor_pooling_type='sum')
        self.relation_encoder = MLPModule(input_size=lift_num, hidden_size=self.hidden_dim)
        self.lift_num = lift_num
        self.floor_num = floor_num
        self.half_node = self.lift_num + self.floor_num

    def forward(self, batch_fused_state):
        # input shape: [SAMPLE, dim]; output shape: [SAMPLE, ELV, dim].
        adj, node_f = batch_fused_state['adj_m'], batch_fused_state['node_feature_m']
        distances = batch_fused_state['distances']
        node_hidden = self.gnn_encoder(node_f, padded_nei=None, adj=adj)
        up_floor_hidden = node_hidden[:, :self.floor_num, :]
        down_floor_hidden = node_hidden[:, self.half_node:self.half_node + self.floor_num, :]
        all_agent_features = torch.cat([up_floor_hidden, down_floor_hidden], dim=1)
        rel_agent_features = self.relation_encoder(distances)
        all_agent_features = all_agent_features + rel_agent_features
        return all_agent_features


class SmecGraphEncoderV2(nn.Module):
    def __init__(self, lift_num, floor_num, hidden_dim=128, device='cpu'):
        super(SmecGraphEncoderV2, self).__init__()
        self.hidden_dim = hidden_dim
        self.relation_encoder = MLPModule(input_size=lift_num+1, hidden_size=self.hidden_dim)
        self.lift_num = lift_num
        self.floor_num = floor_num
        self.AttentionFactor = SelfAttention(self.lift_num, 128, 256, 4, 0.3, elevator_num=lift_num,
                                             floor_num=floor_num, device=device)

    def forward(self, batch_fused_state):
        # input shape: [SAMPLE, dim]; output shape: [SAMPLE, ELV, dim].
        distances = batch_fused_state['distances']
        floor_mask = batch_fused_state['floor_mask'].float().unsqueeze(-1).repeat((1,1,self.lift_num))
        # distances = distances.view(distances.shape[0], -1)
        # floor_mask = distances.view(floor_mask.shape[0], -1)
        # all_agent_features = self.relation_encoder(torch.cat((distances, floor_mask), -1))
        all_agent_features = self.AttentionFactor(floor_mask, distances, distances)
        return all_agent_features


class SmecBase(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, device='cpu'):
        super(SmecBase, self).__init__()
        self.use_graph = use_graph
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
        self.actor = SmecGraphEncoderV1(lift_num, floor_num)
        self.critic1 = SmecGraphEncoderV1(lift_num, floor_num)
        self.critic2 = SmecGraphEncoderV2(lift_num, floor_num, device=device)
        self.a_output = self.actor.hidden_dim
        self.critic_linear = init_(nn.Linear(self.critic1.hidden_dim, 1))
        # self.rule_linear = init_(nn.Linear(self.actor.hidden_dim, 1))
        # self.train()
        self.eval()

    def forward(self, inputs):
        hidden_actor = self.actor(inputs)

        z1 = self.critic1(inputs)
        z2 = self.critic2(inputs)

        v1 = self.critic_linear(z1)
        v2 = self.critic_linear(z2)

        rule = torch.argmax(v1, dim=1, keepdim=True)

        env_loss = 0.00025 * F.mse_loss(z1, z2)
        return v1, v2, hidden_actor, rule, env_loss


class SmecPolicy(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, open_mask=True, initialization=False, use_advice=False, device='cpu'):
        super(SmecPolicy, self).__init__()
        self.base = SmecBase(lift_num, floor_num, use_graph=use_graph, device=device)
        assert use_graph
        if not use_advice:
            self.dist = SmecSampler(self.base.a_output, lift_num, use_graph)
        else:
            # modified by JY, add advice choice.
            self.dist = SmecSampler(self.base.a_output, lift_num + 1, use_graph)

        self.open_mask = open_mask
        self.initialization = initialization

        if initialization:
            initialize_weights(self, "orthogonal")

    def forward(self, inputs_obs):
        return self.act(inputs_obs)

    def reset(self):
        self.hidden_cell = None

    def act(self, inputs_obs, deterministic=False, train=False):
        v1, v2, actor_features, rule, _ = self.base(inputs_obs)
        # # test pure random:
        actor_features = torch.ones_like(actor_features)
        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        dist = self.dist(actor_features, legal_mask=legal_mask)

        # adj, node_f = inputs_obs['adj_m'], inputs_obs['node_feature_m']
        # distances = inputs_obs['distances']

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)

        return v1, action, action_log_probs, rule

    def get_value(self, inputs):
        v1, v2, _, _, _ = self.base(inputs)
        return v2

    def evaluate_actions(self, inputs_obs, masks, action):
        v1, v2, actor_features, rule, env_loss = self.base(inputs_obs)
        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        # dist = self.dist(actor_features, legal_mask=None)
        dist = self.dist(actor_features, legal_mask=legal_mask)
        action_log_probs = dist.log_probs(action)

        dist_entropy = dist.entropy()

        # add by JY, remove the effect of those not-taken action.
        valid_action_mask = inputs_obs['valid_action_mask']
        if valid_action_mask is not None:
            dist_entropy = dist_entropy * valid_action_mask
        # dist_entropy = dist_entropy.sum()
        dist_entropy = dist_entropy.mean()
        return v1, action_log_probs, dist_entropy, env_loss
