import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import torch as th


class ObsRewardEncoder(nn.Module):
    def __init__(self, args):
        super(ObsRewardEncoder, self).__init__()
        self.args = args
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.mixing_embed_dim = args.mixing_embed_dim
        self.action_latent_dim = args.action_latent_dim

        self.state_dim = int(np.prod(args.state_shape))
        self.obs_dim = int(np.prod(args.obs_shape))

        self.obs_encoder_avg = nn.Sequential(
            nn.Linear(self.obs_dim + self.n_actions * (self.n_agents - 1), args.state_latent_dim * 2),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim * 2, args.state_latent_dim))
        self.obs_decoder_avg = nn.Sequential(
            nn.Linear(args.state_latent_dim + args.action_latent_dim, args.state_latent_dim),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim, self.obs_dim))

        self.action_encoder = nn.Sequential(nn.Linear(self.n_actions, args.state_latent_dim * 2),
                                            nn.ReLU(),
                                            nn.Linear(args.state_latent_dim * 2, args.action_latent_dim))

        self.reward_decoder_avg = nn.Sequential(
            nn.Linear(args.state_latent_dim + args.action_latent_dim, args.state_latent_dim),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim, 1))

    def predict(self, obs, actions):
        # used in learners (for training)
        other_actions = self.other_actions(actions)
        obs_reshaped = obs.contiguous().view(-1, self.obs_dim)
        inputs = th.cat([obs_reshaped, other_actions], dim=-1)

        # average
        obs_latent_avg = self.obs_encoder_avg(inputs)
        actions = actions.contiguous().view(-1, self.n_actions)
        action_latent_avg = self.action_encoder(actions)

        pred_avg_input = th.cat([obs_latent_avg, action_latent_avg], dim=-1)
        no_pred_avg = self.obs_decoder_avg(pred_avg_input)
        r_pred_avg = self.reward_decoder_avg(pred_avg_input)

        return no_pred_avg.view(-1, self.n_agents, self.obs_dim), r_pred_avg.view(-1, self.n_agents, 1)

    def forward(self):
        actions = th.Tensor(np.eye(self.n_actions)).to(self.args.device)
        actions_latent_avg = self.action_encoder(actions)
        return actions_latent_avg

    def other_actions(self, actions):
        # actions: [bs, n_agents, n_actions]
        assert actions.shape[1] == self.n_agents

        other_actions = []
        for i in range(self.n_agents):
            _other_actions = []
            for j in range(self.n_agents):
                if i != j:
                    _other_actions.append(actions[:, j])
            _other_actions = th.cat(_other_actions, dim=-1)
            other_actions.append(_other_actions)

        other_actions = th.stack(other_actions, dim=1).contiguous().view(-1, (self.n_agents - 1) * self.n_actions)
        return other_actions


class HyperActionEncoderInput(nn.Module):
    def __init__(self, args):
        super(HyperActionEncoderInput, self).__init__()
        self.n_actions = args.n_actions
        self.latent_dim = args.state_latent_dim * 2

        move_feats_dim, enemy_feats_dim, ally_feats_dim, own_feats_dim = args.obs_component
        self.enemy_feats_dim = enemy_feats_dim[-1]  # [n_enemies, feat_dim]

        self.normal_action_embedding = nn.Linear(self.n_actions, self.latent_dim)
        self.attack_action_embedding = nn.Linear(self.enemy_feats_dim, self.latent_dim)

    def forward(self, tuple_inputs):
        """
        :param obs_enemies: [bs * n_agents, n_enemies, enemy_feats_dim]
        :param onehot_action: [bs * n_agents, n_action]
        :param actions: [bs * n_agents, 1]
        :return:
        """
        obs_enemies, onehot_action, action, for_selected_action = tuple_inputs
        if for_selected_action:
            move_embedding = self.normal_action_embedding(onehot_action)  # [bs * n_agents, hidden]
            attack_embedding = self.attack_action_embedding(obs_enemies)  # [bs * n_agents, n_enemies, hidden]
            _tmp_attack_index = th.unsqueeze(th.clamp_min(action - 6, min=0), -1).expand(-1, -1, self.latent_dim)
            selected_attack_embedding = th.gather(attack_embedding, dim=1, index=_tmp_attack_index).squeeze(1)
            # (no_op, stop, up, down, right, left)
            return th.where(action < 6, move_embedding, selected_attack_embedding)

        else:
            # [bs * n_agents, n_action, action_latent_dim]
            attack_embedding = self.attack_action_embedding(obs_enemies)  # [bs * n_agents, n_enemies, hidden]
            # [n_move, hidden] ->  # [bs * n_agents, n_move, hidden]
            move_embedding = self.normal_action_embedding(onehot_action).unsqueeze(0).expand(attack_embedding.shape[0],
                                                                                             -1, -1)
            return th.cat([move_embedding, attack_embedding], dim=1)  #  [bs * n_agents, n_actions, hidden]


class HyperObsRewardEncoder(nn.Module):
    def __init__(self, args):
        super(HyperObsRewardEncoder, self).__init__()
        self.args = args
        self.n_agents = args.n_agents
        self.n_enemies = args.n_enemies
        self.n_actions = args.n_actions
        self.mixing_embed_dim = args.mixing_embed_dim
        self.action_latent_dim = args.action_latent_dim
        move_feats_dim, enemy_feats_dim, ally_feats_dim, own_feats_dim = self.args.obs_component
        self.enemy_feats_dim = enemy_feats_dim[-1]  # [n_enemies, feat_dim]
        enemy_feats_dim_flatten = np.prod(enemy_feats_dim)
        ally_feats_dim_flatten = np.prod(ally_feats_dim)
        self.obs_delimiter = (move_feats_dim, enemy_feats_dim_flatten, ally_feats_dim_flatten, own_feats_dim)

        # self.state_dim = int(np.prod(args.state_shape))
        self.obs_dim = int(np.prod(args.obs_shape))

        self.obs_encoder_avg = nn.Sequential(
            nn.Linear(self.obs_dim + self.n_actions * (self.n_agents - 1), args.state_latent_dim * 2),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim * 2, args.state_latent_dim))
        self.obs_decoder_avg = nn.Sequential(
            nn.Linear(args.state_latent_dim + args.action_latent_dim, args.state_latent_dim),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim, self.obs_dim))

        # The input is the onehot action
        self.action_encoder = nn.Sequential(HyperActionEncoderInput(args),
                                            nn.ReLU(),
                                            nn.Linear(args.state_latent_dim * 2, args.action_latent_dim))

        self.reward_decoder_avg = nn.Sequential(
            nn.Linear(args.state_latent_dim + args.action_latent_dim, args.state_latent_dim),
            nn.ReLU(),
            nn.Linear(args.state_latent_dim, 1))

    def predict(self, obs, onehot_actions, actions):
        # used in learners (for training)
        other_actions = self.other_actions(onehot_actions)
        obs_reshaped = obs.contiguous().view(-1, self.obs_dim)
        inputs = th.cat([obs_reshaped, other_actions], dim=-1)

        # average
        obs_latent_avg = self.obs_encoder_avg(inputs)
        onehot_actions = onehot_actions.contiguous().view(-1, self.n_actions)

        action_latent_avg = self.action_encoder(
            [self._get_enemy_features_from_obs(obs), onehot_actions, actions.contiguous().view(-1, 1), True])

        pred_avg_input = th.cat([obs_latent_avg, action_latent_avg], dim=-1)
        no_pred_avg = self.obs_decoder_avg(pred_avg_input)
        r_pred_avg = self.reward_decoder_avg(pred_avg_input)

        return no_pred_avg.view(-1, self.n_agents, self.obs_dim), r_pred_avg.view(-1, self.n_agents, 1)

    def forward(self, obs):
        onehot_move_actions = th.Tensor(np.eye(self.n_actions)[:6]).to(obs.device)
        actions_latent_avg = self.action_encoder([self._get_enemy_features_from_obs(obs), onehot_move_actions, None, False])
        return actions_latent_avg.view(obs.shape[0], self.n_agents, self.n_actions, self.action_latent_dim)

    def other_actions(self, actions):
        # actions: [bs, n_agents, n_actions]
        assert actions.shape[1] == self.n_agents

        other_actions = []
        for i in range(self.n_agents):
            _other_actions = []
            for j in range(self.n_agents):
                if i != j:
                    _other_actions.append(actions[:, j])
            _other_actions = th.cat(_other_actions, dim=-1)
            other_actions.append(_other_actions)

        other_actions = th.stack(other_actions, dim=1).contiguous().view(-1, (self.n_agents - 1) * self.n_actions)
        return other_actions

    def _get_enemy_features_from_obs(self, obs):
        move_feats_t, enemy_feats_t, ally_feats_t, own_feats_t = th.split(obs, self.obs_delimiter, dim=-1)

        enemy_feats_t = enemy_feats_t.reshape(-1, self.n_enemies,
                                              self.enemy_feats_dim)  # [bs * n_agents, n_enemies, fea_dim]
        return enemy_feats_t
