import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Dict, Any, Union

from PIL import Image
from pddl.core import Predicate
from tp_lodge.utils.pddl_domain_syntax import parse_predicate
from python_utils.data_utils import dict_similar
import logging

from state_estimation.se_variable import SEVariable

logger = logging.getLogger(__name__)


class VariableParser:

    def get_printable_for_llm(self, state: SEVariable, precision: int = 2) -> SEVariable:
        raise NotImplementedError()

    def from_dict(self, data: Dict) -> SEVariable:
        raise NotImplementedError()

    def to_dict(self, state: SEVariable) -> Dict:
        raise NotImplementedError()


@dataclass
class State:
    """Represents a single state in the reply buffer."""

    variables: List[SEVariable]
    env_hashes: List[str]
    prev_state_hash: Optional[str]
    executed_skill: Optional[str]
    similar_state: Optional[str]
    predicates: Optional[dict[Predicate, Optional[bool]]] = None

    def to_dict(self, var_parser: VariableParser) -> Dict[str, Any]:
        """Convert state to dictionary for JSON serialization."""
        return {
            "variables": [var_parser.to_dict(var) for var in self.variables],
            "predicates": {str(pred): val for pred, val in self.predicates.items()} if self.predicates else None,
            "similar_state": self.similar_state,
            "env_hashes": self.env_hashes,
            "prev_state_hash": self.prev_state_hash,
            "executed_skill": self.executed_skill,
        }

    @staticmethod
    def from_dict(data: Dict[str, Any], var_parser: VariableParser) -> "State":
        """Create state from dictionary loaded from JSON."""
        variables = [var_parser.from_dict(var_data) for var_data in data["variables"]]

        predicates = None
        if data.get("predicates"):
            predicates = {
                parse_predicate(pred_data, only_variables=False): val for pred_data, val in data["predicates"].items()
            }

        return State(
            variables=variables,
            predicates=predicates,
            similar_state=data["similar_state"],
            env_hashes=data["env_hashes"],
            prev_state_hash=data.get("prev_state_hash"),
            executed_skill=data.get("executed_skill"),
        )


class ReplyBuffer:
    """
    A buffer that holds states containing variables, images, and optionally predicates.
    All data is saved to a directory in a human-readable format.
    """

    def __init__(self, buffer_dir: Union[str, Path], var_parser: VariableParser):
        """
        Initialize the reply buffer.

        Args:
            buffer_dir: Directory where the buffer data will be stored
        """
        self.buffer_dir = Path(buffer_dir)

        self.states: Dict[str, State] = {}
        self.evaluated_predicates: List[str] = []

        self.var_parser = var_parser

        # Create directory if it doesn't exist
        self.buffer_dir.mkdir(parents=True, exist_ok=True)
        self.images_dir = self.buffer_dir / "images"
        self.images_dir.mkdir(exist_ok=True)

        # Load existing data if available
        self._load_from_disk()

    def _img_filename_for_state(self, hash: str) -> str:
        return f"state_{hash}.png"

    def add_state(
        self,
        hash: str,
        variables: List[SEVariable],
        image: Image.Image,
        env_hash: str,
        prev_env_hash: Optional[str] = None,
        executed_skill: Optional[str] = None,
    ) -> None:
        """
        Add a new state to the buffer.

        Args:
            variables: List of SEVariable objects representing the state
            image: PIL Image of the environment
            predicates: Optional list of grounded predicates from VLM
        """
        prev_state_hash = self.get_state_hash_by_env_hash(prev_env_hash) if prev_env_hash is not None else None

        existing_state = self.states.get(hash)
        if existing_state is None:
            try:
                state_hash = self.get_state_hash_by_env_hash(env_hash)
                existing_state = self.states.get(state_hash)
            except RuntimeError:
                pass

        similar_state_hash = None
        if existing_state is None:
            # we really dont want duplicate states -> check that they are significantly different
            for sim_state_hash, state in self.states.items():
                map_v = lambda v: self.var_parser.to_dict(self.var_parser.get_printable_for_llm(v))
                state_vars_dict = [map_v(var) for var in state.variables]
                curr_vars_dict = [map_v(var) for var in variables]
                if state.executed_skill != executed_skill:
                    # HACK: needed to distinguish similar states
                    continue
                if dict_similar(state_vars_dict, curr_vars_dict, tolerance=0.03, ignore_keys=["orientation"]):
                    similar_state_hash = sim_state_hash
                    break

        if existing_state is not None:
            # add env_hash to existing state
            # assert existing_state.prev_state_hash == prev_state_hash
            if env_hash not in existing_state.env_hashes:
                existing_state.env_hashes.append(env_hash)

        else:
            # Generate unique filename for the image
            image_filename = self._img_filename_for_state(hash)
            image_path = self.images_dir / image_filename
            image.save(image_path)

            # Create state object
            state = State(
                variables=variables.copy(),
                env_hashes=[env_hash],
                prev_state_hash=prev_state_hash,
                similar_state=similar_state_hash,
                executed_skill=executed_skill,
            )

            # Add to buffer
            self.states[hash] = state

        # Save to disk
        self._save_to_disk()

    def set_predicates(self, state_hash: str, predicates: dict[Predicate, Optional[bool]]):
        self.states[state_hash].predicates = predicates

        self._save_to_disk()

    def set_evaluated_predicates(self, evaluated_predicates: List[str]):
        # assert set(evaluated_predicates).issuperset(self.evaluated_predicates)
        self.evaluated_predicates = list(set(evaluated_predicates) | set(self.evaluated_predicates))
        self._save_to_disk()

    def get_state_hash_by_env_hash(self, env_hash: str) -> str:
        for s_hash, state in self.states.items():
            if env_hash in state.env_hashes:
                return s_hash
        raise RuntimeError("State not found")

    def get_state(self, state_hash: str) -> State:
        return self.states[state_hash]

    def has_state(self, state_hash: str) -> bool:
        return state_hash in self.states

    def get_image(self, state_hash: str) -> Image.Image:
        """
        Load and return the image for a specific state.

        Args:
            state_hash: Hash of the state to retrieve

        Returns:
            PIL Image or None if not found
        """
        assert state_hash in self.states, f"State {state_hash} not found in buffer"
        image_path = self.images_dir / self._img_filename_for_state(state_hash)
        assert image_path.exists(), f"Image not found for state {state_hash}"
        return Image.open(image_path)

    def get_all_states(self) -> Dict[str, State]:
        """Get all states in the buffer."""
        return self.states.copy()

    def get_similar_state(self, state: State) -> tuple[str, State]:
        assert state.similar_state is not None
        sim_state = self.states[state.similar_state]
        if sim_state.similar_state is not None:
            return self.get_similar_state(sim_state)
        return state.similar_state, sim_state

    def clear(self) -> None:
        """Clear all states from the buffer and remove saved files."""
        # Remove all image files
        for state_hash in self.states.keys():
            image_path = self.images_dir / self._img_filename_for_state(state_hash)
            if image_path.exists():
                image_path.unlink()

        # Clear the buffer
        self.states.clear()

        # Update saved data
        self._save_to_disk()

    def size(self) -> int:
        """Get the number of states in the buffer."""
        return len(self.states)

    def _save_to_disk(self) -> None:
        """Save the buffer state to disk in JSON format."""
        buffer_data = {
            "states": {state_hash: state.to_dict(self.var_parser) for state_hash, state in self.states.items()},
            "evaluated_predicates": self.evaluated_predicates,
        }

        states_file = self.buffer_dir / "states.json"
        with open(states_file, "w", encoding="utf-8") as f:
            json.dump(buffer_data, f, indent=2, ensure_ascii=False)

    def _load_from_disk(self) -> None:
        """Load the buffer state from disk."""
        states_file = self.buffer_dir / "states.json"

        if not states_file.exists():
            return

        try:
            with open(states_file, "r", encoding="utf-8") as f:
                buffer_data = json.load(f)

            self.states = {
                state_hash: State.from_dict(state_data, self.var_parser)
                for state_hash, state_data in buffer_data["states"].items()
            }
            self.evaluated_predicates = buffer_data["evaluated_predicates"]

        except (json.JSONDecodeError, KeyError, ValueError) as e:
            logger.warning(f"Could not load buffer data from {states_file}: {e}")
            self.states = {}

    def __len__(self) -> int:
        """Return the number of states in the buffer."""
        return len(self.states)
