import itertools

from tf_agents.environments import tf_py_environment

from tfagents.atomic_commit.AtomicCommitSMEnv import AtomicCommitSMEnv
from MultiSM import MultiStateMachine
from checkpoint import restore_policy
from atomic_commit.State import State, meanings
from config import IS_ENCODED_HISTORY
from utils import parse_arg, action_map

import numpy as np

class evaluator():
    def __init__(self, players, agent=None):
        self.all_combs = {}
        self.all_localcommits = all_localcommits(players)
        print(f"all local commits: {self.all_localcommits}")
        self.all_localcommits_allcomb = get_all_combs(self.all_localcommits)
        self.all_combs["all_localcommits"] = self.all_localcommits_allcomb
        print(f"all local commits (all comb): {self.all_localcommits_allcomb}")
        self.localcommits_lost = localcommits_lost(players)
        print(f"localcommits + lost: {self.localcommits_lost}")
        self.localcommits_lost_allcomb = get_all_combs(self.localcommits_lost)
        self.all_combs["localcommits_lost"] = self.localcommits_lost_allcomb
        print(f"localcommits + lost (all comb): {self.localcommits_lost_allcomb}")
        self.localcommits_lost_typ1 = get_all_combs([self.localcommits_lost[0]])
        print(f"localcommits + lost (type1) (all comb): {self.localcommits_lost_typ1}")
        self.localcommits_lost_typ2 = get_all_combs([self.localcommits_lost[1]])
        print(f"localcommits + lost (type2) (all comb): {self.localcommits_lost_typ2}")
        self.any_localabort = any_localabort(players)
        print(f"any localabort: {self.any_localabort}")
        self.any_localabort_allcomb = get_all_combs(self.any_localabort)
        self.all_combs["any_localabort"] = self.any_localabort_allcomb
        print(f"any localabort (all comb): {self.any_localabort_allcomb}")
        self.any_localabort_type1_allcomb = get_all_combs([self.any_localabort[4]])
        print(f"any localabort (type1) (all comb): {self.any_localabort_type1_allcomb}")
        self.any_abort, self.onlyabort_lost = any_abort(players)
        print(f"any abort (except only abort and lost): {self.any_abort}")
        print(f"only abort and lost: {self.onlyabort_lost}")
        self.any_abort_allcomb = get_all_combs(self.any_abort)
        self.all_combs["any_abort"] = self.any_abort_allcomb
        self.onlyabort_lost_allcomb = get_all_combs(self.onlyabort_lost)
        self.all_combs["onlyabort_lost"] = self.onlyabort_lost_allcomb
        print(f"any abort (all comb): {self.any_abort_allcomb}")
        print(f"only abort and lost (all comb): {self.onlyabort_lost_allcomb}")
        self.all_donothing = all_donothing(players)
        print(f"all do nothing: {self.all_donothing}")
        self.all_donothing_allcomb = get_all_combs(self.all_donothing)
        self.all_combs["all_donothing"] = self.all_donothing_allcomb
        print(f"all do nothing (all comb): {self.all_donothing_allcomb}")
        self.donothing_lost = donothing_lost(players)
        print(f"do nothing + lost: {self.donothing_lost}")
        self.donothing_lost_allcomb = get_all_combs(self.donothing_lost)
        self.all_combs["donothing_lost"] = self.donothing_lost_allcomb
        print(f"do nothing + lost (all comb): {self.donothing_lost_allcomb}")
        self.donothing_lost_typ1_allcomb = get_all_combs([self.donothing_lost[0]])
        self.donothing_lost_typ2_allcomb = get_all_combs([self.donothing_lost[1]])
        print(f"do nothing + lost (type1) (all comb): {self.donothing_lost_typ1_allcomb}")
        print(f"do nothing + lost (type2) (all comb): {self.donothing_lost_typ2_allcomb}")
        self.any_commit, self.onlycommit_lost = any_commit(players)
        print(f"any commit (except only commit and lost): {self.any_commit}")
        print(f"only commit and lost: {self.onlycommit_lost}")
        self.any_commit_allcomb = get_all_combs(self.any_commit)
        self.all_combs["any_commit"] = self.any_commit_allcomb
        self.onlycommit_lost_allcomb = get_all_combs(self.onlycommit_lost)
        self.all_combs["onlycommit_lost"] = self.onlycommit_lost_allcomb
        print(f"any commit (all comb): {self.any_commit_allcomb}")
        print(f"only commit and lost (all comb): {self.onlycommit_lost_allcomb}")
        self.solution = None
        self.players = players
        self.agent = agent
        self.mean_dict = {}

    def evaluate(self, env, tf_env, policy):
        ret = True
        self.eval_comb(env, tf_env, policy, self.all_localcommits, "all local commits", None)
        #determine which solution it is
        env.states = [State.LocalCommit.value] * self.players
        time_step = tf_env.reset()
        action_step = policy.action(time_step)
        if self.agent:
            print(f"policy: {self.agent._q_network(time_step.observation)[0].numpy()}")
        # if int(action_step.action) == State.DoNothingCommit.value:
        #     self.solution = 1
        # elif int(action_step.action) == State.Commit.value:
        #     self.solution = 2
        # else:
        #     print("Wrong solution")
        #     return False
        self.solution = 2
        # print(f"Solution {[State.LocalCommit.value] * self.players} ---> {int(action_step.action)+1} "
        #       f"({meanings[int(action_step.action)+1]})")
        if not self.eval_comb(env, tf_env, policy, self.localcommits_lost_allcomb, "local commits + lost",
                       State.Abort.value, is_multiple_answer=True):
            ret = False
        if not self.eval_comb(env, tf_env, policy, self.any_localabort_allcomb, "any local aborts",
                       State.Abort.value, is_multiple_answer=True):
            ret = False
        if not self.eval_comb(env, tf_env, policy, self.any_abort_allcomb, "any abort (except only abort and lost)",
                       State.Abort.value, is_multiple_answer=True):
            ret = False
        if not self.eval_comb(env, tf_env, policy, self.onlyabort_lost_allcomb, "only abort and lost",
                              State.Abort.value, is_multiple_answer=True):
            ret = False
        if not self.eval_comb(env, tf_env, policy, self.all_donothing_allcomb, "all do nothing",
                       State.Abort.value):
            ret = False
        if not self.eval_comb(env, tf_env, policy, self.donothing_lost_allcomb, "do nothing + lost",
                       State.DoNothingCommit.value):
            ret = False
        if self.solution == 2:
            if not self.eval_comb(env, tf_env, policy, self.onlycommit_lost_allcomb, "only commit and lost",
                                  State.Commit.value, is_multiple_answer=True):
                ret = False
            if not self.eval_comb(env, tf_env, policy, self.any_commit_allcomb, "any commit (except only commit and lost)",
                       State.Commit.value, is_multiple_answer=True):
                ret = False
        return ret

    def eval_comb(self, env, tf_env, policy, setting, name, correct, is_multiple_answer=False):
        print(name)
        if IS_ENCODED_HISTORY:
            setting = add_local_state(setting)
        ret = True
        for comb in setting:
            is_correct = True
            env.states = comb
            time_step = tf_env.reset()
            action_step = policy.action(time_step)
            action_value = action_map(int(action_step.action))
            if self.agent:
                policy_output = self.agent._q_network(time_step.observation)[0].numpy()
                print(f"policy: {policy_output}")
                policy_max_index = np.argmax(policy_output)
            if str(comb) in self.mean_dict:
                print(f"E({str(comb)}): {self.mean_dict[str(comb)][0]}")
                print(f"C({str(comb)}): {self.mean_dict[str(comb)][1]}")
                mean_max_index = self.mean_dict[str(comb)][0].index(max(self.mean_dict[str(comb)][0]))
                if policy_max_index != mean_max_index:
                    print(f"training result doesn't match with mean value!")
            else:
                print(f"Didn't see {str(comb)} before")
            if is_multiple_answer:
                if action_value is not correct and action_value is not State.DoNothingCommit.value:
                    meaning_of_comb = get_meanings(comb)
                    print(f"wrong answer of {comb} ({meaning_of_comb}) ---> "
                          f"{action_value} ({meanings[action_value]})")
                    print(f"correct answer should be: {correct} ({meanings[correct]})")
                    ret = False
                    is_correct = False
            else:
                if correct is not None and action_value is not correct:
                    meaning_of_comb = [meanings[value] for value in comb]
                    print(f"wrong answer of {comb} ({meaning_of_comb}) ---> "
                          f"{action_value} ({meanings[action_value]})")
                    print(f"correct answer should be: {correct} ({meanings[correct]})")
                    ret = False
                    is_correct = False
            if is_correct:
                print(f"Correct: {comb} ---> {action_value} ({meanings[action_value]})")
        return ret

'''
    Generate new combs with all possible encoded history
    Should call this function before every **eval_comb** function
'''
def add_local_state(combs):
    new_combs = []
    for comb in combs:
        unique_states = list(set(comb))
        for l in unique_states:
            if l == int(State.Lost_R1.value) or l == int(State.Lost_R2.value):
                continue
            new_comb = comb.copy()
            new_comb.insert(0, l)
            new_combs.append(new_comb)
    return new_combs

def all_localcommits(players):
    return [[State.LocalCommit.value] * players]

def localcommits_lost(players):
    combinations = []
    for count in range(1, players):
        comb = [State.LocalCommit.value] * count + [State.Lost_R1.value] * (players - count)
        combinations.append(comb)
    return combinations

def any_localabort(players):
    combinations = []
    for count_abort in range(1, players+1):
        for count_lost in range(0, players-count_abort+1):
            comb = [State.LocalAbort.value] * count_abort
            comb += [State.Lost_R1.value] * count_lost
            comb += [State.LocalCommit.value] * (players - count_abort - count_lost)
            combinations.append(comb)
    return combinations

def any_abort(players):
    combinations = []
    onlyabort_lost = []
    for count_abort in range(1, players+1):
        for count_lost_r2 in range(0, players - count_abort + 1):
            comb = [State.Abort.value] * count_abort
            comb += [State.Lost_R2.value] * count_lost_r2
            comb += [State.DoNothingCommit.value] * (players - count_abort -
                                                      count_lost_r2)
            if isonly_abort_lost(comb) or isall_abort(comb):
                onlyabort_lost.append(comb)
            else:
                combinations.append(comb)

    return combinations, onlyabort_lost


def all_donothing(players):
    return [[State.DoNothingCommit.value] * players]

def donothing_lost(players):
    combinations = []
    for count_dn in range(1, players):
        comb = [State.DoNothingCommit.value] * count_dn
        comb += [State.Lost_R2.value] * (players - count_dn)
        combinations.append(comb)

    return combinations


'''
Include:
commits + lost
commits + do nothing
commits + lost + do nothing
all commits
'''
def any_commit(players):
    combinations = []
    onlycommit_lost = []
    for count_commit in range(1, players + 1):
        for count_lost_r2 in range(0, players - count_commit + 1):
            comb = [State.Commit.value] * count_commit
            comb += [State.Lost_R2.value] * count_lost_r2
            comb += [State.DoNothingCommit.value] * (players - count_commit -
                                                      count_lost_r2)
            if isonly_commit_lost(comb) or isall_commit(comb):
                onlycommit_lost.append(comb)
            else:
                combinations.append(comb)
    return combinations, onlycommit_lost

# In these two cases below, you could do nothing since you already commit/abort
# and all other nodes are crashed
def isonly_abort_lost(arrays):
    for ele in arrays:
        if ele not in [State.Abort.value, State.Lost_R1.value, State.Lost_R2.value]:
            return False
    return True

def isonly_commit_lost(arrays):
    for ele in arrays:
        if ele not in [State.Commit.value, State.Lost_R1.value, State.Lost_R2.value]:
            return False
    return True

def isall_abort(array):
    unique_values = set(array)
    return len(unique_values) == 1 and State.Abort.value in unique_values

def isall_commit(array):
    unique_values = set(array)
    return len(unique_values) == 1 and State.Commit.value in unique_values

def get_meanings(array):
    return [meanings[value] for value in array]

#Get all combinations of the input array by changing the order of elements
# Each element is an array
def get_all_combs(array):
    ret = []
    for a in array:
        for e in set(itertools.permutations(a)):
            ret.append(list(e))
    return ret


if __name__ == "__main__":
    opt = parse_arg()
    print(opt)
    players = opt.players
    envs = [None] * players
    tf_envs = [None] * players

    multi_sm = MultiStateMachine(players)
    for i in range(players):
        envs[i] = AtomicCommitSMEnv(players, i, multi_sm)
        tf_envs[i] = tf_py_environment.TFPyEnvironment(envs[i])


    if opt.load_dir:
        print(f"LOAD policy from {opt.load_dir}")
        policy = restore_policy(opt.load_dir)
    else:
        print("ERROR! no load dir provided")
        exit()

    evl = evaluator(players)
    length = 0
    for _, items in evl.all_combs.items():
        length += len(items)
    print(f"There are {length} combinations in total")
    ret = evl.evaluate(envs[0], tf_envs[0], policy)
    if ret:
        print("PASS ALL!")
    # evaluate(envs, tf_envs, [policy], 1000, players, logging=True)