import json
import os
import random
import re
import ast
from typing import Optional

import numpy as np
from .config import MachiavelliConfig
### Utils for player.py and scraper.py ###

cfg = MachiavelliConfig()


def random_seed(seed: int) -> None:
    """
    Manually set random seed for various libraries.
    """
    random.seed(seed)
    np.random.seed(seed)
    # torch.manual_seed(seed)
    # if torch.cuda.device_count() > 0:
    #     torch.cuda.manual_seed_all(seed)


def get_slurm_id() -> Optional[str]:
    """
    Checks for a valid SLURM ID in the environment variables.
    """

    slurm_job_id = None
    if 'SLURM_ARRAY_JOB_ID' in os.environ and 'SLURM_ARRAY_TASK_ID' in os.environ:
        slurm_job_id = f"{os.environ['SLURM_ARRAY_JOB_ID']}_{os.environ['SLURM_ARRAY_TASK_ID']}"
    elif 'SLURM_JOBID' in os.environ:
        slurm_job_id = os.environ['SLURM_JOBID']
    return slurm_job_id


def clean(text):
    return text.strip().replace(
        "[i]", "").replace(
        "[/i]", "").replace(
        "[b]", "").replace(
        "[/b]", "")


def delete_choice_release_data(game):
    for s in ['.json', '_filled.json']:
        suffix = '_scenarios_bfs' + s
        with open(cfg.scenarios_dir / f'{game+suffix}', encoding="utf-8") as f:
            scenarios = json.load(f)
        new_scenarios = {}
        for idx, scene in scenarios.items():
            if 'choice_release_date' in scene:
                continue
            new_scenarios[idx] = scene
        with open(cfg.scenarios_dir / f'{game+suffix}', 'w', encoding="utf-8") as f:
            json.dump(new_scenarios, f, indent=2)


def get_conditional_group(lines):
    """
    returns all lines that appear after *if.
    stop when encounting anything other than *elseif or *else
    """
    assert lines[0][0].startswith("*if")
    wspace = lines[0][1]
    i = 1
    while i < len(lines):
        if lines[i][1] > wspace:
            i += 1
        elif lines[i][1] == wspace:
            if lines[i][0].startswith(("*elseif", "*elsif", "*else")):
                i += 1
            else:
                break
        else:
            break
    return lines[:i], lines[i:]


def get_nested_group(lines):
    """
    returns the first lines and any consecutive lines with deeper nesting than it
    """
    assert len(lines) > 0
    wspace = lines[0][1]
    i = 1
    while i < len(lines) and (lines[i][1] > wspace):
        i += 1
    return lines[:i], lines[i:]


def get_scenes_from_scenelist(lines):
    """
    returns a list of scenes that are not have not yet been processed
    """
    assert lines[0][0].startswith("*scene_list")
    scenes = []
    for line in lines[1:]:
        line = line[0]
        if line.startswith("$"):
            line = line.split(" ")[1]
        if line != "startup":
            scenes.append(line)
    return scenes


def htmlify(text):
    return "<p>" + text.replace(
        "\n", "<br><br>").replace(
        "[i]", "<i>").replace(
        "[/i]", "</i>").replace(
        "[b]", "<strong>").replace(
        "[/b]", "</strong>") + "</p>"


def load_game_ach_paths(game):
    with open(cfg.ach_dir / f'{game}_required_nodes.json', encoding='utf-8') as f:
        return json.load(f)


def load_game_scenarios(game, filled=False):
    suffix = '_scenarios_bfs'
    if filled:
        suffix += cfg.filled_suffix
    suffix += '.json'
    with open(cfg.scenarios_dir / f'{game+suffix}', encoding="utf-8") as f:
        scenarios = json.load(f)
    scenario_list = []
    for idx in scenarios:
        scenario_list.append((idx, scenarios[idx]))
    return scenario_list


def load_game_summaries():
    with open(cfg.base_data_dir / 'game_summaries.json', encoding='utf-8') as f:
        return json.load(f)


def load_game_tree(game):
    with open(cfg.trees_dir / f'{game}_game_tree.json', encoding='utf-8') as f:
        return json.load(f)


def load_game_data(game):
    with open(cfg.sources_dir / f'{game}.json', encoding='utf-8') as f:
        data = json.load(f)
    scene_list = data['scene_list']
    scenes = data['scenes']
    for scene in scenes:
        # Buffer for computing next line
        if scene != scene_list[-1]:
            scenes[scene].append(("*finish", 0))
        else:
            scenes[scene].append(("*ending", 0))
        scenes[scene].append(("", 0))
    return data


def make_lowercase(m):
    if m.group(1) is None:
        return m.group(0)
    return m.group(1).lower()


def make_underscore(m):
    if m.group(1) is None:
        return m.group(0)
    return m.group(1).lower()+"_"


def split_into_params(text):
    """
    split text into params. this is wrong if text contains any nested quotes
    """
    split = []
    parens = 0
    in_quote = False
    curr = ""
    for c in text:
        if c.isspace():
            if parens == 0 and not in_quote:
                split.append(curr)
                curr = ""
                continue
        elif c == '(':
            parens += 1
        elif c == ')':
            parens -= 1
        elif c == '"':
            in_quote = not in_quote
        curr += c
    split.append(curr)
    split_ = []
    for s in split:
        if s.strip() != '':
            split_.append(s.strip())
    return split_


def get_vars(text, aeval, kw_vars, vars):
    """
    Updates vars to contain all variables in text

    text: string
    aeval: asteval Interpreter
    kw_vars: string of the form "var1|var2|var3" with vars that are Python keywords
    vars: set of variables accumulated in text

    returns: _, bool where _ is None and bool is True if text contains only boolean variables
    """
    assert '@{' not in text and '#' not in text
    if 'choice_time_stamp' in text:
        return
    expr = r"(?=(?:[^\"']*[\"'][^\"']*[\"'])*[^\"']*$)"
    text = re.sub(r"(TRUE|true)"+expr, "True", text)
    text = re.sub(r"(FALSE|false)"+expr, "False", text).strip()

    # Process text
    for key in ['${', '$!{', '$!!{']:
        while key in text:
            i = text.find(key) + len(key)
            j = 0
            count = 1
            while count > 0:
                if text[i+j] == '{':
                    count += 1
                if text[i+j] == '}':
                    count -= 1
                j += 1
            match = text[i:i+j-1]
            sub = get_vars(match.lower().split('#')[0], aeval, kw_vars, vars)[0]
            if key == '$!{':
                sub = sub.capitalize()
            if key == '$!!{':
                sub = sub.upper()
            text = text[:i-len(key)] + sub + text[i+j:]

    # Ensure that text has '==' not '=' as the equality operator
    text = text.replace('=', '==')
    text = text.replace('<==', '<=').replace('>==', '>=').replace('!==', '!=')
    text = text.replace('%+', '+').replace('%-', '-')
    text = text.replace('modulo', '%')
    text = text.replace('^', '**')
    text = text.replace('round(', 'round_to_int(')
    # Replace length(exp) with "1" and get vars of exp
    text = re.sub(r'length\((\w+)\)',
                  lambda m: get_vars(m.group(1), aeval, kw_vars, vars)[0], text)

    if kw_vars != '':
        expr = r'"[^"]*"|((?<!\w)(?:'+kw_vars+r')(?!\w))'
        text = re.sub(expr, make_underscore, text)

    while '{' in text:
        assert '}' in text
        text = re.sub(
            r'\{([^\{\}]+)\}', lambda m: get_vars(m.group(1), aeval, kw_vars, vars)[0], text)

    if '&' in text:
        text = text.replace('(', '').replace(')', '')
        expr = r"&(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)"
        text = re.sub(expr, lambda m: '+', text)

    # Add to vars
    call = ast.dump(aeval.parse(text))
    uncased_vars = set(re.findall(r"Name\(id='(\w+)', ctx=Load\(\)\)", call))
    for var in uncased_vars:
        vars.add(var.lower())

    return "1", not bool(re.search(r"(value=\d)|(value=')", call))
