import numpy as np, functools

try:
    import robosumo
except:
    pass

class SumoEnv:
    sequential = False
    continuous = True

    def __init__(self, single_i=None, reward_shaping=True):
        ## TODO: gym warning is turned off.
        gym.logger.set_level(50)
        self.env = gym.make('RoboSumo-Ant-vs-Ant-v0')
        self.single_i = single_i
        self.single_agent = single_i is not None
        # if self.single_agent:
        self.env.reward_shaping = reward_shaping
        for agent in self.env.agents:
            agent._adjust_z = -0.5
        # print (f"[SumoEnv] Ant, single_agent={self.single_agent}[i={single_i}]")
        # observation_space: (Box(120,), Box(120,))
        # action_space: (Box(8,), Box(8,))
        self.state_dim = 120 # self.env.observation_space
        self.nact = 8 # self.env.action_space
    def reset(self):
        o = self.env.reset()
        return (o[0].astype(np.float32), o[1].astype(np.float32))
    def step(self, a, b):
        o,r,t,i = self.env.step((a, b))
        if self.single_agent:
            o,r,t,i = o[self.single_i],r[self.single_i],t[self.single_i],i[self.single_i]
            o = o.astype(np.float32)
            r /= 2000
            return (o,o),r,t,i

        o = (o[0].astype(np.float32), o[1].astype(np.float32))
        return o, (r[0]-r[1])/4000, t[0], i
        # if any(t): # t[0]==t[1] always
        # if t[0]:
        #     # print (i)
        #     if i[0].get('winner', False):
        #         reward0 = 1.0
        #     elif i[1].get('winner', False):
        #         reward0 = -1.0
        #     # elif i[0]['main_reward'] == self.env.DRAW_PENALTY:
        #     else:
        #         reward0 = 0.0
        #         print ('draw')
        #     return o,reward0,True,i
        # else:
        #     # default tatami_size = 2 --> a 4x4 square arena
        #     # 0 <= dist_to_center = |qx|+|qy| <= 2*2
        #     # default max T = 500; so sum(rew) <= coef*4*500
        #     #reward0 = (i[1]['dist_to_center'] - i[0]['dist_to_center']) * 5e-4
        #     reward0 = 0.0
        #     # xyz0, xyz1 = i[0]['xyz'] = i[1]['xyz']
        #     # np.linalg.norm(self_xyz[:2], ord=1)
        # return o,reward0,False,i


# TODO: a decorator is better?
class ParallelEnvWrapper:
    # Some environments cannot be pickled, so here is the work-around.
    # forked processes operate copy-on-write, so each process in the pool
    # creates its own copy of env.
    _env, _i = None, None
    @staticmethod
    def _make(i, *args, env_class=None, **kwargs):
        # pid = os.getpid()
        ParallelEnvWrapper._env = env_class(*args, **kwargs)
        ParallelEnvWrapper._i = i
        # print (pid, i, ParallelEnvWrapper._env)
    @staticmethod
    def _reset(_):
        return ParallelEnvWrapper._i, ParallelEnvWrapper._env.reset()
    @staticmethod
    def _step(_, a=None):
        if a[0][ParallelEnvWrapper._i] is None: return ParallelEnvWrapper._i, None
        return ParallelEnvWrapper._i, ParallelEnvWrapper._env.step(*[aa[ParallelEnvWrapper._i] for aa in a])
    def __init__(self, env_class, nproc, *args, **kwargs):
        # assert ParallelEnvWrapper._envs[0] is None, \
        #         "ParallelEnvWrapper can only have one instance due to shared static variables"
        # self.envs = [env_class(*args, **kwargs) for _ in range(nproc)]
        import multiprocessing as mp
        self.pool = mp.Pool(nproc)
        self.nproc = nproc
        self.pool.map(functools.partial(ParallelEnvWrapper._make,
                                        *args, env_class=env_class, **kwargs),
                      range(self.nproc))
    def reset(self):
        ret = self.pool.map(ParallelEnvWrapper._reset, range(self.nproc))
        ret_reorder = [None] * self.nproc
        for r in ret: ret_reorder[r[0]] = r[1]
        return ret_reorder
    def step(self, *a):
        # passing all actions is necessary, as order of args may not match order of process.
        # a[i] can be None to do nothing in env_i, e.g. after episode has ended
        ret = self.pool.map(functools.partial(ParallelEnvWrapper._step, a=a),
                            range(self.nproc))
        ret_reorder = [None] * self.nproc
        for r in ret: ret_reorder[r[0]] = r[1]
        return ret_reorder
    def close(self):
        self.pool.close()
        self.pool.join()
