# radar1
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging
import random

def cal_dist(state,agent_id):
    return abs(state[-2]-state[2*agent_id])+abs(state[-1]-state[2*agent_id+1])

class Radar2Env(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            n_agents=3,
            reward_win=10,
            obs_last_action=False,
            state_last_action=True,
            is_print=False,
            print_rew=False,
            print_steps=1000,
            seed=None
    ):
        # Map arguments
        self.print_rew = print_rew
        self.is_print = is_print
        self.print_steps = print_steps
        self._seed = random.randint(0, 9999)
        np.random.seed(self._seed)
        self.n_agents = n_agents
        self.n_states=2*n_agents

        # Observations and state
        self.obs_last_action = obs_last_action
        self.state_last_action = state_last_action

        # Rewards args
        self.reward_win = reward_win

        # Other
        self._seed = seed

        # Actions
        self.n_actions = 4

        # config
        self.scale=5
        self.n_radars=self.n_agents-1
        assert self.n_radars==8# just use 8 is better
        self.dir_list=np.array([[1,0],[1,1],[0,1],[-1,1],[-1,0],[-1,-1],[0,-1],[1,-1]])

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False

        self.last_action = np.zeros((self.n_agents, self.n_actions))

        self.episode_limit = 2*(self.scale*2+1)

        # initialize agents
        self.distance=np.random.choice(8,self.n_radars,replace=False)+2 #2-9
        dir=np.random.choice(8,self.n_radars,replace=False)
        self.direction=self.dir_list[dir]
        self.state=np.zeros((self.n_radars*2+2))
        for i in range(self.n_radars):
            if abs(self.direction[i][0])+abs(self.direction[i][1])==1:
                self.state[2*i:2*i+2]=self.distance[i]*self.direction[i]
            else:
                if self.distance[i]%2==0:
                    self.state[2*i:2*i+2]=self.distance[i]*self.direction[i]//2
                else:
                    the_more=int(random.random()*2)
                    self.state[2*i]=self.direction[i][0]*(self.distance[i]//2+the_more)
                    self.state[2*i+1]=self.direction[i][1]*(self.distance[i]//2+(1-the_more))
        self.best=np.where(self.distance==2)[0][0]
        self.min_dist=2
        if self.best==0:
            self.best=np.where(self.distance==3)[0][0]
            self.min_dist=3

    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}

        if self.is_print:
            print('t_steps: %d' % self._episode_steps)
            print(self.state_n)
            print(actions.cpu().numpy())

        if True:#for agent_i, action in enumerate(actions):
            action=actions[-1]
            if action == 0:#hold
                pass
            elif action == 1:#left
                self.state[-2] = max(-self.scale, self.state[-2] - 1)
            elif action == 2:#right
                '''if self.state[-2]==1 and self.state[-1]==0:
                    pass
                else:
                    self.state[-2] = min(self.scale, self.state[-2] + 1)'''
                self.state[-2] = min(self.scale, self.state[-2] + 1)
            elif action == 3:#up
                self.state[-1] = min(self.scale, self.state[-2] + 1)
            elif action == 4:#down
                self.state[-1] = max(-self.scale, self.state[-2] - 1)
            arrival=False
            for i in range(self.n_radars):
                if (self.state[i*2:i*2+2]==self.state[-2:]).all():
                    arrival=True
                    arrive_radar=i
                    break

        reward = 0
        terminated = False
        info['battle_won'] = False

        if arrival:#(self.state_n == self.goal_loc).all():
            reward = self.reward_win/self._episode_steps if arrive_radar!=0 else 0
            terminated = True
            self.battles_won += int(self.best==arrive_radar and self._episode_steps==self.min_dist)#1
            info['battle_won'] = True
        #elif (self.state_n == 0).any():
        #    terminated = True

        if self._episode_steps >= self.episode_limit:
            terminated = True

        if terminated:
            self._episode_count += 1
            self.battles_game += 1
        if self.print_rew:
            self.p_step += 1
            if terminated:
                self.rew_gather.append(reward)
            if self.p_step % self.print_steps == 0:
                print('steps: %d, average rew: %.3lf' % (self.p_step,
                                                         float(np.mean(self.rew_gather)) / self.reward_win))
                self.is_print_once = True

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return np.array([cal_dist(self.state,agent_id),self.state[2*agent_id]-self.state[-2],self.state[2*agent_id+1]-self.state[-1]])#self.state.reshape(3,self.n_states)[:,agent_id]

    def get_obs_size(self):
        """Returns the size of the observation."""
        return 3

    def get_state(self):
        """Returns the global state."""
        return self.state

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.n_states

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0
        
        self.distance=np.random.choice(8,self.n_radars,replace=False)+2 #2-9
        dir=np.random.choice(8,self.n_radars,replace=False)
        self.direction=self.dir_list[dir]
        self.state=np.zeros((self.n_radars*2+2))
        for i in range(self.n_radars):
            if abs(self.direction[i][0])+abs(self.direction[i][1])==1:
                self.state[2*i:2*i+2]=self.distance[i]*self.direction[i]
            else:
                if self.distance[i]%2==0:
                    self.state[2*i:2*i+2]=self.distance[i]*self.direction[i]//2
                else:
                    the_more=int(random.random()*2)
                    self.state[2*i]=self.direction[i][0]*(self.distance[i]//2+the_more)
                    self.state[2*i+1]=self.direction[i][1]*(self.distance[i]//2+(1-the_more))
        self.best=np.where(self.distance==2)[0][0]
        self.min_dist=2
        if self.best==0:
            self.best=np.where(self.distance==3)[0][0]
            self.min_dist=3

        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats

    def clean(self):
        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False
