######################## 文件简介 ########################
# 训练策略的实现，策略中规则模块使用的是硬规则
import random
from smec_rl_components.smec_sampler import *
from smec_rl_components.gnn_module import GraphCNN
from pytorch_rl.a2c_ppo_acktr.base_modules import *
from smec_rl_components.smec_graph_build import *
from smec_rl_components.optimization import *
import time
from torch.nn import functional as F
from attention.myattention import SelfAttention2


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class SmecGraphEncoderV1(nn.Module):
    def __init__(self, lift_num, floor_num, hidden_dim=128, device='cpu'):
        super(SmecGraphEncoderV1, self).__init__()
        self.hidden_dim = hidden_dim
        self.elev_loading_conv = nn.Sequential(
            nn.Conv2d(1, 3, (4,8), 2),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(15, 64)),
            nn.ReLU(),
        )

        self.elev_location_conv = nn.Sequential(
            nn.Conv2d(1, 3, (4, 8), 2),
            nn.ReLU(),
            Flatten(),
            init_(nn.Linear(15, 64)),
            nn.ReLU(),
        )

        self.wait_encoder = MLPModule(input_size=1, hidden_size=hidden_dim)
        self.relation_encoder = MLPModule(input_size=lift_num, hidden_size=hidden_dim)
        self.total_encoder = MLPModule(input_size=hidden_dim * 3, hidden_size=hidden_dim)
        self.lift_num = lift_num
        self.floor_num = floor_num

    def forward(self, batch_fused_state):
        # input shape: [SAMPLE, dim]; output shape: [SAMPLE, ELV, dim].
        # adj, node_f = batch_fused_state['adj_m'], batch_fused_state['node_feature_m']
        up_dn_wait = torch.cat([batch_fused_state['up_wait'], batch_fused_state['dn_wait']], dim=-1).unsqueeze(-1)  # env, 2f
        wait_features = self.wait_encoder(up_dn_wait)  # env, 2f, hidden

        loading = batch_fused_state['loading'].unsqueeze(1)  # env, elev, f
        location = batch_fused_state['location'].unsqueeze(1)  # env, elev, f
        loading_features = self.elev_loading_conv(loading).unsqueeze(1).repeat((1, self.floor_num*2, 1))  # env, 2f, 64
        location_features = self.elev_location_conv(location).unsqueeze(1).repeat((1, self.floor_num*2, 1))  # env, 2f, 64

        distances = batch_fused_state['distances']
        rel_agent_features = self.relation_encoder(distances)  # env, 2f, hidden
        # print(wait_features.shape, elev_features.shape, rel_agent_features.shape)
        all_agent_features = torch.cat([wait_features, loading_features, location_features, rel_agent_features], dim=-1)
        all_agent_features = self.total_encoder(all_agent_features)
        return all_agent_features


class SmecGraphEncoderV2(nn.Module):
    def __init__(self, lift_num, floor_num, hidden_dim=128, device='cpu'):
        super(SmecGraphEncoderV2, self).__init__()
        self.hidden_dim = hidden_dim
        self.relation_encoder = MLPModule(input_size=lift_num+1, hidden_size=self.hidden_dim)
        self.lift_num = lift_num
        self.floor_num = floor_num
        self.AttentionFactor = SelfAttention2(self.lift_num, 128, 256, 4, 0.3, elevator_num=lift_num,
                                             floor_num=floor_num, device=device)

    def forward(self, batch_fused_state):
        # input shape: [SAMPLE, dim]; output shape: [SAMPLE, ELV, dim].
        distances = batch_fused_state['distances']  # env, 2f, elev
        floor_mask = batch_fused_state['floor_mask'].float().unsqueeze(-1).repeat((1,1,self.lift_num))  # env, 2f, elev
        all_agent_features = self.AttentionFactor(floor_mask, distances, distances)
        all_agent_features = all_agent_features.view((floor_mask.shape[0], self.floor_num*2, -1))
        return all_agent_features


class SmecBase(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, device='cpu'):
        super(SmecBase, self).__init__()
        self.use_graph = use_graph
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2))
        self.actor = SmecGraphEncoderV1(lift_num, floor_num, device=device)
        self.critic1 = SmecGraphEncoderV1(lift_num, floor_num, device=device)
        self.critic2 = SmecGraphEncoderV2(lift_num, floor_num, device=device)
        self.a_output = self.actor.hidden_dim
        self.critic_linear = init_(nn.Linear(self.critic1.hidden_dim, 1))
        # self.rule_linear = init_(nn.Linear(self.actor.hidden_dim, 1))
        # self.train()
        self.eval()

    def forward(self, inputs):
        hidden_actor = self.actor(inputs)

        z1 = self.critic1(inputs)
        z2 = self.critic2(inputs)

        v1 = self.critic_linear(z1)
        v2 = self.critic_linear(z2)

        rule = torch.argmax(v1, dim=1, keepdim=True)

        env_loss = 0.00025 * F.mse_loss(z1, z2)
        return v1, v2, hidden_actor, rule, env_loss


class SmecPolicy(nn.Module):
    def __init__(self, lift_num, floor_num, use_graph=True, open_mask=True, initialization=False, use_advice=False, device='cpu'):
        super(SmecPolicy, self).__init__()
        self.base = SmecBase(lift_num, floor_num, use_graph=use_graph, device=device)
        assert use_graph
        if not use_advice:
            self.dist = SmecSampler(self.base.a_output, lift_num, use_graph)
        else:
            # modified by JY, add advice choice.
            self.dist = SmecSampler(self.base.a_output, lift_num + 1, use_graph)

        self.open_mask = open_mask
        self.initialization = initialization

        if initialization:
            initialize_weights(self, "orthogonal")

    def forward(self, inputs_obs):
        return self.act(inputs_obs)

    def reset(self):
        self.hidden_cell = None

    def act(self, inputs_obs, deterministic=False, train=False):
        v1, v2, actor_features, rule, _ = self.base(inputs_obs)
        # # test pure random:
        actor_features = torch.ones_like(actor_features)
        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        dist = self.dist(actor_features, legal_mask=legal_mask)

        # adj, node_f = inputs_obs['adj_m'], inputs_obs['node_feature_m']
        # distances = inputs_obs['distances']

        if deterministic:
            action = dist.mode()
        else:
            action = dist.sample()

        action_log_probs = dist.log_probs(action)

        # return v2, action, action_log_probs, rule  # _v2
        return v1, action, action_log_probs, rule  # _v1

    def get_value(self, inputs):
        v1, v2, _, _, _ = self.base(inputs)
        # return v2  # _v2
        return v1  # _v1

    def evaluate_actions(self, inputs_obs, masks, action):
        v1, v2, actor_features, rule, env_loss = self.base(inputs_obs)
        legal_mask = inputs_obs['legal_masks'] if self.open_mask else None
        # dist = self.dist(actor_features, legal_mask=None)
        dist = self.dist(actor_features, legal_mask=legal_mask)
        action_log_probs = dist.log_probs(action)

        dist_entropy = dist.entropy()

        # add by JY, remove the effect of those not-taken action.
        valid_action_mask = inputs_obs['valid_action_mask']
        if valid_action_mask is not None:
            dist_entropy = dist_entropy * valid_action_mask
        # dist_entropy = dist_entropy.sum()
        dist_entropy = dist_entropy.mean()
        # return v1, action_log_probs, dist_entropy, env_loss  # v1_
        return v2, action_log_probs, dist_entropy, env_loss  # v2_
