import torch
import numpy as np
from tqdm import tqdm
# from envs import ParallelEnvWrapper

@torch.no_grad()
def act(model, states, device=None, temp=None):
    # prepare state input for pytorch
    states = torch.as_tensor(states, device=device)[None]
    act_logits = model(states, av=1)[0]
    if temp: act_logits /= temp
    dist = torch.distributions.Categorical(logits=act_logits)
    return int(dist.sample())

# continuous version, Gaussian policy
@torch.no_grad()
def act_cont(model, states, device=None):
    states = torch.as_tensor(states, device=device)[None]
    mean, std = model(states, av=1)
    dist = torch.distributions.Normal(mean[0], std[0].exp())
    raw = dist.sample()
    return raw.cpu().numpy(), torch.tanh(raw).cpu().numpy()

def rollout(env, *args, **kwargs):
    if env.sequential:
        return rollout_seq(env, *args, **kwargs)
    if hasattr(env, "continuous") and env.continuous:
        # if isinstance(env, ParallelEnvWrapper)
        return rollout_sim_cont(env, *args, **kwargs)
    return rollout_sim(env, *args, **kwargs)

def test_rollout(env, *args, **kwargs):
    if env.sequential:
        return test_rollout_seq(env, *args, **kwargs)
    # if hasttr(env, "continuous") and env.continuous:
    #     return test_rollout_sim_cont(env, *args, **kwargs)
    return test_rollout_sim(env, *args, **kwargs)

# simultaneous game
# T*B time-major order is more friendly to parallelism
def rollout_sim(env, T, B, S, x, y, record=3, device=None):
    rews = np.zeros((T,B), dtype=np.float32) # rew x
    if record:
        states = np.zeros((T+1,B,S), dtype=np.int32)
        acts_x = np.zeros((T,B), dtype=np.int32) if record & 1 else None
        acts_y = np.zeros((T,B), dtype=np.int32) if record & 2 else None
        # masks = np.zeros((T+1,B), dtype=np.int32)
        masks = np.zeros((T+1,B), dtype=np.float32)

    for i in range(B):
        state = env.reset()
        t = 0
        while t < T:
            a = int(act(x, state, device=device))
            b = int(act(y, state, device=device))
            if record:
                states[t, i] = state
                masks[t, i] = 1
                if acts_x is not None: acts_x[t, i] = a
                if acts_y is not None: acts_y[t, i] = b
                state, reward, terminal, _ = env.step(a, b)
            rews[t, i] = reward
            if terminal: break
            t += 1
        if record and not terminal:
            states[t, i] = state
            masks[t, i] = 1

    if record == 0:
        return rews

    if device:
        states = torch.as_tensor(states, device=device)
        if acts_x is not None: acts_x = torch.as_tensor(acts_x, device=device)
        if acts_y is not None: acts_y = torch.as_tensor(acts_y, device=device)
        rews = torch.as_tensor(rews, device=device)
        masks = torch.as_tensor(masks, device=device)
    return states, acts_x, rews, masks,  states, acts_y, -rews, masks

def test_rollout_sim(env, T, B, act_x, act_y, gamma):
    rews = np.zeros(B, dtype=np.float32)
    for i in range(B):
        state = env.reset()
        act_x(None); act_y(None) # reset the agents
        r = 0; g = 1.0
        for t in range(T):
            a = act_x(state)
            b = act_y(state)
            state, reward, terminal, _ = env.step(a, b)
            r += g * reward
            g *= gamma
            if terminal:
                break
        rews[i] = r
        # if win_stat is not None:
        #     win_stat[i] = reward
    return rews


# simultaneous continuous control version
# T*B time-major order is more friendly to parallelism
def rollout_sim_cont(env, T, B, S, x, y, record=3, device=None, pbar=False):
    rews = np.zeros((T,B), dtype=np.float32) # rew x
    if record:
        if record & 1:
            states_x = np.zeros((T+1,B,S), dtype=np.float32)
            acts_x = np.zeros((T,B,env.nact), dtype=np.float32)
        if record & 2:
            states_y = np.zeros((T+1,B,S), dtype=np.float32)
            acts_y = np.zeros((T,B,env.nact), dtype=np.float32)
        # masks = np.zeros((T+1,B), dtype=np.float32)
        masks = np.zeros((T+1,B), dtype=np.bool)

    rang = range(B)
    if pbar: rang = tqdm(rang)
    for i in rang:
        state = env.reset()
        t = 0
        while t < T:
            a, aa = act_cont(x, state[0], device=device)
            b, bb = act_cont(y, state[1], device=device)
            if record:
                if record & 1:
                    states_x[t, i] = state[0]
                    acts_x[t, i] = a
                if record & 2:
                    states_y[t, i] = state[1]
                    acts_y[t, i] = b
                # states_x[t, i], states_y[t, i] = state
                masks[t, i] = 1
                state, reward, terminal, _ = env.step(aa, bb)
            rews[t, i] = reward
            if terminal: break
            t += 1
        if record and not terminal:
            if record & 1: states_x[t, i] = state[0]
            if record & 2: states_y[t, i] = state[1]
            masks[t, i] = 1

    if record == 0:
        return (rews,)

    if device:
        if record & 1:
            states_x = torch.as_tensor(states_x, device=device)
            acts_x = torch.as_tensor(acts_x, device=device)
        if record & 2:
            states_y = torch.as_tensor(states_y, device=device)
            acts_y = torch.as_tensor(acts_y, device=device)
        rews = torch.as_tensor(rews, device=device)
        masks = torch.as_tensor(masks, device=device)
    if record == 1:
        return states_x, acts_x, rews, masks,  None, None, None, None
    elif record == 2:
        return None, None, None, None,  states_y, acts_y, -rews, masks
    return states_x, acts_x, rews, masks,  states_y, acts_y, -rews, masks


# sequential game

def rollout_seq(env, T, B, S, x, y, record=3, device=None):
    rews_x = np.zeros((T,B), dtype=np.float32) # rew x

    if record & 1:
        states_x = np.zeros((T+1,B,S), dtype=np.int32)
        acts_x = np.zeros((T,B), dtype=np.int32)
        # masks_x = np.zeros((T+1,B), dtype=np.float32)
        masks_x = np.zeros((T+1,B), dtype=np.bool)
    else:
        states_x = acts_x = masks_x = None
    if record & 2:
        states_y = np.zeros((T+1,B,S), dtype=np.int32)
        acts_y = np.zeros((T,B), dtype=np.int32)
        # masks_y = np.zeros((T+1,B), dtype=np.float32)
        masks_y = np.zeros((T+1,B), dtype=np.bool)
        rews_y = np.zeros((T,B), dtype=np.float32) # rew y = -rew y
    else:
        states_y = acts_y = masks_y = rews_y = None

    for i in range(B):
        state, player = env.reset()
        terminal = False
        tx, ty = 0, 0
        while True:
            if player == 0:
                if tx == T:
                    if record & 1:
                        states_x[tx, i] = state
                        masks_x[tx, i] = 1
                    break
                a = int(act(x, state, device=device))
                if record & 1:
                    states_x[tx, i] = state
                    masks_x[tx, i] = 1
                    acts_x[tx, i] = a
                state, reward, terminal, player = env.step(a)
                tx += 1
            else: # player == 1
                if ty == T:
                    if record & 2:
                        states_y[ty, i] = state
                        masks_y[ty, i] = 1
                    break
                a = int(act(y, state, device=device))
                if record & 2:
                    states_y[ty, i] = state
                    masks_y[ty, i] = 1
                    acts_y[ty, i] = a
                state, reward, terminal, player = env.step(a)
                ty += 1
            if terminal:
                # TODO: assuming terminal reward only
                rews_x[tx-1, i] = reward
                if record & 2: rews_y[ty-1, i] = -reward
                break

    if record == 0:
        return rews_x

    if device:
        if record & 1:
            states_x = torch.as_tensor(states_x, device=device)
            acts_x = torch.as_tensor(acts_x, device=device)
            rews_x = torch.as_tensor(rews_x, device=device)
            masks_x = torch.as_tensor(masks_x, device=device)
        if record & 2:
            states_y = torch.as_tensor(states_y, device=device)
            acts_y = torch.as_tensor(acts_y, device=device)
            rews_y = torch.as_tensor(rews_y, device=device)
            masks_y = torch.as_tensor(masks_y, device=device)
    return states_x, acts_x, rews_x, masks_x,  states_y, acts_y, rews_y, masks_y


def test_rollout_seq(env, T, B, act_x, act_y, gamma):
    assert gamma == 1.0

    rews = np.zeros(B, dtype=np.float32)
    for i in range(B):
        state, player = env.reset()
        act_x(None); act_y(None) # reset the agents
        terminal = False
        tx, ty = 0, 0
        # g = 1.0
        while True:
            if player == 0:
                if tx == T:
                    break
                a = int(act_x(state))
                state, reward, terminal, player = env.step(a)
                if terminal: break
                tx += 1
            else: # player == 1
                if ty == T:
                    break
                a = int(act_y(state))
                state, reward, terminal, player = env.step(a)
                if terminal: break
                ty += 1
            # assume episodic reward?
            # r += g * reward
            # g *= gamma
        rews[i] = reward
        # if win_stat is not None:
        #     win_stat[i] = reward
    return rews



## parallel runner
import multiprocessing as mp

# For now, only supports Sumo env
class ParallelRunner:
    def __init__(self, nproc):
        mp.set_start_method('spawn')
        self.qin = mp.Queue()
        self.qout = mp.Queue()
        self.procs = [mp.Process(target=ParallelRunner.process_eval,
                        args=(self.qin, self.qout, _)) for _ in range(nproc)]
        for p in self.procs: p.start()
        self.nproc = nproc
    def __del__(self):
        try:
            for p in self.procs: self.qin.put(None)
            for p in self.procs: p.join()
            self.qin.close()
            self.qout.close()
        except:
            pass

    @staticmethod
    @torch.no_grad()
    def process_eval(qin, qout, proc_idx):
        import envs, model

        device = torch.device("cpu")
        # ensure proper randomness
        # np.random.seed(proc_idx)
        # torch.manual_seed(proc_idx)
        # torch.cuda.manual_seed_all(proc_idx)
        print ("[proc %d] started" % proc_idx, device, torch.randn(()), np.random.rand())

        env = envs.SumoEnv()
        state_dim = env.state_dim
        nact = env.nact

        MODEL = model.MLPControl

        xnn = MODEL().to(device).eval()
        ynn = MODEL().to(device).eval()

        def get_act(a, player_a_or_b):
            if a[0] == 'random':
                def act_func(state):
                    if state is not None:
                        return env.env.action_space[player_a_or_b].sample()
            elif a[0] == 'builtin':
                def act_func(state):
                    if state is not None:
                        return env.env.action_space[player_a_or_b].sample() * 0
            else:
                x = xnn if player_a_or_b == 0 else ynn
                # if a[1] is not None:
                x.policy.load_state_dict(a[1])
                def act_func(state):
                    if state is not None:
                        # [0]raw_act, [1]tanh(raw_act)
                        return act_cont(x, state[player_a_or_b], device=device)[1]
            return act_func

        while True:
            params = qin.get()
            if params is None: break
            cmd, T, B, x, y, *extra = params
            act_x = get_act(x, 0)
            act_y = get_act(y, 1)
            # print (x[0], act_x, y[0], act_y)
            if cmd == 'test':
                ret = test_rollout_sim(env, T, B, act_x, act_y, *extra)
            else:
                ret = rollout_sim_cont(env, T, B, state_dim, xnn, ynn, *extra)
                    # pbar=proc_idx==0)
            qout.put(ret)

    # env, S, device are not used
    def rollout(self, env, T, B, S, x, y, record=3, device=None):
        # assert B % self.nproc == 0
        n0 = B // self.nproc
        m = B % self.nproc
        x = ('mlp', x.policy.state_dict())
        y = ('mlp', y.policy.state_dict())
        for i in range(self.nproc):
            n = n0 + int(m > 0); m -= 1
            self.qin.put(('train', T, n, x, y, record))
        r = [self.qout.get() for i in range(self.nproc)]
        r = [torch.from_numpy(np.concatenate(a, 1)).to(device) if a[0] is not None else None
             for a in zip(*r)]
        # s: (T+1=101,B=100,120), a: (T,B,8), r: (T,B), m: (T+1,B)
        # for a in r:
        #     if a is not None:
        #         print (a.shape)
        return r

    # env, device are not used
    def test_rollout(self, env, T, B, x, y, gamma, device=None):
        # assert B % self.nproc == 0
        n0 = B // self.nproc
        m = B % self.nproc
        # x = ('mlp', x); y = ('mlp', y)
        for i in range(self.nproc):
            n = n0 + int(m > 0); m -= 1
            self.qin.put(('test', T, n, x, y, gamma))
        r = [self.qout.get() for i in range(self.nproc)]
        r = np.concatenate(r)
        return r
