from smac.env.multiagentenv import MultiAgentEnv
from smac.env.starcraft2.maps import get_map_params

import torch as th
import torch.nn as nn
import torch.nn.functional as F
import enum
import numpy as np

N_AGENTS = 2
FIELD_SIZE = [9, 9]
SIGHT = 1
EPISODE_LIMIT = 50
REACH_RANGE = 0


class NavigationDecomposer:
    def __init__(self, args):
        # Load map params
        self.args = args
        # Obtain task id and task config
        self.task_id = args.env_args["task_id"]
        # Some fixed information
        self.n_agents = N_AGENTS
        self.field_size = FIELD_SIZE
        self.n_landmarks = 1
        self.n_actions = 5
        
        self.episode_limit = getattr(self.args.env_args, "episode_limit", EPISODE_LIMIT)
        
        self.own_feats, self.ally_feats, self.enemy_feats, self.obs_nf_en, self.obs_nf_al = \
            self.get_obs_size()
        self.own_obs_dim = self.own_feats
        self.obs_dim = self.own_obs_dim + self.enemy_feats + self.ally_feats
        
        self.enemy_state_dim, self.ally_state_dim, self.state_nf_en, self.state_nf_al = \
            self.get_state_size()
        self.state_dim = self.enemy_state_dim + self.ally_state_dim

    def get_state_size(self):
        nf_al = nf_en = 2
        enemy_state = self.n_landmarks * nf_en
        ally_state = self.n_agents * nf_al
        
        return enemy_state, ally_state, nf_en, nf_al

    def get_obs_size(self):
        nf_al = nf_en = 2
        own_feats = 2
        enemy_feats = self.n_landmarks * nf_en
        ally_feats = (self.n_agents - 1) * nf_al

        return own_feats, ally_feats, enemy_feats, nf_en, nf_al

    def decompose_state(self, state_input):
        # state_input = [ally_state, enemy_state]
        # assume state_input.shape == [batch_size, seq_len, state]
        ally_states = [state_input[:, :, i * self.state_nf_al:(i + 1) * self.state_nf_al] for i in range(self.n_agents)]
        base = self.n_agents * self.state_nf_al
        enemy_states = [state_input[:, :, base + i * self.state_nf_en:base + (i + 1)*self.state_nf_en] for i in range(self.n_landmarks)]
        return ally_states, enemy_states

    def decompose_obs(self, obs_input):
        own_feats = obs_input[:, :self.own_feats]
        base = self.own_feats
        ally_feats = [obs_input[:, base + i * self.obs_nf_al:base + (i + 1) * self.obs_nf_al] for i in range(self.n_agents - 1)]
        base += self.obs_nf_al * (self.n_agents - 1)
        enemy_feats = [obs_input[:, base + i * self.obs_nf_en:base + (i + 1) * self.obs_nf_en] for i in range(self.n_landmarks)]
        return own_feats, enemy_feats, ally_feats