import ast
import re
from typing import Dict, Set, Any
from termcolor import colored
import astor
import parso
import astunparse

CODE_HEADER = """
from env import setup_environment, shutdown_environment
from skill_code import *
import argparse
import importlib
import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.joint import Joint
from pyrep.objects.dummy import Dummy
from pyrep.objects.proximity_sensor import ProximitySensor
from env import setup_environment, shutdown_environment
from rlbench.const import colors
"""

def _try_parse(code: str) -> bool:
    try:
        ast.parse(code)
        return True
    except SyntaxError:
        return False

def extract_from_code_block(code: str):
    pattern = r'```python(?:\s*\n)?(.*?)```'
    matches = re.findall(pattern, code, flags=re.DOTALL)
    if len(matches) > 0:
        code_block = matches[0].strip()
        return code_block, True
    else:
        return code, False
def _get_docstring_line_indices(code: str) -> Set[int]:
    lines = code.splitlines()
    doc_lines = set()
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return _heuristic_triple_quote_blocks(code)

    def visit(node):
        if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            doc = ast.get_docstring(node, clean=False)
            if doc is not None and node.body:
                first = node.body[0]
                if isinstance(first, ast.Expr) and isinstance(first.value, (ast.Constant,)):
                    start = first.lineno - 1
                    end = getattr(first, "end_lineno", start) - 1
                    for i in range(start, end + 1):
                        doc_lines.add(i)
        for child in ast.iter_child_nodes(node):
            visit(child)

    visit(tree)
    return doc_lines

def _heuristic_triple_quote_blocks(code: str) -> Set[int]:
    lines = code.splitlines()
    doc_lines = set()
    open_quote = None
    for idx, line in enumerate(lines):
        if open_quote:
            doc_lines.add(idx)
            if open_quote in line:
                if line.count(open_quote) % 2 == 1:
                    open_quote = None
        else:
            m = re.search(r'("""|\'\'\')', line)
            if m:
                quote = m.group(1)
                open_quote = quote
                doc_lines.add(idx)
                if line.count(quote) >= 2:
                    open_quote = None
    return doc_lines

_keyword_start = re.compile(
    r'^(?:def|class|import|from|return|if|elif|else|for|while|try|except|with|as|lambda|assert|yield|async|await)\b'
)

def _is_code_like_line(line: str) -> bool:
    stripped = line.strip()
    if not stripped:
        return True
    if stripped.startswith("
        return True
    if _keyword_start.match(stripped):
        return True
    code_indicators = [
        "=", ":", "->", "(", ")", "[", "]", "{", "}", "+", "-", "*", "/", "%",
    ]
    if any(ind in line for ind in code_indicators):
        return True
    if re.match(r"\s*[A-Za-z_][A-Za-z0-9_]*\s*=", line):
        # print(colored(line, "blue"))
        return True
    if re.search(r"[A-Za-z_][A-Za-z0-9_]*\s*\(", line):
        # print(colored(line, "yellow"))
        return True
    # print(colored(line, "red"))
    return False

def _is_prose_line(line: str) -> bool:
    stripped = line.strip()
    if not stripped or stripped.startswith("
        return False
    if _is_code_like_line(line):
        return False
    # words = stripped.split()
    # if len(words) >= 5:
    #     if any(c in stripped for c in [".", ",", ";", ":", "?"]):
    #         return True
    #     english_connectors = {"and", "or", "but", "because", "however", "therefore"}
    #     korean_connectors = {"", "", "", "", "", ""}
    #     lower_words = {w.strip(".,!?").lower() for w in words}
    #     if lower_words & english_connectors:
    #         return True
    #     avg_len = sum(len(w) for w in words) / len(words)
    #     if avg_len >= 4 and len(words) >= 6:
    #         return True
    return True

def remove_non_code_text(code: str) -> str:
    doc_lines = _get_docstring_line_indices(code)
    print("doc_lines: ", doc_lines)
    lines = code.splitlines()

    #    docstring     prose 
    if _try_parse(code):
        cleaned = []
        for idx, ln in enumerate(lines):
            if idx in doc_lines:
                cleaned.append(ln)
            elif _is_prose_line(ln):
                continue
            else:
                cleaned.append(ln)
        return "\n".join(cleaned)

    #    : prose    docstring   
    prose_indices = [i for i, ln in enumerate(lines) if i not in doc_lines and _is_prose_line(ln)]
    cleaned_lines = lines.copy()
    from itertools import combinations

    if len(prose_indices) <= 8:
        found = False
        for r in range(len(prose_indices), 0, -1):
            for combo in combinations(prose_indices, r):
                candidate = [ln for idx, ln in enumerate(lines) if idx not in combo]
                if _try_parse("\n".join(candidate)):
                    cleaned_lines = candidate
                    found = True
                    break
            if found:
                break
    else:
        #   
        for i in prose_indices:
            if i >= len(cleaned_lines):
                continue
            if i in doc_lines:
                continue
            candidate = cleaned_lines[:i] + cleaned_lines[i+1:]
            if _try_parse("\n".join(candidate)):
                cleaned_lines = candidate

    #    , docstring    prose   
    if not _try_parse("\n".join(cleaned_lines)):
        fallback = []
        for idx, ln in enumerate(cleaned_lines):
            if idx in doc_lines:
                fallback.append(ln)
            elif _is_prose_line(ln):
                continue
            else:
                fallback.append(ln)
        cleaned_lines = fallback

    return "\n".join(cleaned_lines)

# FOR ENVIRONMENT ERRORS
## handle error: WrongObjectType
import ast

def fix_type_mismatches(code: str, type_map: Dict[str, str]) -> str:
    class TypeFixer(ast.NodeTransformer):
        def visit_Assign(self, node):
            #   (: tv_frame = ...) 
            if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
                return self.generic_visit(node)
            
            # Extract the key_name from the constructor argument
            key_name = None
            if isinstance(node.value, ast.Call) and node.value.args:
                # Get the first argument if it's a string constant
                first_arg = node.value.args[0]
                if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
                    key_name = first_arg.value
                elif isinstance(first_arg, ast.Str):
                    key_name = first_arg.s
            
            if not key_name:
                return self.generic_visit(node)
            
            try:
                expected_type = type_map.get(key_name)
            except Exception as e:
                print(e)
                return self.generic_visit(node)
            
            if expected_type and isinstance(node.value, ast.Call):
                func = node.value.func
                current_type = None
                if isinstance(func, ast.Name):
                    current_type = func.id
                elif isinstance(func, ast.Attribute):
                    current_type = func.attr
                if current_type and current_type != expected_type:
                    #    (  attribute attr )
                    if isinstance(func, ast.Name):
                        node.value.func = ast.copy_location(
                            ast.Name(id=expected_type, ctx=ast.Load()), func
                        )
                    elif isinstance(func, ast.Attribute):
                        node.value.func = ast.copy_location(
                            ast.Attribute(value=func.value, attr=expected_type, ctx=ast.Load()), func
                        )
            return self.generic_visit(node)

    tree = ast.parse(code)
    fixer = TypeFixer()
    fixed_tree = fixer.visit(tree)
    ast.fix_missing_locations(fixed_tree)
    try:
        fixed_code = astor.to_source(fixed_tree)
    except AttributeError:
        raise RuntimeError("  astor.to_source   .")
    return fixed_code

# tmp_code = """
# def get_object_positions(index: int):
#     a_obj = Joint('a')
#     b_obj = Shape('b')
#     return True

# def tmp2():
#     return False
# """

# tmp_dict = {
#     'a': 'Shape',
#     'b': 'Joint'
# }

# print(fix_type_mismatches(tmp_code, tmp_dict))

## handle error:  Handle does not exist
### 1. object    
### 2. dictionary   + dictionary   
###  LLM planning   ..?

def build_type_map_from_scene(scene_info: Dict[str, Any]) -> Dict[str, str]:
    type_map = {}
    objects = scene_info.get("objects", [])
    for obj in objects:
        name = obj.get("name")
        obj_type = obj.get("type")
        if name and obj_type:
            type_map[name] = obj_type
    return type_map


def fix_object_types_from_scene(code: str, scene_info: Dict[str, Any]) -> str:
    type_map = build_type_map_from_scene(scene_info)
    if not type_map:
        return code

    try:
        return fix_type_mismatches(code, type_map)
    except Exception as e:
        print(f"[POSTPROCESS] fix_object_types_from_scene failed: {e}")
        return code


def remove_unnecessary_success_checks(code: str) -> str:
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return code

    class SuccessCheckRemover(ast.NodeTransformer):
        def visit_If(self, node: ast.If):
            #    
            self.generic_visit(node)

            # if success.detect():  if <var>.detect():  
            if isinstance(node.test, ast.Call):
                if isinstance(node.test.func, ast.Attribute):
                    method_name = node.test.func.attr
                    # detect()  ,  success  
                    # body/orelse Success!/Failed! print  
                    if method_name == 'detect':
                        obj_name = ""
                        if isinstance(node.test.func.value, ast.Name):
                            obj_name = node.test.func.value.id.lower()

                        # success  , print  
                        if 'success' in obj_name or self._has_success_fail_print(node):
                            print(f"[POSTPROCESS] Removed unnecessary success check: if {obj_name}.detect()...")
                            return None

            # if done: return ...  
            if isinstance(node.test, ast.Name) and node.test.id == 'done':
                # body  return  
                if len(node.body) == 1 and isinstance(node.body[0], ast.Return):
                    print(f"[POSTPROCESS] Removed unnecessary 'if done: return' statement")
                    return None

            return node

        def visit_Expr(self, node: ast.Expr):
            # print("Success!...")  print("Failed!...")   
            # : None  if      IndentationError 
            #  pass  
            if isinstance(node.value, ast.Call):
                if isinstance(node.value.func, ast.Name) and node.value.func.id == 'print':
                    if node.value.args:
                        first_arg = node.value.args[0]
                        if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
                            msg = first_arg.value.lower()
                            if any(kw in msg for kw in ['success!', 'failed!', 'task complete', 'task failed']):
                                print(f"[POSTPROCESS] Replaced unnecessary print with pass: {first_arg.value[:50]}...")
                                return ast.Pass()
            return node

        def _has_success_fail_print(self, if_node: ast.If) -> bool:
            for stmt in if_node.body + if_node.orelse:
                if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call):
                    if isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'print':
                        if stmt.value.args:
                            first_arg = stmt.value.args[0]
                            if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
                                msg = first_arg.value.lower()
                                if 'success' in msg or 'failed' in msg:
                                    return True
            return False

    remover = SuccessCheckRemover()
    new_tree = remover.visit(tree)
    ast.fix_missing_locations(new_tree)

    try:
        return astor.to_source(new_tree)
    except Exception as e:
        print(f"[POSTPROCESS] remove_unnecessary_success_checks failed to regenerate code: {e}")
        return code


def remove_code_starts_here_suffix(code: str) -> str:
    marker = "# CODE STARTS HERE"
    if marker in code:
        idx = code.find(marker)
        cleaned_code = code[:idx].rstrip()
        print(f"[POSTPROCESS] Removed '# CODE STARTS HERE' suffix ({len(code) - len(cleaned_code)} chars)")
        return cleaned_code
    return code


def adjust_indentation(code: str) -> str:
    lines = code.splitlines(keepends=True)
    new_lines = []
    for line in lines:
        #   4     
        # line = line.expandtabs(4)

        #    
        stripped = line.lstrip(' ')
        indent = len(line) - len(stripped)

        # 4  
        if indent == 0:
            if line.startswith("from") or line.startswith("import") or line.startswith("def"):
                new_indent = indent
            else:
                new_indent = 4
        elif indent % 4 == 0:
            new_indent = indent
        else:
            # 1~4 → 4, 5~8 → 8, ...  
            new_indent = ((indent - 1) // 4 + 1) * 4

        # 
        new_lines.append(' ' * new_indent + stripped)

    return ''.join(new_lines)

def add_return_statement(code: str) -> str:
    tree = ast.parse(code)

    def _is_if_done(test: ast.AST) -> bool:
        # if done:
        if isinstance(test, ast.Name) and test.id == "done":
            return True

        # if (done == True)  if (True == done)
        if isinstance(test, ast.Compare) and len(test.ops) == 1 and len(test.comparators) == 1:
            left, op, right = test.left, test.ops[0], test.comparators[0]
            is_true_const = lambda n: isinstance(n, ast.Constant) and n.value is True \
                                       or (isinstance(n, ast.Constant) and n.value is True)
            left_is_done = isinstance(left, ast.Name) and left.id == "done"
            right_is_done = isinstance(right, ast.Name) and right.id == "done"
            if isinstance(op, ast.Eq) and ((left_is_done and is_true_const(right)) or (right_is_done and is_true_const(left))):
                return True
            if isinstance(op, ast.Is) and ((left_is_done and is_true_const(right)) or (right_is_done and is_true_const(left))):
                return True

        #   (: if not done:)  
        return False

    class _ForceReturn(ast.NodeTransformer):
        def visit_If(self, node: ast.If):
            self.generic_visit(node)
            if _is_if_done(node.test):
                if not node.body or not isinstance(node.body[-1], ast.Return):
                    # astunparse      Name 
                    node.body.append(
                        ast.Return(
                            value=ast.Name(id='True', ctx=ast.Load())
                        )
                    )
            return node

    new_tree = _ForceReturn().visit(tree)
    ast.fix_missing_locations(new_tree)
    # print(" .")
    return astunparse.unparse(new_tree)

def add_indent(text, indent_spaces=4):
    indent = ' ' * indent_spaces
    return '\n'.join(indent + line if line.strip() else line for line in text.splitlines())

def postprocess(code: str, type_map: Dict[str, str] = None, initial_instruction: str = None):
    # initial_instruction  postprocess    
    temp_interface = "def temp_interface():\n"
    
    code, is_code_block = extract_from_code_block(code)
    if is_code_block:
        code = add_indent(code)
    else:
        code = code.replace("```python", "").replace("```", "").strip()
    if initial_instruction:
        code = temp_interface + code
    code = adjust_indentation(code)
    # print("adjust_indentation Code:\n", code)
    code = remove_non_code_text(code)
    # print("remove_non_code_text Code:\n", code)
    
    # code = fix_type_mismatches(code, type_map)
    # code = add_return_statement(code)
    # print("add_return_statement Code:\n", code)
    if initial_instruction:
        code = code.replace(temp_interface, "")
        # code = CODE_HEADER + "\n" + initial_instruction + code
        code = initial_instruction + code
    return code

# with open('./utils/tmp.py', 'r') as f:
#     initial_code = f.read()
# with open('./tasks/ReachTarget.py', 'r') as f:
#     original_code = f.read()

# new_code = postprocess(original_code, initial_instruction=initial_code)
# print("Extracted Code Block:\n", new_code)


def remove_comments(code: str) -> str:

    Args:
        code:  Python  

    Returns:
        str:   
    """
    lines = code.splitlines()
    cleaned_lines = []
    in_docstring = False
    docstring_char = None

    for line in lines:
        stripped = line.strip()

        # docstring / 
        if not in_docstring:
            if stripped.startswith('"""') or stripped.startswith("'''"):
                docstring_char = stripped[:3]
                in_docstring = True
                #       
                if stripped.count(docstring_char) >= 2:
                    in_docstring = False
                cleaned_lines.append(line)
                continue
        else:
            if docstring_char in stripped:
                in_docstring = False
            cleaned_lines.append(line)
            continue

        # docstring   
        if in_docstring:
            cleaned_lines.append(line)
            continue

        #    
        if stripped.startswith('#'):
            continue

        #    (  # )
        #  :     
        result_line = line
        in_string = False
        string_char = None
        comment_start = -1

        i = 0
        while i < len(line):
            char = line[i]

            if not in_string:
                if char in ('"', "'"):
                    #   
                    if line[i:i+3] in ('"""', "'''"):
                        string_char = line[i:i+3]
                        in_string = True
                        i += 3
                        continue
                    else:
                        string_char = char
                        in_string = True
                elif char == '#':
                    comment_start = i
                    break
            else:
                #   
                if len(string_char) == 3:
                    if line[i:i+3] == string_char:
                        in_string = False
                        i += 3
                        continue
                else:
                    if char == string_char and (i == 0 or line[i-1] != '\\'):
                        in_string = False
            i += 1

        if comment_start >= 0:
            result_line = line[:comment_start].rstrip()

        #    
        if result_line.strip():
            cleaned_lines.append(result_line)
        elif not stripped:  #    
            cleaned_lines.append(line)

    result = '\n'.join(cleaned_lines)
    if result != code:
        print(f"[POSTPROCESS] Removed comments from code")
    return result


def apply_code_postprocessing(code: str, scene_info: Dict[str, Any]) -> str:
    try:
        code = remove_code_starts_here_suffix(code)
        code = remove_comments(code)
        code = fix_object_types_from_scene(code, scene_info)
        code = remove_unnecessary_success_checks(code)
        return code
    except Exception as e:
        print(f"[POSTPROCESS] apply_code_postprocessing failed: {e}")
        return code