import concurrent.futures as cf
from collections import defaultdict
import functools
import time
from machiavelli.game.player import Player
from machiavelli.utils import load_game_data, load_game_tree, load_game_ach_paths, get_vars
import numpy as np
from asteval import Interpreter
import csv

def current_power(player, power_scores=None, traj=None):
    if power_scores is None:
        return player.get_resource_power()
    elif traj is None:
        traj = player.get_history()
    power = [power_scores[node] for node in traj]
    return sum(power), max(power)

def safe_check(cond):
    try:
        res = player._check_text(cond)
        return res
    except:
        return False

def parallel_safe_check(conds):
    with cf.ThreadPoolExecutor(max_workers=40) as executor:
        results = executor.map(safe_check, conds)
    return list(results)

def get_conds(game, node_path, sub_path):
    conds = set()
    data = load_game_data(game)
    tree = load_game_tree(game)
    for line in data['scenes']['startup']:
        if line[0].lower().startswith('*create'):
            _, var, val = line[0].split(None, 2)
            var = player._set_var(var, val)
            player.user_vars[var] = None
    original_state = player._save_state(save_pos=True)

    scene_line_to_node = defaultdict(dict)
    for node in sorted(tree):
        scene = tree[node]['scene']
        for i in tree[node]['lines']:
            if i in scene_line_to_node[scene]:
                scene_line_to_node[scene][i].append(node)
            else:
                scene_line_to_node[scene][i] = [node]

    def is_binary(cond):
        if 'choice_achieved' in cond or '<' in cond or '>' in cond:
            return False
        aeval = Interpreter(usersyms={'round_to_int': lambda x: round(x)})
        kw_vars = "|".join(data['kw_vars'].keys())
        return get_vars(cond, aeval, kw_vars, set())[1]

    def get_text(node):
        return data['scenes'][node['scene']][node['start']][0].strip()

    def strip_cond(text):
        return text.replace(
                "*if ", "").replace("*elseif ", "").replace("*else", "")
        
    def maybe_get_cond_from_parent(name):
        conds = set()
        node = tree[name]
        parent = tree[node['parents'][0]] if len(node['parents']) == 1 else None
        idx = 0
        add_condition = False
        if parent is not None and parent['type'] == 'COND':
            text = get_text(node)
            if text.lower().startswith('*if'):
                add_condition = True
            if text.lower().startswith('*else'):
                if text.lower().startswith('*elseif'):
                    add_condition = True
                idx = parent['children'].index(name)
            
        if add_condition:
            conds = conds.union({strip_cond(text)})
        if parent is not None:
            conds = conds.union(
                {f'not ({strip_cond(get_text(tree[sibling]))})' for sibling in parent['children'][:idx]})
        return conds

    def get_conds_in_nested(name):
        node = tree[name]
        lines = data['scenes'][node['scene']]
        idx = node['start']
        conds, not_conds = [lines[idx][0]], []
        is_else = lines[idx][0].lower().startswith('*else')
        curr_wspace = lines[idx][1]
        
        idx -= 1
        while curr_wspace > 0 and idx >= 0:
            line = lines[idx][0]
            if lines[idx][1] > curr_wspace:
                pass
            elif lines[idx][1] == curr_wspace:
                if is_else:
                    not_conds.append(line)
                    if line.startswith('*if'):
                        is_else = False
            else: 
                is_else = line.lower().startswith('*else')
                curr_wspace = lines[idx][1]
                conds.append(line)
            idx -= 1
                
        root_conds = maybe_get_cond_from_parent(scene_line_to_node[node['scene']][node['start']][-1])
        conds = [strip_cond(c) for c in conds if c.lower().startswith(("*if", "*elseif"))]
        not_conds = [f'not ({strip_cond(c)})' for c in not_conds if c.lower().startswith(("*if", "*elseif"))]
        return root_conds | set(conds) | set(not_conds)

    for name in node_path[1:]+sub_path:
        conds = conds.union(maybe_get_cond_from_parent(name))

    binary_conds = {c for c in conds if is_binary(c)}
    horizon = [node_path[-1]]
    if len(sub_path) > 0:
        horizon.append(sub_path[-1])
    visited = set()
    while len(horizon) > 0:
        name = horizon.pop(0)
        if name in visited:
            continue
        visited.add(name)
        player._load_state(original_state)
        player.simulate_node(tree[name])
        if name == node_path[-1] or any([safe_check(c) for c in binary_conds]):
            new_conds = get_conds_in_nested(name)
            conds = conds.union(new_conds)
            #binary_conds = binary_conds.union({c for c in new_conds if is_binary(c)})
        for parent in tree[name]['parents']:
            horizon.append(parent)
    return conds


def get_child_nodes(tree, name, max_depth=5):
    def get_child_helper(node, max_depth):
        if max_depth == 0:
            return []
        children = tree[node]['children']
        return children + [c for child in children for c in get_child_helper(child, max_depth-1)]

    def get_parent_helper(node, max_depth):
        if max_depth == 0:
            return []
        parents = tree[node]['parents']
        return parents + [p for parent in parents for p in get_parent_helper(parent, max_depth-1)]

    possible_child_nodes = set(get_child_helper(name, max_depth))
    possible_parent_nodes = set(get_parent_helper(name, max_depth))
    return possible_child_nodes - possible_parent_nodes - set(name)


def get_minimum_power_path(game, achievement, power_scores=None, max_power=False):
    # Starting from sink and moving to source, get the required variables between two nodes
    # Mark where the required variables change between the two nodes
    # Calculate the reachability graph among those nodes
    # Obtain all possible paths along the reachability graph and determine the minimum power path
    # Run Dijkstra's algorithm to find the path between the remaining nodes
    def inc_next_req_idx(path, idx):
        j = 0
        while path[idx+j] in player.get_history():
            j += 1
            if idx+j >= len(path):
                assert False

        assert all([n not in player.get_history() for n in path[idx+j:]])
        return j

    # async def parallel_safe_check(player, cond):
    #     loop = asyncio.get_event_loop()
    #     res = await loop.run_in_executor(None, functools.partial(safe_check, player, cond))
    #     return res

    ach_path = load_game_ach_paths(game)[achievement]
    min_power = np.inf
    min_path = []
    powers = []

    for ach_node in ach_path:
        print(ach_node)
        player.reset()

        conds = get_conds(
            game, ach_path[ach_node]['node_path'], ach_path[ach_node]['sub_path'])
        print(conds)
        return None, None, None
        skip = False
        path = ach_path[ach_node]['node_path'] + ach_path[ach_node]['sub_path']
        
        req_idx = 0
        player.reset()
        player.skip_to_next_choice()
        while path[-1] not in player.get_history():
            history = set(player.get_history())
            next_req_idx = len(path) - np.argmax([n in history for n in path[::-1]])
            if achievement in player.unlocked_achievements or next_req_idx == req_idx:
                skip = True
                break 
            # inc = inc_next_req_idx(path, next_req_idx)
            # if inc == 0:
            #     skip = True
            #     break # This path is not reachable 
            # next_req_idx += inc
            #print(path[next_req_idx])
            req_idx = next_req_idx
            res, original_state = player.simulate_choices(path[req_idx:])
            #print(len(res))
            if len(res) == 0:
                skip = True 
                break # This path is not reachable  TODO: add backtracking

            unsat = []
            power = []
            for _, state in res:
                player._load_state(state)
                #unsatisfied.append(len(conds) - sum([~safe_check(player, c) for c in conds]))
                # sat = []
                # for cond in conds:
                #     sat.append(asyncio.create_task(parallel_safe_check(player, cond)))
                # sat = await asyncio.gather(*sat)
                sat = [safe_check(c) for c in conds]
                unsat.append(sum([not s for s in sat]))
                pow = current_power(player, power_scores)
                power.append(1 / pow if max_power else pow)
            best_path = np.lexsort((power, unsat))[0]
            player._load_state(original_state)
            for act in res[best_path][0]:
                player.play(act)
                player.skip_to_next_choice()
        if skip or achievement not in player.unlocked_achievements:
            continue

        assert achievement in player.unlocked_achievements
        total_power = current_power(player, power_scores)
        if total_power > min_power and max_power:
            min_power = total_power
            min_path = player.get_history()
        if total_power < min_power and not max_power:
            min_power = total_power
            min_path = player.get_history()
        powers.append(total_power)

    return min_path, min_power, powers


if __name__ == '__main__':
    import argparse
    import json
    from machiavelli.config import MachiavelliConfig
    parser = argparse.ArgumentParser()
    #parser.add_argument('--game', type=str)
    #parser.add_argument('--achievement', type=str)
    parser.add_argument('--power-scores', type=str, default=None)
    parser.add_argument('--max-power', action='store_true')
    parser.add_argument('--game_idx', type=int, default=-1)
    args = parser.parse_args()

    cfg = MachiavelliConfig()
    args.game = cfg.working_games[args.game_idx-1]

    player = Player(args.game, watch_history=True)
    if args.power_scores is not None:
        with open(args.power_scores, 'rb') as f:
            power_scores = json.load(f)
    else:
        power_scores = None

    achs, paths, powers = [], [], []
    if args.max_power:
        for ach in load_game_ach_paths(args.game):
            min_path, _, min_power = get_minimum_power_path(args.game, ach, power_scores, max_power=True)
            achs.append(ach)
            paths.append(min_path)
            powers.append(min_power)
        with open(f'/data/aypan/machiavelli/{args.game}_maxpower.csv', 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['achievement', 'path', 'power'])
            for ach, path, power in zip(achs, paths, powers):
                writer.writerow([ach,path,power])

    else:
        for ach in load_game_ach_paths(args.game):
            min_path, _, min_power = get_minimum_power_path(args.game, ach, power_scores)
            achs.append(ach)
            paths.append(min_path)
            powers.append(min_power)
        with open(f'/data/aypan/machiavelli/{args.game}_minpower.csv', 'w') as f:
            writer = csv.writer(f)
            writer.writerow(['achievement', 'path', 'power'])
            for ach, path, power in zip(achs, paths, powers):
                writer.writerow([ach,path,power])
