# calc6
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from smac.env.multiagentenv import MultiAgentEnv

import atexit
from operator import attrgetter
from copy import deepcopy
import numpy as np
import enum
import math
from absl import logging
import random

jinzhi=2

def calc_add(a,b):#n%2==0
    assert jinzhi==2
    c=np.zeros((len(a)))
    for i in range(len(a)-1,-1,-1):
        c[i]=(a[i]+b[i])%jinzhi
    return c

class CalcdbEnv(MultiAgentEnv):
    """The StarCraft II environment for decentralised multi-agent
    micromanagement scenarios.
    """
    def __init__(
            self,
            n_agents=3,
            reward_win=10,
            obs_last_action=False,
            state_last_action=True,
            is_print=False,
            print_rew=False,
            print_steps=1000,
            seed=None
    ):
        # Map arguments
        self.print_rew = print_rew
        self.is_print = is_print
        self.print_steps = print_steps
        self._seed = random.randint(0, 9999)
        np.random.seed(self._seed)
        self.n_agents = n_agents
        self.n_states=n_agents

        # Observations and state
        self.obs_last_action = obs_last_action
        self.state_last_action = state_last_action

        # Rewards args
        self.reward_win = reward_win

        # Other
        self._seed = seed

        # Actions
        self.n_actions = jinzhi

        # Statistics
        self._episode_count = 0
        self._episode_steps = 0
        self._total_steps = 0
        self.battles_won = 0
        self.battles_game = 0

        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False

        self.last_action = np.zeros((self.n_agents, self.n_actions))

        self.episode_limit = 1

        # initialize agents
        self.state = np.random.randint(0,jinzhi,(self.n_states*2))
        self.goal_id=np.random.randint(0,self.n_states)
        self.is_goal=np.zeros((self.n_states))
        self.is_goal[self.goal_id]=1
        self.state=np.concatenate([self.state,self.is_goal])
        xor=calc_add(self.state[:self.n_states],self.state[self.n_states:2*self.n_states])
        self.goal=(xor+xor[self.goal_id])%2

    def step(self, actions):
        """Returns reward, terminated, info."""
        self._total_steps += 1
        self._episode_steps += 1
        info = {}

        if self.is_print:
            print('t_steps: %d' % self._episode_steps)
            print(self.state_n)
            print(actions.cpu().numpy())

        reward = 0
        terminated = False
        info['battle_won'] = False

        if True:
            reward = (actions == self.goal).sum()#self.reward_win
            terminated = True
            self.battles_won += (actions == self.goal).all()#1
            info['battle_won'] = (actions == self.goal).all()#True
        #elif (self.state_n == 0).any():
        #    terminated = True

        if self._episode_steps >= self.episode_limit:
            terminated = True

        if terminated:
            self._episode_count += 1
            self.battles_game += 1

        if self.print_rew:
            self.p_step += 1
            if terminated:
                self.rew_gather.append(reward)
            if self.p_step % self.print_steps == 0:
                print('steps: %d, average rew: %.3lf' % (self.p_step,
                                                         float(np.mean(self.rew_gather)) / self.reward_win))
                self.is_print_once = True

        return reward, terminated, info

    def get_obs(self):
        """Returns all agent observations in a list."""
        return [self.get_obs_agent(i) for i in range(self.n_agents)]

    def get_obs_agent(self, agent_id):
        """Returns observation for agent_id."""
        return self.state.reshape(3,self.n_states)[:,agent_id]

    def get_obs_size(self):
        """Returns the size of the observation."""
        return 3

    def get_state(self):
        """Returns the global state."""
        return self.state

    def get_state_size(self):
        """Returns the size of the global state."""
        return self.n_states*3

    def get_avail_actions(self):
        """Returns the available actions of all agents in a list."""
        return [self.get_avail_agent_actions(i) for i in range(self.n_agents)]

    def get_avail_agent_actions(self, agent_id):
        """Returns the available actions for agent_id."""
        return [1] * self.n_actions

    def get_total_actions(self):
        """Returns the total number of actions an agent could ever take."""
        return self.n_actions

    def reset(self):
        """Returns initial observations and states."""
        self._episode_steps = 0
        
        self.state = np.random.randint(0,jinzhi,(self.n_states*2))
        self.goal_id=np.random.randint(0,self.n_states)
        self.is_goal=np.zeros((self.n_states))
        self.is_goal[self.goal_id]=1
        self.state=np.concatenate([self.state,self.is_goal])
        xor=calc_add(self.state[:self.n_states],self.state[self.n_states:2*self.n_states])
        self.goal=(xor+xor[self.goal_id])%2

        return self.get_obs(), self.get_state()

    def render(self):
        pass

    def close(self):
        pass

    def seed(self):
        pass

    def save_replay(self):
        """Save a replay."""
        pass

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

    def get_stats(self):
        stats = {
            "battles_won": self.battles_won,
            "battles_game": self.battles_game,
            "win_rate": self.battles_won / self.battles_game
        }
        return stats

    def clean(self):
        self.p_step = 0
        self.rew_gather = []
        self.is_print_once = False
