from collections import defaultdict

import numpy as np
import random
from tf_agents.policies import py_tf_eager_policy
from tf_agents.trajectories import trajectory

from primary_backup.PrimaryBackup import reset_envs, step_envs
from primary_backup.State import State
from utils import log, hash_observation


def collect_data(
    envs, agent, observer, freezer, steps, players, all_combs, is_refine, failed_cases
):
    this_func = "PrimaryBackup_collect_data"
    num_round = envs[0].num_round
    num_steps = 0
    policy = py_tf_eager_policy.PyTFEagerPolicy(
        agent[0].collect_policy, use_tf_function=True
    )
    while num_steps < steps:
        if is_refine and num_steps < steps / 2:
            index = random.choice(failed_cases)
            setting = all_combs[index]
            log(this_func, f"pick crash comb id: {index}")
            log(this_func, "crash setting: " + str(setting))
        else:
            setting = None
        reset_envs(envs, players, input=setting)

        cur_round_idx = 0
        time_steps = [None] * players
        next_time_steps = [None] * players
        trajs = [[None] * num_round for i in range(players)]

        for i in range(players):
            time_steps[i] = envs[i].reset()

        while not time_steps[0].is_last():
            log(this_func, f"{time_steps}")
            action_steps = []
            actions = []
            sa = defaultdict()
            action_dict = defaultdict()
            # Pick action for each node, node with same observation will have same action
            # If node crashed, it will always choose Lost
            for i in range(players):
                action_step = policy.action(time_steps[i])

                # if the node crashed, it will always choose Lost
                if i in envs[0].crash_nodes:
                    # assume crash nodes' action is Lost
                    actions.append(int(State.Lost.value))
                    new_action_step = action_step.replace(
                        action=np.array(int(State.Lost.value), dtype=np.int32)
                    )
                    action_steps.append(new_action_step)
                    continue

                # Freeze some states to choose the desired action all the time
                if freezer.is_freezed():
                    action = freezer.apply_freeze_rules(time_steps[i].observation.tolist(), players)
                    if action is not None:
                        actions.append(action)
                        new_action_step = action_step.replace(action=np.array(action, dtype=np.int32))
                        action_steps.append(new_action_step)
                        continue

                # Regular action selection
                obs_list = hash_observation(time_steps[i].observation.tolist())
                if obs_list in sa:
                    actions.append(sa[obs_list])
                    action_steps.append(action_dict[obs_list])
                else:
                    sa[obs_list] = int(action_step.action)
                    action_dict[obs_list] = action_step
                    actions.append(int(action_step.action))
                    action_steps.append(action_step)

            log(this_func, f"action selected: {actions}")
            step_envs(envs, actions, players, input=setting if is_refine else None)

            for i in range(players):
                next_time_steps[i] = envs[i].step(action_steps[i])
                traj = trajectory.from_transition(
                    time_steps[i], action_steps[i], next_time_steps[i]
                )
                trajs[i][cur_round_idx] = traj
            cur_round_idx += 1
            time_steps = next_time_steps[:]

        for i in range(players):
            buffer_trace(envs[i], i, observer, trajs[i])
        num_steps += 1


def buffer_trace(env, idx, observer, traj_list):
    # env: my PrimaryBackupEnv
    # idx: the node index
    # observer: training data buffer
    this_func = "PrimaryBackup_buffer_trace"

    # For those nodes that make decision and crashed, we should store their trajectories
    if idx in env.crash_nodes:
        if env.state_machine.final_decision is not None:
            # transitions: [initial_state, action_1, action_2, ...], if action_1 == Lost, then this node crashed at 1st round
            transitions = env.state_machine.get_transitions()
            assert State.Lost.value in transitions
            crashed_round = transitions.index(State.Lost.value) - 1
            for i in range(crashed_round):
                traj_list[i] = traj_list[i].replace(
                    reward=np.array(traj_list[-1].reward, dtype=np.float32)
                )
                log(this_func, f"Store {traj_list[i]} into buffer")
                observer[0](traj_list[i])
    else:
        for traj in traj_list:
            traj = traj.replace(reward=np.array(traj_list[-1].reward, dtype=np.float32))
            log(this_func, f"Store {traj} into buffer")
            observer[0](traj)
