import torch
import numpy as np

import multiprocessing as mp

@torch.no_grad()
def act(model, states, device=None, temp=None):
    # prepare state input for pytorch
    states = torch.as_tensor(states, device=device)[None]
    act_logits = model(states, av=1)[0]
    if temp: act_logits /= temp
    dist = torch.distributions.Categorical(logits=act_logits)
    return dist.sample()


def rollout(env, *args, **kwargs):
    if env.sequential:
        return rollout_seq(env, *args, **kwargs)
    return rollout_sim(env, *args, **kwargs)

def test_rollout(env, *args, **kwargs):
    if env.sequential:
        return test_rollout_seq(env, *args, **kwargs)
    return test_rollout_sim(env, *args, **kwargs)


# simultaneous game
# T*B time-major order is more friendly to parallelism
def rollout_sim(env, T, B, S, x, y, record=3, device=None):
    rews = np.zeros((T,B), dtype=np.float32) # rew x
    if record:
        states = np.zeros((T+1,B,S), dtype=np.int32)
        acts_x = np.zeros((T,B), dtype=np.int32) if record & 1 else None
        acts_y = np.zeros((T,B), dtype=np.int32) if record & 2 else None
        # masks = np.zeros((T+1,B), dtype=np.int32)
        # masks = np.zeros((T+1,B), dtype=np.float32)
        masks = np.zeros((T+1,B), dtype=np.bool)

    for i in range(B):
        state = env.reset()
        t = 0
        while t < T:
            a = int(act(x, state, device=device))
            b = int(act(y, state, device=device))
            if record:
                states[t, i] = state
                masks[t, i] = 1
                if acts_x is not None: acts_x[t, i] = a
                if acts_y is not None: acts_y[t, i] = b
                state, reward, terminal, _ = env.step(a, b)
            rews[t, i] = reward
            if terminal: break
            t += 1
        if record and not terminal:
            states[t, i] = state
            masks[t, i] = 1

    if record == 0:
        return rews

    if device:
        states = torch.as_tensor(states, device=device)
        if acts_x is not None: acts_x = torch.as_tensor(acts_x, device=device)
        if acts_y is not None: acts_y = torch.as_tensor(acts_y, device=device)
        rews = torch.as_tensor(rews, device=device)
        masks = torch.as_tensor(masks, device=device)
    return states, acts_x, rews, masks,  states, acts_y, -rews, masks

def test_rollout_sim(env, T, B, act_x, act_y, gamma):
    rews = np.zeros(B, dtype=np.float32)
    for i in range(B):
        state = env.reset()
        act_x(None); act_y(None) # reset the agents
        r = 0
        for t in range(T):
            a = int(act_x(state))
            b = int(act_y(state))
            state, reward, terminal, _ = env.step(a, b)
            r += gamma * reward
            gamma *= gamma
            if terminal:
                break
        rews[i] = r
    return rews


# sequential game
def rollout_seq(env, T, B, S, x, y, record=3, device=None):
    rews_x = np.zeros((T,B), dtype=np.float32) # rew x

    if record & 1:
        states_x = np.zeros((T+1,B,S), dtype=np.int32)
        acts_x = np.zeros((T,B), dtype=np.int32)
        masks_x = np.zeros((T+1,B), dtype=np.float32)
        # masks_x = np.zeros((T+1,B), dtype=np.bool)
    else:
        states_x = acts_x = masks_x = None
    if record & 2:
        states_y = np.zeros((T+1,B,S), dtype=np.int32)
        acts_y = np.zeros((T,B), dtype=np.int32)
        masks_y = np.zeros((T+1,B), dtype=np.float32)
        # masks_y = np.zeros((T+1,B), dtype=np.bool)
        rews_y = np.zeros((T,B), dtype=np.float32) # rew y = -rew y
    else:
        states_y = acts_y = masks_y = rews_y = None

    for i in range(B):
        state, player = env.reset()
        terminal = False
        tx, ty = 0, 0
        while True:
            if player == 0:
                if tx == T:
                    if record & 1:
                        states_x[tx, i] = state
                        masks_x[tx, i] = 1
                    break
                a = int(act(x, state, device=device))
                if record & 1:
                    states_x[tx, i] = state
                    masks_x[tx, i] = 1
                    acts_x[tx, i] = a
                state, reward, terminal, player = env.step(a)
                tx += 1
            else: # player == 1
                if ty == T:
                    if record & 2:
                        states_y[ty, i] = state
                        masks_y[ty, i] = 1
                    break
                a = int(act(y, state, device=device))
                if record & 2:
                    states_y[ty, i] = state
                    masks_y[ty, i] = 1
                    acts_y[ty, i] = a
                state, reward, terminal, player = env.step(a)
                ty += 1
            if terminal:
                # TODO: assuming terminal reward only
                rews_x[tx-1, i] = reward
                if record & 2: rews_y[ty-1, i] = -reward
                break

    if record == 0:
        return rews_x

    if device:
        if record & 1:
            states_x = torch.as_tensor(states_x, device=device)
            acts_x = torch.as_tensor(acts_x, device=device)
            rews_x = torch.as_tensor(rews_x, device=device)
            masks_x = torch.as_tensor(masks_x, device=device)
        if record & 2:
            states_y = torch.as_tensor(states_y, device=device)
            acts_y = torch.as_tensor(acts_y, device=device)
            rews_y = torch.as_tensor(rews_y, device=device)
            masks_y = torch.as_tensor(masks_y, device=device)
    return states_x, acts_x, rews_x, masks_x,  states_y, acts_y, rews_y, masks_y


def test_rollout_seq(env, T, B, act_x, act_y, gamma):
    assert gamma == 1.0
    # g = 1.0
    rews = np.zeros(B, dtype=np.float32)
    for i in range(B):
        state, player = env.reset()
        act_x(None); act_y(None) # reset the agents
        terminal = False
        tx, ty = 0, 0
        # r = 0
        while True:
            if player == 0:
                if tx == T:
                    break
                a = int(act_x(state))
                state, reward, terminal, player = env.step(a)
                if terminal: break
                tx += 1
            else: # player == 1
                if ty == T:
                    break
                a = int(act_y(state))
                state, reward, terminal, player = env.step(a)
                if terminal: break
                ty += 1
            # r += g * reward
            # g *= gamma
        rews[i] = reward  # assume only terminal reward
    return rews
