import ast
from pddl.core import Predicate
from pddl.logic.base import Variable
import os
import re
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
import json
import argparse
import subprocess
from pathlib import Path
from dotenv import load_dotenv
from pddl.logic.base import Not

from llm_utils import Chat, ImageMessageContent, TextGenApi, TextMessageContent, UserMessage
from tp_lodge.task_planning.models.pddl.pddl_domain import PDDLDomain
from tp_lodge.utils.pddl_domain_syntax import parse_formula
from tp_lodge.utils.pddl_utils import (
    get_list_of_predicates,
)
from tqdm import tqdm


def find_trajectories_covering_all_transitions(transitions):
    """
    Find a minimal set of trajectories that covers all transitions.
    This version is deterministic - always returns the same result for the same input.

    Args:
        transitions: dict where key is prev_state and value is list of next_states
    Returns:
        List of trajectories, where each trajectory is a list of state hashes
    """
    # Create a copy of transitions to track uncovered edges
    uncovered = defaultdict(set)
    for prev_state, next_states in transitions.items():
        uncovered[prev_state] = set(next_states)

    # Find all states (nodes in the graph)
    all_states = set(transitions.keys())
    for next_states in transitions.values():
        all_states.update(next_states)

    # Find states with no incoming edges (potential starting points)
    has_incoming = set()
    for next_states in transitions.values():
        has_incoming.update(next_states)

    # Sort for deterministic ordering
    start_candidates = sorted(all_states - has_incoming)

    trajectories = []

    def extend_trajectory_greedy(start_state):
        """Extend a trajectory as far as possible using greedy approach"""
        trajectory = [start_state]
        current = start_state

        while current in uncovered and uncovered[current]:
            # Choose next state deterministically:
            # 1. First by number of uncovered outgoing transitions (descending)
            # 2. Then by lexicographic order (ascending) for tie-breaking
            next_state = min(uncovered[current], key=lambda s: (-len(uncovered.get(s, [])), s))

            trajectory.append(next_state)
            uncovered[current].remove(next_state)

            # Clean up empty sets
            if not uncovered[current]:
                del uncovered[current]

            current = next_state

        return trajectory

    # Start with states that have no incoming edges (in sorted order)
    for start_state in start_candidates:
        if start_state in uncovered:
            traj = extend_trajectory_greedy(start_state)
            if len(traj) > 1:  # Only add if it covers at least one transition
                trajectories.append(traj)

    # Cover remaining transitions by starting from states with uncovered outgoing edges
    while uncovered:
        # Choose a state with uncovered transitions deterministically:
        # 1. First by number of uncovered outgoing transitions (descending)
        # 2. Then by lexicographic order (ascending) for tie-breaking
        start_state = min(uncovered.keys(), key=lambda s: (-len(uncovered[s]), s))

        traj = extend_trajectory_greedy(start_state)
        if len(traj) > 1:
            trajectories.append(traj)
        else:
            # If we can't extend, just remove this state to avoid infinite loop
            del uncovered[start_state]

    return trajectories


def _get_demonstrations(out_dir: Path):
    # Implement the logic to get transitions from the FB domain
    imgs_dir = out_dir / "reply_buffer/images"
    reply_buffer = json.loads((out_dir / "reply_buffer/states.json").read_text())

    def _get_predicates(state):
        if state["similar_state"] is not None:
            sim_state = reply_buffer["states"][state["similar_state"]]
            return _get_predicates(sim_state)
        else:
            predicates = state["predicates"]
            assert predicates is not None
            predicates = [f"(not {p})" if not e else p for p, e in predicates.items()]
            return get_list_of_predicates(parse_formula(f"(and {' '.join(predicates)})", only_variables=False))

    transitions = defaultdict(list)
    for state_hash, state in reply_buffer["states"].items():
        if state["prev_state_hash"] is None:
            continue
        transitions[state["prev_state_hash"]].append(state_hash)

    trajectories = find_trajectories_covering_all_transitions(transitions)

    demonstrations = []
    for trajectory in trajectories:
        demonstration = []
        c_state_hash = trajectory[0]

        demonstration.append(Image.open(imgs_dir / f"state_{c_state_hash}.png"))

        for n_state_hash in trajectory[1:]:
            n_state = reply_buffer["states"][n_state_hash]

            demonstration.append(n_state["executed_skill"])
            demonstration.append(Image.open(imgs_dir / f"state_{n_state_hash}.png"))

        demonstrations.append(demonstration)

    return demonstrations


def _demos_dir(args):
    root_dir = Path(__file__).parent.parent
    is_fb = args.domain.startswith("fb-")
    return root_dir / ("furniturebench" if is_fb else "ipc")


def _data_source_dir(args):
    llm = args.llm
    is_fb = args.domain.startswith("fb-")
    model_dir = TextGenApi.default(llm).connections.connections[0].model_dir
    suffix = "planning-with-pred-learning" if is_fb else "iclr-w-dk-w-ai-shared"
    return _demos_dir(args) / f"results/{args.domain}/{model_dir}/hi-tamp/{suffix}/sample-0"


def add_text_heading(
    image,
    text,
    font_size=40,
    font_color="black",
    position="top-center",
    padding=20,
    background_color="white",
    font_path=None,
):
    """
    Add a text heading to a PIL Image.

    Args:
        image: PIL Image object
        text: Text to add as heading
        font_size: Size of the font (default: 40)
        font_color: Color of the text (default: 'black')
        position: Position of text - 'top-center', 'top-left', 'top-right',
                 'bottom-center', 'bottom-left', 'bottom-right', or (x, y) tuple
        padding: Padding around the text (default: 20)
        background_color: Background color for text area (default: 'white' with 50% transparency)
        font_path: Path to custom font file (None for default)

    Returns:
        PIL Image with text added
    """

    # Create a copy of the image to avoid modifying the original
    img_copy = image.copy()
    draw = ImageDraw.Draw(img_copy)

    # Load font
    try:
        if font_path and os.path.exists(font_path):
            font = ImageFont.truetype(font_path, font_size)
        else:
            # Try to load a default font
            try:
                # Common system fonts (adjust path based on your OS)
                system_fonts = [
                    "/System/Library/Fonts/Arial.ttf",  # macOS
                    "/Windows/Fonts/arial.ttf",  # Windows
                    "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",  # Linux
                ]
                font = None
                for font_file in system_fonts:
                    if os.path.exists(font_file):
                        font = ImageFont.truetype(font_file, font_size)
                        break

                if font is None:
                    font = ImageFont.load_default()
            except:
                font = ImageFont.load_default()
    except:
        font = ImageFont.load_default()

    # Get text dimensions
    bbox = draw.textbbox((0, 0), text, font=font)
    text_width = bbox[2] - bbox[0]
    text_height = bbox[3] - bbox[1]

    # Calculate position
    img_width, img_height = img_copy.size

    if isinstance(position, tuple):
        x, y = position
    else:
        if position == "top-center":
            x = (img_width - text_width) // 2
            y = padding
        elif position == "top-left":
            x = padding
            y = padding
        elif position == "top-right":
            x = img_width - text_width - padding
            y = padding
        elif position == "bottom-center":
            x = (img_width - text_width) // 2
            y = img_height - text_height - padding
        elif position == "bottom-left":
            x = padding
            y = img_height - text_height - padding
        elif position == "bottom-right":
            x = img_width - text_width - padding
            y = img_height - text_height - padding
        else:
            # Default to top-center
            x = (img_width - text_width) // 2
            y = padding

    # Add background rectangle if specified
    if background_color:
        bg_padding = 10
        rect_coords = [x - bg_padding, y - bg_padding, x + text_width + bg_padding, y + text_height + bg_padding]

        # Create a semi-transparent overlay for the background
        if background_color and "white" in str(background_color).lower():
            # Create a transparent overlay
            overlay = Image.new("RGBA", img_copy.size, (0, 0, 0, 0))
            overlay_draw = ImageDraw.Draw(overlay)
            # White with 50% transparency (128 out of 255)
            overlay_draw.rectangle(rect_coords, fill=(255, 255, 255, 100))

            # Convert original image to RGBA if it isn't already
            if img_copy.mode != "RGBA":
                img_copy = img_copy.convert("RGBA")

            # Composite the overlay onto the image
            img_copy = Image.alpha_composite(img_copy, overlay)
            draw = ImageDraw.Draw(img_copy)
        else:
            # Use solid background color
            draw.rectangle(rect_coords, fill=background_color)

    # Add text
    draw.text((x, y), text, fill=font_color, font=font)

    return img_copy.convert("RGB")


def _propose_atoms(demo, textgen_api: TextGenApi, objs: list[str], out_dir: Path):
    propose_atoms_file = out_dir / "proposed_atoms.json"
    if propose_atoms_file.is_file():
        return json.loads(propose_atoms_file.read_text())

    skills = demo[1::2]
    images = demo[0::2]
    assert len(images) == len(skills) + 1

    skill_names = [skill.split("(")[0] for skill in skills]

    stamped_images = [
        add_text_heading(image, text=f"Timestamp {i}", background_color="white") for i, image in enumerate(images)
    ]

    objs_str = ", ".join(objs)
    skills_str = "\n".join(skills)

    prompt = f"""
    You are a robotic vision system whose job is to output a structured set of predicates useful for describing important concepts in the following demonstration of a task. You will be provided with a list of actions used during the task, as well as images of states before and after every action execution. Please provide predicates in terms of the following objects: {objs_str}. For each predicate, output it in the following format: predicate name(obj1, obj2, obj3...). Start by generating predicates that change before and after each action. After this, generate any other predicates that perhaps do not change but are still important to describing the demonstration shown. For each predicate you generate, also generate some predicates that are synonyms and antonyms so that any predicate that is even tangentially relevant to the demonstrations is generated.
    Skills executed in trajectory:
    {skills_str}
    """

    chat = Chat(
        messages=[
            UserMessage(
                [
                    TextMessageContent(text=prompt),
                    *[ImageMessageContent(image=i) for i in stamped_images],
                ]
            )
        ]
    )

    response = textgen_api.do_call(chat)

    chat.add_message(response)

    out_dir.mkdir(parents=True, exist_ok=True)
    response_text = response.content[0].text

    (out_dir / "chat.txt").write_text(str(chat))
    (out_dir / "response.txt").write_text(str(response.content[0].text))

    pattern = r"([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\))"
    matches = re.findall(pattern, response_text)
    matches = list(set(matches))

    parsed_ground_atoms = []
    for match in matches:
        f_name, f_args = match.split("(")
        f_args = f_args.rstrip(")")
        if f_name in skill_names:
            continue
        if len(f_args.strip()) > 0:
            f_args = [arg.replace("'", "").replace('"', "").strip() for arg in f_args.split(",")]
        else:
            f_args = []
        if any(f_arg not in objs for f_arg in f_args):
            print(f"Skipping function: {f_name} with invalid args: {f_args}")
            continue
        parsed_ground_atoms.append((f_name, f_args))

    propose_atoms_file.write_text(json.dumps(parsed_ground_atoms, indent=2))

    return parsed_ground_atoms


def _lift_atoms(atoms: list[tuple[str, list[str]]], out_dir: Path, objs: dict[str, str]):
    lifted_atoms = []
    for name, args in atoms:
        lifted_atoms.append(Predicate(name, *[Variable(arg, [objs[arg]]) for arg in args]))
    (out_dir / "lifted_atoms.txt").write_text("\n".join(str(a) for a in lifted_atoms))
    return lifted_atoms


def _load_data(exp_dir: Path, args):
    is_fb = args.domain.startswith("fb-")
    assert is_fb
    assert args.domain == "fb-lamp"

    textgen_api = TextGenApi.default(args.llm)
    demos_dir = _demos_dir(args)
    data_dir = demos_dir / "data" / args.domain
    data_source_dir = _data_source_dir(args)

    function_stubs = ast.parse((data_dir / "function_stubs.py").read_text()).body

    demos = _get_demonstrations(data_source_dir)
    objs = {"lamp_base": "part", "lamp_bulb": "part", "lamp_hood": "part", "table": "table", "robot_arm": "robot"}

    all_lifted_atoms = {}
    for i, demo in enumerate(tqdm(demos)):
        demo_out_dir = exp_dir / f"demos/demo-{i}"
        atoms = _propose_atoms(demo=demo, textgen_api=textgen_api, objs=list(objs.keys()), out_dir=demo_out_dir)
        lifted_atoms = _lift_atoms(atoms, out_dir=demo_out_dir, objs=objs)
        for lifted_atom in lifted_atoms:
            if lifted_atom.name not in all_lifted_atoms:
                all_lifted_atoms[lifted_atom.name] = lifted_atom

    raise NotImplementedError()
    action_predicates = {
        expr.name: {"name": expr.name, "arity": len(expr.args.args), "var_types": ["part" for _ in expr.args.args]}
        for expr in function_stubs
    }

    gen_domain = PDDLDomain.loads((data_source_dir / "domain.json").read_text())
    handover_dir = exp_dir
    pddl_preds = {
        p.definition.name: {
            "name": p.definition.name,
            "arity": len(p.definition.terms),
            "var_types": [list(t.type_tags)[0] for t in p.definition.terms],
        }
        for p in gen_domain.predicates
    }

    def parse_preds(ps):
        return [
            {"predicate_name": p.name, "variables": [p.name for p in p.terms]} for p in ps if not isinstance(p, Not)
        ]

    def parse_action(a):
        a = ast.parse(a).body[0].value

        predicate = a.func.id
        variables = [arg.n for arg in a.args]

        return {"action_pred_name": predicate, "variables": variables}

    episode = [
        # state, action, next-state, None
        (parse_preds(t["state"]), parse_action(t["action"]), parse_preds(t["next_state"]), None)
        for t in transitions
    ]

    (handover_dir / "preds.json").write_text(json.dumps(pddl_preds, indent=2))
    (handover_dir / "action-preds.json").write_text(json.dumps(action_predicates, indent=2))
    (handover_dir / "episode.json").write_text(json.dumps(episode, indent=2))


def _out_dir(args):
    demos_root_dir = Path(__file__).parent.parent / ("furniturebench" if args.domain.startswith("fb-") else "ipc")
    connection = TextGenApi.default(args.llm).connections.connections[0]
    out_dir = demos_root_dir / "results" / args.domain / connection.model_dir / "pix2pred"
    return out_dir


def _cluster_dir():
    ipc_root_dir = Path(__file__).parent
    return ipc_root_dir.parent.parent / "3rdparty/LOFT_IROS_2021"


def run_cluster(args):
    exp_dir = _out_dir(args)
    exp_dir.mkdir(parents=True, exist_ok=True)

    _load_data(exp_dir, args)

    baseline_root = _cluster_dir()
    python_path = baseline_root / ".pixi/envs/default/bin/python"

    response = subprocess.run(
        [
            python_path,
            baseline_root / "learn-nrst.py",
            "--out_dir",
            exp_dir,
        ],
        cwd=baseline_root,
    )
    if response.returncode != 0:
        print("Error running learn-nrst.py:")
        print(response.returncode)
        if response.stdout is not None:
            print(response.stdout.decode("utf-8"))
        if response.stderr is not None:
            print(response.stderr.decode("utf-8"))
        raise RuntimeError("Failed to run learn-nrst.py")


def main(args):
    run_cluster(args)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--domain", type=str, default="household", required=True)
    argparser.add_argument("--llm", type=str, required=True)
    load_dotenv()
    main(argparser.parse_args())
