from tf_agents.environments import tf_py_environment

from MultiSM import MultiStateMachine
from tfagents.atomic_commit.AtomicCommitSMEnv import AtomicCommitSMEnv
from collections import defaultdict

import numpy as np
import tensorflow as tf

from inference import get_meanings
from atomic_commit.State import State, meanings
from utils import construct_bits_string
import distributed_locking.State as dl_state

LA = int(State.LocalAbort.value)
LC = int(State.LocalCommit.value)
DN = int(State.DoNothingCommit.value)
LOST1 = int(State.Lost_R1.value)
LOST2 = int(State.Lost_R2.value)
COMMIT = int(State.Commit.value)
ABORT = int(State.Abort.value)
DUMMY = int(State.Dummy.value)


class Tracker:
    def __init__(self, players, agent, eval, protocol):
        multi_sm = MultiStateMachine(players, protocol)
        self.py_env = AtomicCommitSMEnv(players, 0, multi_sm, None)
        self.tf_env = tf_py_environment.TFPyEnvironment(self.py_env)
        self.agent = agent
        self.policy = agent.policy
        self.players = players
        self.count = 0
        self.actions = [0] * 4
        self.rewards = [0] * 4
        self.all_next_states = defaultdict()
        self.evaluator = eval

        self.all_dicts = [dict() for i in range(6)]
        self.localcommits_lost_typ1_dict = dict()
        self.localcommits_lost_typ2_dict = dict()
        self.all_localcommits_dict = dict()
        self.donothing_lost_type1_dict = dict()
        self.donothing_lost_type2_dict = dict()
        self.anylocalabort_type1_dict = dict()
        self.all_donothing_dict = dict()
        self.track_all_dict = dict()

        # self.evaluator.mean_dict = self.track_all_dict

    def track_one_state(self, state):
        self.py_env.states = state
        time_step = self.tf_env.reset()
        action_step = self.policy.action(time_step)
        print(self.agent._q_network(time_step.observation)[0].numpy())
        print(
            f"{state} ({get_meanings(state)}) ---> {int(action_step.action)+1} ({meanings[int(action_step.action)+1]})"
        )

    def get_action_reward_count(self, exp):
        self.update_dicts(
            exp, [self.evaluator.localcommits_lost[0]], self.localcommits_lost_typ1_dict
        )
        self.update_dicts(
            exp, [self.evaluator.localcommits_lost[1]], self.localcommits_lost_typ2_dict
        )
        self.update_dicts(
            exp, [self.evaluator.all_localcommits[0]], self.all_localcommits_dict
        )
        self.update_dicts(
            exp, [self.evaluator.any_localabort[4]], self.anylocalabort_type1_dict
        )
        self.update_dicts(
            exp, [self.evaluator.donothing_lost[0]], self.donothing_lost_type1_dict
        )
        self.update_dicts(
            exp, [self.evaluator.donothing_lost[1]], self.donothing_lost_type2_dict
        )
        self.update_dicts(
            exp, [self.evaluator.all_donothing[0]], self.all_donothing_dict
        )
        self.update_dicts(exp, [[LA, LA, LC]], self.all_dicts[0])
        self.update_dicts(exp, [[LOST1, LA, LC]], self.all_dicts[1])
        self.update_dicts(exp, [[DN, LOST2, DN]], self.all_dicts[2])
        self.update_dicts(exp, [[LA, LA, LOST1]], self.all_dicts[3])
        self.update_dicts(exp, [[LOST2, DN, DN]], self.all_dicts[4])

    def extract_exp(self, exp, arrays):
        exp_np = exp.observation.numpy()
        condition = [np.all(exp_np == array, axis=2) for array in arrays]
        index = np.where(np.logical_or.reduce(condition))
        index = index[0][np.where(index[1] == 0)]
        round = exp.step_type.numpy()[index, 0]

        if len(round) == 0:
            return [], []
        else:
            round = round[0]

        if round == 0:
            rewards = exp.reward.numpy()[index, 1] * 0.5
        elif round == 1:
            rewards = exp.reward.numpy()[index, 0] * 0.5
        actions = exp.action.numpy()[index, 0]
        actions = actions + 1  # Action space is [1,3]
        return actions, rewards

    def update_dicts(self, exp, arrays, track_dict):
        actions, rewards = self.extract_exp(exp, arrays)
        for a, c in zip(actions, rewards):
            if a in track_dict.keys():
                if c in track_dict[a].keys():
                    track_dict[a][c] += 1
                else:
                    track_dict[a][c] = 1
            else:
                track_dict[a] = dict()
                track_dict[a][c] = 1

    def update_mean(self, exp, arrays, mean_values, counts):
        actions, rewards = self.extract_exp(exp, arrays)
        for a, c in zip(actions, rewards):
            mean_values[a - 1] = (mean_values[a - 1] * counts[a - 1] + c) / counts[
                a - 1
            ]
            counts[a - 1] += 1

    def compute_expectation(self, track_dict):
        total_cnt = 0
        track_dict = dict(sorted(track_dict.items()))
        mean_scores = []
        for action in track_dict.keys():
            sum = 0
            action_cnt = 0
            for reward in track_dict[action].keys():
                sum += reward * track_dict[action][reward]
                action_cnt += track_dict[action][reward]
            total_cnt += action_cnt
            print(
                f"action: {action} ({meanings[action]}) (cnt: {action_cnt}) sum: {sum:.4f}, mean: {sum / action_cnt:.4f}"
            )
            mean_scores.append(sum / action_cnt)
        print(f"total cnt: {total_cnt}")

    def track_all_states(self):
        print("------------------------------")
        print("|Localcommits and Lost (typ1)|")
        print("------------------------------")
        print(self.localcommits_lost_typ1_dict)
        self.compute_expectation(self.localcommits_lost_typ1_dict)
        self.track_one_state(self.evaluator.localcommits_lost[0])
        print("------------------------------")
        print("|Localcommits and Lost (typ2)|")
        print("------------------------------")
        print(self.localcommits_lost_typ2_dict)
        self.compute_expectation(self.localcommits_lost_typ2_dict)
        self.track_one_state(self.evaluator.localcommits_lost[1])
        print("------------------")
        print("|All Localcommits|")
        print("------------------")
        print(self.all_localcommits_dict)
        self.compute_expectation(self.all_localcommits_dict)
        self.track_one_state(self.evaluator.all_localcommits[0])
        print("---------------------------------")
        print("|Anylocalabort and Lost (type 1)|")
        print("---------------------------------")
        print(self.anylocalabort_type1_dict)
        self.compute_expectation(self.anylocalabort_type1_dict)
        self.track_one_state(self.evaluator.any_localabort[4])
        print("-----------------------------")
        print("|Donothing and Lost (type 1)|")
        print("-----------------------------")
        print(self.donothing_lost_type1_dict)
        self.compute_expectation(self.donothing_lost_type1_dict)
        self.track_one_state(self.evaluator.donothing_lost[0])
        print("-----------------------------")
        print("|Donothing and Lost (type 2)|")
        print("-----------------------------")
        print(self.donothing_lost_type2_dict)
        self.compute_expectation(self.donothing_lost_type2_dict)
        self.track_one_state(self.evaluator.donothing_lost[1])
        print("---------------")
        print("|All Donothing|")
        print("---------------")
        print(self.all_donothing_dict)
        self.compute_expectation(self.all_donothing_dict)
        self.track_one_state(self.evaluator.all_donothing[0])
        print("---------")
        print(f"{[LA, LA, LC]}")
        print("---------")
        print(self.all_dicts[0])
        self.compute_expectation(self.all_dicts[0])
        self.track_one_state([LA, LA, LC])
        print("---------")
        print(f"{[LOST1, LA, LC]}")
        print("---------")
        print(self.all_dicts[1])
        self.compute_expectation(self.all_dicts[1])
        self.track_one_state([LOST1, LA, LC])
        print("---------")
        print(f"{[DN, LOST2, DN]}")
        print("---------")
        print(self.all_dicts[2])
        self.compute_expectation(self.all_dicts[2])
        self.track_one_state([DN, LOST2, DN])
        print("---------")
        print(f"{[LA, LA, LOST1]}")
        print("---------")
        print(self.all_dicts[3])
        self.compute_expectation(self.all_dicts[3])
        self.track_one_state([LA, LA, LOST1])
        print("---------")
        print(f"{[LOST2, DN, DN]}")
        print("---------")
        print(self.all_dicts[4])
        self.compute_expectation(self.all_dicts[4])
        self.track_one_state([LOST2, DN, DN])

    def clear_all(self):
        self.all_localcommits_dict.clear()
        self.localcommits_lost_typ1_dict.clear()
        self.localcommits_lost_typ2_dict.clear()
        self.anylocalabort_type1_dict.clear()
        self.donothing_lost_type1_dict.clear()
        self.donothing_lost_type2_dict.clear()
        self.all_donothing_dict.clear()
        self.track_all_dict.clear()
        for d in self.all_dicts:
            d.clear()

    # {s1: [[m1, m2, m3], [c1, c2, c3], s2: ...}
    # s1: state
    # m1: mean of action 1
    # c1: count of action 1
    def track_all(self, exp):
        obs = exp.observation.numpy()[:, 0, :].tolist()
        actions = exp.action.numpy()[:, 0].tolist()
        rewards = exp.reward.numpy()[:, 0].tolist()
        for i, state in enumerate(obs):
            if str(state) in self.track_all_dict.keys():
                # update mean
                self.track_all_dict[str(state)][0][actions[i]] += (
                    rewards[i] - self.track_all_dict[str(state)][0][actions[i]]
                ) / self.track_all_dict[str(state)][1][actions[i]]
                # update count
                self.track_all_dict[str(state)][1][actions[i]] += 1
            else:
                self.track_all_dict[str(state)] = [[0, 0, 0], [1, 1, 1]]
                self.track_all_dict[str(state)][0][actions[i]] = rewards[i]
                self.track_all_dict[str(state)][1][actions[i]] += 1

            # TODO: maybe separate this into a function to replace the reward to mean value
            rewards[i] = self.track_all_dict[str(state)][0][actions[i]]
        rewards = np.reshape(rewards, (len(rewards), 1))

        return rewards

    # Get dict of all states and their mean rewards
    def get_all_states(self):
        print(f"Size of states: {len(self.track_all_dict)}")
        for i in range(2**self.players):
            print("----------------------------")
            states = list(construct_bits_string(i, self.players))
            initial_states = [0] * self.players
            for j in range(self.players):
                if int(states[j]) == 0:
                    # no need lock
                    initial_states[j] = dl_state.State.NoNeed.value
                if int(states[j]) == 1:
                    # need lock
                    initial_states[j] = dl_state.State.Need.value
            for j in range(self.players):
                states = str([j] + initial_states[:])
                if str(states) in self.track_all_dict.keys():
                    print(f"E({states}): {self.track_all_dict[states][0]}")
                    print(f"C({states}): {self.track_all_dict[states][1]}")
                    observation = self.construct_input_from_str(states)
                    print(self.policy.wrapped_policy._q_network(observation)[0].numpy())

    def construct_input_from_str(self, input):
        input = input.replace("[", "").replace("]", "").replace(" ", "")
        input = input.split(",")
        input = [[int(i) for i in input]]
        # construct tensorflow Tensor
        input = tf.constant(input, dtype=tf.int32)
        return input
