from dataclasses import asdict, dataclass
from itertools import chain
from pathlib import Path
import json
from typing import Any
import pyRDDLGym
import sys
from tqdm import tqdm

from regawa.model.base_grounded_model import GroundValue
from regawa.rddl.rddl_grounded_model import RDDLGroundedModel
from regawa.rddl.rddl_utils import rddl_ground_to_tuple
import numpy as np
import logging

logger = logging.getLogger(__name__)

RecordingObs = dict[str, Any]
RecordingAction = dict[str, int]
RecordingEntry = dict[RecordingObs, RecordingAction]
Recording = list[RecordingEntry]

GroundAction = dict[GroundValue, bool]
GroundObs = dict[GroundValue, Any]

IndexedAction = tuple[int, ...]


class JsonSerializer(json.JSONEncoder):
    def default(self, o: Any) -> Any:
        if isinstance(o, np.bool_):
            return bool(o)
        if isinstance(o, np.float64):
            return float(o)
        if isinstance(o, np.int64):
            return int(o)


def ground_to_tuple(s: str) -> GroundValue:
    return rddl_ground_to_tuple(s)


def convert_state_to_tuples(d: RecordingObs) -> dict[GroundValue, Any]:
    return {ground_to_tuple(k): v for k, v in d.items()}


def convert_actions_to_tuples(d: RecordingAction) -> dict[GroundValue, bool]:
    return convert_state_to_tuples(d) if d else {("None", "None"): True}


@dataclass
class Transistion:
    state: dict[str, Any]
    actions: dict[str, bool]
    step: int
    done: bool
    reward: float


def combine_data(domain: str, prost_data_path: str, instances: list[int]):
    prost_folder = Path(prost_data_path)

    data_files = list((prost_folder / domain).glob("**/**/**/data_*.json"))

    assert len(data_files) == 10

    def get_constants(instance: str):
        env = pyRDDLGym.make(domain, instance, enforce_action_constraints=True)  # type: ignore
        rddl_model = env.model
        grounded_rddl_model = RDDLGroundedModel(rddl_model)
        return {
            g: grounded_rddl_model.constant_value(g)
            for g in grounded_rddl_model.constant_groundings
        }

    combined_data: list[Transistion] = []
    step = 0

    data_files = list(filter(
        lambda df: int(df.parent.parent.name) - 1 in instances and df.is_file(),
        data_files,
    ))

    assert len(data_files) == len(instances)

    for df in data_files:
        instance = df.parent.parent.name

        constants = get_constants(str(instance))
        with open(df, "r") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                logger.error(f"Error in {df}")
                logger.error("Skipping this file due to JSON decode error.")
                continue

        for d in chain(*data):
            s = convert_state_to_tuples(d["state"])
            a = convert_actions_to_tuples(d["actions"])
            s = {**s, **constants}
            s = {"__".join(k): v for k, v in s.items()}
            a = {"__".join(k): v for k, v in a.items()}

            entry = Transistion(s, a, step, "round_reward" in d, d.get("reward", 0))

            step += 1
            combined_data.append(asdict(entry))

    #assert len(combined_data) == 2000, f"Expected 2000 transitions, got {len(combined_data)} for {domain} with instances {instances}"
    return combined_data


def main():
    prost_data_path = sys.argv[1]
    domain = sys.argv[2]
    instances = sys.argv[3]
    instances = json.loads(instances)
    combined_data = combine_data(domain, prost_data_path, instances)
    print(json.dumps(combined_data, cls=JsonSerializer))


if __name__ == "__main__":
    main()
