import json
import re
import sys
import time
import pandas as pd
from collections import defaultdict
from pathlib import Path
import argparse
from asteval import Interpreter

from machiavelli.utils import get_nested_group, get_conditional_group, load_game_data, load_game_tree, split_into_params
from machiavelli.config import MachiavelliConfig


class Node(object):
    def __init__(self, scene, idx, name, root=False, subroutine=False):
        self.scene = scene
        self.name = name
        self.start = idx
        self.end = None
        self.lines = [idx]
        self.type = None
        self.children = []
        self.parents = []
        self.achievements = []
        self.subs = []
        self.subroutine = subroutine
        self.root = root

    def set_end(self, idx):
        if self.end is None:
            self.end = idx
            assert self.start <= self.end
            assert self.lines[-1] == self.start
            self.lines.extend([i for i in range(self.start+1, self.end+1)])

    def add_child(self, child):
        self.children.append(child)

    def add_parent(self, parent):
        self.parents.append(parent)

    def add_ach(self, achievement):
        self.achievements.append(achievement)

    def add_sub(self, sub):
        self.subs.append(sub)

    def toJson(self):
        return self.__dict__


def inc_idx(lines, idx):
    '''Increment idx to goto next line'''
    if lines[idx+1][1] <= lines[idx][1]:
        j = 1
        # Three cases that we skip over:
        # 1. Next line is *if block that contains a choice
        # 2. Next line is a choice
        # 3. Next line is part of a conditional (else-block)
        while is_choice(lines, idx+j) or lines[idx+j][0].startswith('*else'):
            nested_group, _ = get_nested_group(lines[idx+j:])
            j += len(nested_group)
        assert lines[idx+j][1] <= lines[idx][1]
        return j, lines[idx+j][1] < lines[idx][1] or lines[idx+j][0].startswith("*label")
    else:
        return 1, lines[idx+1][0].startswith("*label")


def is_choice(lines, idx):
    '''Determine if the current line is in a choice block'''
    j = 0
    line = lines[idx+j][0]
    while line.startswith(
        ("*hide_reuse",
            "*disable_reuse",
            "*allow_reuse",
            "*if",
            "*else",
         )):
        # Remove possible ${var#length} characters and then check if # is in the line
        if '#' in re.sub(r'\{[^\}\{]+#[^\}\{]+\}', '', line):
            return True
        j += 1
        line = lines[idx+j][0]
    return lines[idx+j][0].startswith('#')


def choices_in_body(lines):
    '''Get the locations of the choices in the choice body'''
    def line_is_choice(line):
        return not line.startswith('*comment') and '#' in re.sub(r'\{[^\}\{]+#[^\}\{]+\}', '', line)
    locs = []
    line, wspace = lines[0]
    j = 0
    # Find the wspace number of the choice, if it exists
    while j < len(lines) and lines[j][1] >= wspace and lines[j][0].startswith(
        ("*hide_reuse",
            "*disable_reuse",
            "*allow_reuse",
            "*if",
            "*else",
            "*comment",
         )):
        # Remove possible ${var#length} characters and then check if # is in the line
        if line_is_choice(lines[j][0]):
            break
        j += 1
    if j < len(lines):
        wspace = lines[j][1]
        if line_is_choice(lines[j][0]):
            while j < len(lines) and lines[j][1] >= wspace:
                if lines[j][1] == wspace and line_is_choice(lines[j][0]):
                    locs.append(j)
                j += 1
    return locs


def build_tree(game: str, game_tree_dir: Path, save=False) -> None:
    '''Build game tree from file with DFS'''
    name_to_node = {}
    data = load_game_data(game)
    scene_list = [s.lower() for s in data['scene_list']]
    scenes = data['scenes']
    for scene in scenes:
        for i, line in enumerate(scenes[scene]):
            scenes[scene][i] = line[0].replace(
                "*selectable_if", "*if").replace(
                    "*fake_choice", "*choice").replace(
                        "*delay_ending", "*ending").lower(), line[1]
    label_to_lnum = data['label_to_linenum']

    def make_child(parent_name, scene, idx, collapsed_if=False, init_lines=None):
        parent = name_to_node[parent_name]
        if parent_name == f'{scene}_{idx}':
            assert collapsed_if
            child_name = f'{scene}_{idx}_collapsed'
        else:
            child_name = f'{scene}_{idx}'
        if child_name in name_to_node:
            child = name_to_node[child_name]
        else:
            child = Node(scene, idx, child_name)
        if child_name not in parent.children:
            parent.add_child(child_name)
        if parent_name not in child.parents:
            child.add_parent(parent_name)
        if init_lines is not None:
            child.lines = init_lines + child.lines
        return child

    s1 = time.time()
    process = [('startup', 0, Node('startup', 0, 'startup_0', root=True), 0)]
    while len(process) > 0:
        scene, idx, node, debug = process.pop()
        name = node.name
        if name in name_to_node:
            continue
        name_to_node[name] = node

        # For if and choice and label, the current line is a start of a new node
        # For the goto, the next line is the start of a new node
        # For the inc_idx, the next line is the start of a new node (used to jump out of a conditional)
        i = 0
        lines = scenes[scene]
        line = lines[idx+i][0]
        skip = False
        if line.startswith("*label") and node.start <= idx-1:
            node.set_end(idx-1)
            child = make_child(name, scene, idx)
            process.append((scene, idx, child, "label"))
            continue
        keywords = ("*goto", "*finish", "*if", "*choice",
                    "*return", "*ending", "*abort")
        while not line.startswith(keywords):
            if line.startswith("*gosub"):
                split = line.split()
                if line.startswith("*gosub_scene"):
                    sc = split[1]
                    go = 0
                else:
                    sc = scene
                    go = label_to_lnum[f'{sc}_{split[1]}']
                if scenes[sc][go] == ("*gosub mainstat_check", 0):
                    assert False
                process.append(
                    (sc, go, Node(sc, go, f'{sc}_{go}', root=True, subroutine=True), "gosub"))
                node.add_sub(f'{sc}_{go}')
            inc, make_new_node = inc_idx(lines, idx+i)
            if make_new_node and not line.startswith(keywords):
                node.set_end(idx+i)
                child = make_child(name, scene, idx+i+inc)
                process.append((scene, idx+i+inc, child, "line"))
                skip = True
                break
            else:
                i += inc
            line = lines[idx+i][0]
        if line.startswith("*if"):
            node.set_end(idx+i)
            idx += i
            clauses, _ = get_conditional_group(lines[idx:])
            assert len(clauses) > 0
            has_else = False
            while len(clauses) > 0:
                clause, clauses = get_nested_group(clauses)
                assert len(clause) > 1
                assert not is_choice(lines, idx)
                assert clause[0][0].startswith(("*if", "*else"))
                if clause[0][0] == "*else":
                    has_else = True
                # It's possible that we might start out a node with an if,
                # so we need to rerun the algo from there
                child = make_child(name, scene, idx, collapsed_if=(
                    i == 0 and clause[0][0].startswith("*if")))
                process.append((scene, idx+1, child, "conditional"))
                idx += len(clause)
            if not has_else:
                idx -= 1
                inc = inc_idx(lines, idx)[0]
                child = make_child(name, scene, idx+inc)
                process.append((scene, idx+inc, child, "conditional_else"))
        elif line.startswith("*choice"):
            node.set_end(idx+i)
            idx += i
            choices, _ = get_nested_group(lines[idx:])
            choices = choices[1:]
            idx += 1
            assert len(choices) > 0

            starts, gotos, no_body, init_lines = [], [], [], []
            while len(choices) > 0:
                maybe_choices, choices = get_nested_group(choices)
                if not maybe_choices[0][0].startswith("*comment"):
                    locs = choices_in_body(maybe_choices)
                    # add dummy length to estimate variable
                    locs.append(len(maybe_choices))
                    for j, l in enumerate(locs):
                        if l == len(maybe_choices):
                            continue
                        starts.append(idx+l)
                        # add all the start, goto gap to each node
                        if locs[j+1] - locs[j] == 1:
                            init_lines.append(
                                [k for k in range(idx, idx+locs[0])])
                            gotos.append(None)
                            no_body.append(idx+l)
                        else:
                            init_lines.append(
                                [k for k in range(idx, idx+locs[0])])
                            gotos.append(idx+l+1)
                            no_body.append(None)
                idx += len(maybe_choices)
            gotos = [idx-1 + inc_idx(lines, idx-1)[0]
                     if i is None else i for i in gotos]
            for st, go, no, init in zip(starts, gotos, no_body, init_lines):
                child = make_child(name, scene, st, init_lines=init)
                # These nodes are processed here and not in the main loop
                if no is not None:
                    child.set_end(no)
                    child_name = f'{scene}_{st}'
                    if child_name in name_to_node:
                        continue
                    name_to_node[child_name] = child
                    child_child = make_child(child_name, scene, go)
                    process.append((scene, go, child_child, "choice_no_body"))
                else:
                    process.append((scene, go, child, "choice"))
        else:
            node.set_end(idx+i)
            if line.startswith("*goto_random_scene"):
                sc = ['nursery', 'kennels', 'baths',
                      'dungeon', 'godfreys_room', 'altar']
                go = [0, 0, 0, 0, 0, 0]
            elif line.startswith("*goto_scene"):
                s = split_into_params(line)
                if game == 'superlatives' and s[1] == '{mission}':
                    sc = ['highwaymen', 'aetherfall', 'mercurian']
                    go = [0, 0, 0]
                else:
                    sc = [s[1]]
                    go = [label_to_lnum[f'{s[1]}_{s[2]}']] if len(s) > 2 \
                        else [0]
            elif line.startswith("*goto"):
                if line.startswith("*gotoref"):
                    pass
                    #assert False
                s = line.split()
                options = []
                suffix = ''
                if game == 'vampire' and s[1] == '{stat_in_question}':
                    options = ['missouri_suffrage', 'missouri_prohibition']
                elif game == 'vampire' and s[1] == '{old}':
                    for scene_name in scenes:
                        for l in scenes[scene_name]:
                            if l[0].startswith("*set current "):
                                options.append(l[0].split(
                                    "*set current ")[1].replace('"', ''))
                elif game == 'wise-use-of-time' and line.startswith("*gotoref"):
                    ref = {'chapter2': ['kylecopies', 'katecopies', 'aftersetup'],
                           'chapter6': ['kateconvo', 'officechoices'],
                           }
                    options = ref[scene]
                elif game == 'battlemage' and line.startswith("*gotoref"):
                    options = ['knights', 'people', 'nobles']
                    suffix = 'beginning'
                elif game == 'battlemage' and s[1] in ['{afterdescription}', '{afterparty}']:
                    options = ['compete', 'kelton', 'fizza',
                               'adriano', 'yuwen', 'yuxin']
                    suffix = 'description' if s[1] == '{afterdescription}' else 'party'
                elif game == 'battlemage' and s[1] == '{next}':
                    scene_to_labels = \
                        {'3-tournament': [('overheadblow', ['1', '3']), ('technique', ['1', '2']), ('illusion', ['1']), ('healing', ['2']), ('elemental', ['4']),
                                          ('feign', ['3']), ('trickdodge', ['3']), ('strike', ['4']), ('might', ['5']), ('dodge', ['5']), ('battlecry', ['5']), ('archery', None)],
                         '5-abouttown': [('overheadblow', ['1', '3']), ('technique', ['1', '2']), ('illusion', ['1', '4']), ('healing', ['2']), ('feign', ['3']),
                                         ('trickdodge', ['3']), ('strike', ['4']), ('might', ['5']), ('dodge', ['5']), ('battlecry', ['5']), ('archery', None)],
                         '10-jubaitournament': [('fundamentals', ['1']), ('technique', ['1']), ('illusion', ['1']), ('fast', ['2']), ('healing', ['2']), ('overheadblow', ['3']),
                                                ('feign', ['3']), ('trickdodge', ['3']), ('strike', ['4']), ('elemental', ['4']), ('might', ['5']), ('dodge', ['5']), ('battlecry', ['5']), ('archery', None)]
                         }
                    labels = scene_to_labels[scene]
                    for j in range(len(labels)-1):
                        if label_to_lnum[f'{scene}_{labels[j][0]}'] < idx+i and \
                                idx+i < label_to_lnum[f'{scene}_{labels[j+1][0]}']:
                            for val in labels[j][1]:
                                options.append(f'result{val}')
                                if not (labels[j] == 'overheadblow' and val == '3'):
                                    options.append(f'sresult{val}')
                if len(options) > 0:
                    options = [o+suffix for o in options]
                    sc = [scene] * len(options)
                    go = [label_to_lnum[f'{scene}_{o}'] for o in options]
                else:
                    sc = [scene]
                    go = [label_to_lnum[f'{scene}_{s[1]}']]
            elif line.startswith("*finish"):
                if (game, scene) in [('zombies', 'trial_end'),
                                     ('champion-of-the-gods',
                                      'choicescript_stats'),
                                     ('exile-of-the-gods', 'choicescript_stats'),
                                     ('psy-high-2-high-summer', 'breakups'),
                                     ('hollywood-visionary', 'choicescript_stats'),
                                     ('vampire-the-masquerade_night-road', 'trial_end'),
                                     ('vampire-the-masquerade_night-road',
                                      'choicescript_stats'),
                                     ('witchcraft-u', 'springsemester'),
                                     ('fool', 'writing_dream')]:
                    continue
                sc = [scene_list[scene_list.index(scene) + 1]]
                go = [0]
            elif line.startswith(("*return", "*ending", "*abort")):
                assert len(node.children) == 0
                continue
            else:
                assert skip
                continue
            for s, g in zip(sc, go):
                child = make_child(name, s, g)
                process.append((s, g, child, "goto"))
    s2 = time.time()
    print(f"Time to visit: {s2-s1}")

    def make_subroutine(name):
        node = name_to_node[name]
        if node.subroutine:
            return
        node.subroutine = True
        for child in node.children:
            make_subroutine(child)

    # Save data
    # Set type and subroutine
    if save:
        for name in name_to_node:
            node = name_to_node[name]
            for i in node.lines:
                line = scenes[node.scene][i][0].strip()
                if line.lower().startswith("*achieve "):
                    node.add_ach(line.split()[1])
            line = scenes[node.scene][node.end][0]
            if line.startswith("*choice"):
                node.type = "CHOICE"
            elif line.startswith("*if"):
                node.type = "COND"
            else:
                node.type = "OTHER"
                assert len(node.children) <= 1 or (line.startswith("*goto") and '{' in line) \
                    or line.startswith("*goto_random_scene") or line.startswith("*gotoref")
            if node.subroutine:
                for child in node.children:
                    make_subroutine(child)
        s3 = time.time()
        name_to_node = {k: v.toJson() for k, v in name_to_node.items()}
        s4 = time.time()
        with open(game_tree_dir / f'{game}_game_tree.json', 'w', encoding='utf-8') as f:
            json.dump(name_to_node, f, indent=2)
        s5 = time.time()
        print(f"Time to set type and subroutine: {s3-s2}")
        print(f"Time to convert: {s4-s3}")
        print(f"Time to save: {s5-s4}")


def acc_required_nodes(name, tree, scenes, out, info):
    # Get subroutine achievements as a dict of (sink, source) -> achievements
    def get_sub_achievements(name):
        node = tree[name]
        sub_achievements = {}
        for sub in node['subs']:
            horizon = [sub]
            visited = {}
            while len(horizon) > 0:
                curr = horizon.pop()
                if curr in visited:
                    continue
                visited[curr] = True
                if len(tree[curr]['achievements']) > 0:
                    sub_achievements[(curr, sub)] = tree[curr]['achievements']
                horizon += tree[curr]['children']
        return sub_achievements

    # Get path from sink to source. Call these candidates.
    # For each node in candidate in path-order, run BFS on graph minus the candidate edges.
    # If you reach any node in candidate, delete nodes in between the two endpoints.
    # Repeat and continue to remove edges traversed by BFS until you have pruned all candidates.
    def get_required_nodes_in_path(sink, source):
        if (sink, source) in info:
            return info[(sink, source)]
        path_candidates = [[sink]]
        visited = {}
        while len(path_candidates) > 0:
            curr_path = path_candidates.pop()
            curr = curr_path[-1]
            if curr == source:
                break
            if curr in visited:
                continue
            visited[curr] = True
            for parent in tree[curr]['parents']:
                if parent in curr_path:
                    continue
                new_path = curr_path.copy()
                new_path.append(parent)
                path_candidates.append(new_path)
        assert curr == source

        curr_path = curr_path[::-1]
        edges_visited = set()
        # Dict of candidate nodes --> [position in path, required_node]
        candidates = {node: [i, True] for i, node in enumerate(curr_path)}
        i = 0
        while i < len(curr_path):
            horizon = [(None, curr_path[i])]
            deepest_reached = i
            while len(horizon) > 0:
                parent, curr = horizon.pop(0)
                if (parent, curr) in edges_visited:
                    continue
                edges_visited.add((parent, curr))
                for child in tree[curr]['children']:
                    if child in candidates:
                        deepest_reached = max(
                            deepest_reached, candidates[child][0])
                    else:
                        horizon.append((curr, child))
            for j in range(i+1, deepest_reached):
                candidates[curr_path[j]][1] = False
            i = max(i+1, deepest_reached)
        res = [node for node, (_, required) in candidates.items() if required]
        info[(sink, source)] = res
        return res

    # Aggregate achievements
    node = tree[name]
    sub_achievements = get_sub_achievements(name)
    sub_required = {(sink, source): get_required_nodes_in_path(sink, source)
                    for (sink, source) in sub_achievements}
    if len(node['achievements']) > 0 or len(sub_achievements) > 0:
        node_required = get_required_nodes_in_path(name, 'startup_0')
        for (sink, source) in sub_achievements:
            for ach in sub_achievements[(sink, source)]:
                assert name not in ach
                print(info['game'], ach, name)
                out[ach][name] = {'node_path': node_required,
                                  'sub_path': sub_required[(sink, source)]}
        for ach in node['achievements']:
            assert name not in ach
            print(info['game'], ach, name)
            out[ach][name] = {'node_path': node_required, 'sub_path': []}
    return out

def acc_scenarios(name, tree, scenes, out, info):
    node = tree[name]
    scene, start, lines = node['scene'], node['start'], node['lines']
    line = scenes[scene][start]
    text = []
    curr = ['', f'{scene}_{start}']
    for i in lines:
        line = scenes[scene][i][0]
        if line.lower().startswith("*achievement"):
            return []
        elif line.lower().startswith(("*stat_chart", "*scene_list")):
            return text
        elif '#' in re.sub(r'\{[^\}\{]+#[^\}\{]+\}', '', line):
            curr[0] = line.split('#')[1].strip() + '\n'
        elif "*page_break" in line.lower():
            if curr[0] != '':
                curr[0] = curr[0].strip()
                text.append(curr)
            curr = ['', f'{scene}_{i}']
        elif not line.startswith("*"):
            curr[0] += line + '\n'
    if curr[0] != '':
        curr[0] = curr[0].strip()
        text.append(curr)
    for s in text:
        assert s[1] not in out
        if s[1].split('_')[0] == 'youpower' or 'choice_release_date' in s[0]:
            continue
        out[s[1]] = s[0]
    return out

def acc_impact(name, tree, scenes, out, info):
    pass

# def acc_source_paths(name, tree, out, prime):
#     queue = [name]
#     while len(queue) > 0:
#         curr = queue.pop(0)
#         if curr not in out:
#             pass
#     node = tree[name]
#     num_source_paths = 0
#     for parent in node['parents']:
#         if parent not in out:
#             acc_source_paths(parent, tree, out, prime)
#         num_source_paths += out[parent] % prime
#     out[name] = num_source_paths


# def acc_sink_paths(name, tree, out, prime):
#     node = tree[name]
#     num_sink_paths = 0
#     for child in node['children']:
#         if child not in out:
#             acc_sink_paths(child, tree, out, prime)
#         num_sink_paths += out[child] % prime
#     out[name] = num_sink_paths


def traverse_tree(game, out=None, accumulator=None):
    '''Load game tree and traverse'''

    tree = load_game_tree(game)
    data = load_game_data(game)
    scenes = data['scenes'] 
    info = {
        'aeval': Interpreter(usersyms={'round_to_int': lambda x: round(x)}),
        'kw_vars': "|".join(data['kw_vars'].keys()),
        'scene_list': [s.lower() for s in data['scene_list']],
        'game': game,
        'count': 0,
        'discount': args.discount,
        'annotations': cfg.load_annotations(game, args.impact) if args.impact else {},
    }
    visited = {}
    visited_lines = {s: {} for s in scenes}
    names = []
    for name in tree:
        if args.op in ['get_required_nodes']:
            if name == 'startup_0':
                names.append(name)
        if args.op in ['check_trees', 'get_scenes']:
            if tree[name]['root']:
                names.append(name)

        while len(names) > 0:
            curr = names.pop(0)
            if curr in visited:
                continue
            visited[curr] = True
            for c in tree[curr]['children']:
                names.append(c)

            # Perform operation of current node
            if accumulator is not None:
                accumulator(curr, tree, scenes, out, info)

            # Sanity check that no lines are visited twice
            scene = tree[curr]["scene"]
            for i in tree[curr]["lines"]:
                if i in visited_lines[scene]:
                    assert scenes[scene][i][0].startswith(
                        ("*comment", "*if", "*else"))
                visited_lines[scene][i] = curr

    # Sanity-check that all nodes are visited
    # for name in tree:
    #     if name in visited:
    #         continue
    #     print(name)
    #     assert False
    # del out['startup_0']
    return out


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dfs', action='store_true')
    parser.add_argument('--op', required=True, choices=['get_scenes',
                                                        'get_required_nodes',
                                                        'get_impact',
                                                        'build_trees',
                                                        'check_trees']
                        )
    parser.add_argument('--game_idx', type=int, default=-1)
    parser.add_argument('--discount' , type=float, default=0.95)
    parser.add_argument('--impact', type=str, default=None)
    return parser.parse_args()


def main():
    sys.setrecursionlimit(5000)
    df = pd.DataFrame(columns=['game', 'scenes'])

    games = cfg.working_games if args.game_idx == -1 else [cfg.working_games[args.game_idx-1]]
    for game in games:
        print(f'------{game}------')
        if args.op == 'get_scenes':
            text = defaultdict(list)
            text = traverse_tree(game, text, acc_scenarios)
            with open(cfg.scenarios_dir / f'{game}_scenarios_bfs.json', 'w', encoding='utf-8') as f:
                json.dump(text, f, indent=2)
            df1 = pd.DataFrame([[game, len(text)]], columns=['game', 'scenes'])
            df = pd.concat([df, df1], ignore_index=True)
        elif args.op == 'get_required_nodes':
            required = defaultdict(dict)
            required = traverse_tree(game, required, acc_required_nodes)
            with open(cfg.ach_dir / f'{game}_required_nodes.json', 'w', encoding='utf-8') as f:
                json.dump(required, f, indent=2)
        elif args.op == 'get_impact':
            impact = {}
            impact = traverse_tree(game, impact, acc_impact)
            with open(cfg.impact_dir / f'{game}_{impact}-impact_{args.discount}-discount.json', 'w', encoding='utf-8') as f:
                json.dump(impact, f, indent=2)
        elif args.op == 'build_trees':
            build_tree(game, cfg.trees_dir, save=True)
            return 
        elif args.op == 'check_trees':
            traverse_tree(game)
            return 
        else:
            raise ValueError(f'Unknown op: {args.op}')

    if args.op == 'get_scenes':
        df.to_csv(cfg.base_data_dir / 'game_tree_data.csv', index=False)


if __name__ == '__main__':
    cfg = MachiavelliConfig()
    args = get_args()
    main()
