import random
import time
from collections import defaultdict

from .AtomicCommit import reset_envs, step_envs
from tf_agents.policies import py_tf_eager_policy
from tf_agents.trajectories import time_step as ts
from tf_agents.trajectories import trajectory

import numpy as np

from utils import log, hash_observation
from atomic_commit.State import State


def collect_data(
    envs, agent, observer, freezer, steps, players, all_combs, is_refine, failed_cases
):
    this_func = "collect_data"
    num_steps = 0
    policy = py_tf_eager_policy.PyTFEagerPolicy(
        agent[0].collect_policy, use_tf_function=True
    )
    while num_steps < steps:
        if is_refine and num_steps < steps / 2:
            index = random.choice(failed_cases)
            setting = all_combs[index]
            log(this_func, f"pick crash comb id: {index}")
            log(this_func, "crash setting: " + str(setting))
        else:
            setting = None
        reset_envs(envs, players, input=setting)

        time_steps = [None] * players
        next_time_steps = [None] * players
        trajs = [[1] * 2 for i in range(players)]

        for i in range(players):
            time_steps[i] = envs[i].reset()

        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            sa = defaultdict()
            action_dict = defaultdict()
            for i in range(players):
                action_step = policy.action(time_steps[i])

                # if the node crashed, it will always choose Lost
                if i in envs[0].crash_nodes:
                    # assume crash nodes' action is Lost_R2, only for round 2 states
                    actions.append(int(State.Lost.value))
                    new_action_step = action_step.replace(
                        action=np.array(int(State.Lost.value), dtype=np.int32)
                    )
                    action_steps.append(new_action_step)
                    continue

                # Freeze some states to choose the desired action all the time
                if freezer.is_freezed():
                    action = freezer.apply_freeze_rules(time_steps[i].observation.tolist(), players)
                    if action is not None:
                        actions.append(action)
                        new_action_step = action_step.replace(action=np.array(action, dtype=np.int32))
                        action_steps.append(new_action_step)
                        continue

                # Regular action selection
                obs_list = hash_observation(time_steps[i].observation.tolist())
                # have seen same observation with other agents, so they should make same decisions
                if obs_list in sa:
                    actions.append(sa[obs_list])
                    action_steps.append(action_dict[obs_list])
                else:
                    sa[obs_list] = int(action_step.action)
                    action_dict[obs_list] = action_step
                    actions.append(int(action_step.action))
                    action_steps.append(action_step)

            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players, input=setting if is_refine else None)

            for i in range(players):
                next_time_steps[i] = envs[i].step(action_steps[i])
                traj = trajectory.from_transition(
                    time_steps[i], action_steps[i], next_time_steps[i]
                )
                trajs[i][time_steps[i].step_type] = traj
            time_steps = next_time_steps[:]

        for i in range(players):
            # For those nodes that make decision at 1st round but crash at 2nd round
            # Should also consider them
            if i in envs[0].crash_nodes:
                if envs[i].state_machine.final_decision is not None:
                    trajs[i][0] = trajs[i][0].replace(
                        reward=np.array(trajs[i][1].reward, dtype=np.float32)
                    )
                    log(this_func, f"Store {trajs[i][0]} into buffer 0")
                    observer[0](trajs[i][0])
                continue
            trajs[i][0] = trajs[i][0].replace(
                reward=np.array(trajs[i][1].reward, dtype=np.float32)
            )
            for traj in trajs[i]:
                log(this_func, f"Store {traj} into buffer 0")
                observer[0](traj)
        num_steps += 1
