from absl import app, flags
import csv
import numpy as np
import gym
import d4rl

from wsrl.envs.d4rl_dataset import get_d4rl_dataset
from wsrl.envs.adroit_binary_dataset import get_hand_dataset_with_mc_calculation
from wsrl.envs.env_common import get_env_type


FLAGS = flags.FLAGS


flags.DEFINE_list(
    "envs",
    None,
    "Comma-separated list of environments to evaluate. If unset, uses --env",
)
flags.DEFINE_string("env", None, "Single environment to evaluate if --envs unset")

# Defaults to satisfy modules that expect these global flags (e.g., Adroit loader)
flags.DEFINE_float("reward_scale", 1.0, "Reward scale (default for loaders)")
flags.DEFINE_float("reward_bias", 0.0, "Reward bias (default for loaders)")

flags.DEFINE_float(
    "clip_action",
    0.99999,
    "Clip actions to be between [-n, n]. This is needed for tanh policies.",
)
flags.DEFINE_float(
    "antmaze_tol",
    1e-6,
    "Tolerance for detecting concatenation boundaries in AntMaze datasets.",
)
flags.DEFINE_float(
    "discount",
    0.99,
    "Discount factor used when loading Adroit datasets (for MC returns).",
)
flags.DEFINE_string(
    "output_csv",
    "dataset_eval.csv",
    "Path to write CSV summary: env_name, transitions, trajectories, mean, std",
)

# Default environments to evaluate if none are provided via flags
DEFAULT_ENVS = [
    # Kitchen
    "kitchen-mixed-v0",
    "kitchen-partial-v0",
    "kitchen-complete-v0",
    # Adroit (binary)
    "door-binary-v0",
    "pen-binary-v0",
    "relocate-binary-v0",
    # AntMaze
    "antmaze-ultra-diverse-v2",
    "antmaze-large-diverse-v2",
    "antmaze-large-play-v2",
    # MuJoCo locomotion: medium
    "halfcheetah-medium-v2",
    "walker2d-medium-v2",
    "hopper-medium-v2",
    # medium-replay
    "halfcheetah-medium-replay-v2",
    "walker2d-medium-replay-v2",
    "hopper-medium-replay-v2",
    # medium-expert
    "halfcheetah-medium-expert-v2",
    "walker2d-medium-expert-v2",
    "hopper-medium-expert-v2",
    # random
    "halfcheetah-random-v2",
    "walker2d-random-v2",
    "hopper-random-v2",
]



def split_antmaze_by_boundaries(env_name: str, tol: float = 1e-6):
    env = gym.make(env_name).unwrapped
    ds = d4rl.qlearning_dataset(env)

    observations = ds["observations"]
    next_observations = ds["next_observations"]
    rewards = ds["rewards"]

    num_transitions = ds["actions"].shape[0]

    T = len(rewards)
    done_marks = np.zeros(T, dtype=bool)

    # Mark ends of trajectories based on concatenation boundaries
    for i in range(T - 1):
        is_boundary = np.linalg.norm(observations[i + 1] - next_observations[i]) > tol
        if is_boundary:
            done_marks[i] = True
    done_marks[T - 1] = True

    end_idxs = np.nonzero(done_marks)[0]
    start_idxs = np.concatenate([[0], end_idxs[:-1] + 1])

    trajectories = []
    for s, e in zip(start_idxs, end_idxs):
        # Keep all episodes (no length-based filtering)
        trajectories.append(rewards[s : e + 1])

    return trajectories, num_transitions



def split_kitchen_by_dones(dataset):
    rewards = dataset["rewards"]
    dones = dataset["dones"].astype(bool)

    if dones.size == 0:
        return []

    # Use rising edges to avoid consecutive 1-step fragments
    rising_edges = dones & np.concatenate([np.array([True]), ~dones[:-1]])
    ends = np.nonzero(rising_edges)[0]
    starts = np.concatenate([[0], ends[:-1] + 1])

    trajectories = []
    for s, e in zip(starts, ends):
        # Keep all episodes (no length-based filtering)
        trajectories.append(rewards[s : e + 1])

    return trajectories


def split_by_observation_mismatch(dataset, tol: float = 1e-6):
    """Split concatenated dataset into trajectories by obs/next_obs mismatch.

    A transition i ends an episode if ||obs[i+1] - next_obs[i]|| > tol.
    Returns list of reward arrays and number of transitions.
    """
    observations = dataset["observations"]
    next_observations = dataset["next_observations"]
    rewards = dataset["rewards"]

    T = len(rewards)
    if T == 0:
        return [], 0

    done_marks = np.zeros(T, dtype=bool)
    for i in range(T - 1):
        is_boundary = np.linalg.norm(observations[i + 1] - next_observations[i]) > tol
        if is_boundary:
            done_marks[i] = True
    done_marks[T - 1] = True

    end_idxs = np.nonzero(done_marks)[0]
    start_idxs = np.concatenate([[0], end_idxs[:-1] + 1])

    trajectories = [rewards[s : e + 1] for s, e in zip(start_idxs, end_idxs)]
    return trajectories, T


def compute_mujoco_normalized_scores(dataset, env):
    rewards = dataset["rewards"]
    dones = dataset["dones"].astype(bool)

    scores = []
    cumulative_return = 0.0
    steps_in_current_traj = 0

    for r, done in zip(rewards, dones):
        cumulative_return += r
        steps_in_current_traj += 1
        if done:
            scores.append(env.get_normalized_score(cumulative_return))
            cumulative_return = 0.0
            steps_in_current_traj = 0

    # Handle partial tail
    if steps_in_current_traj > 0:
        scores.append(env.get_normalized_score(cumulative_return))

    return scores



def evaluate_env(env_name: str):
    env_name_lower = env_name.lower()

    # Adroit (binary) datasets use custom loader
    try:
        env_type = get_env_type(env_name)
    except Exception:
        env_type = None
    if env_type == "adroit-binary":
        dataset = get_hand_dataset_with_mc_calculation(
            env_name,
            gamma=FLAGS.discount,
            clip_action=FLAGS.clip_action,
        )
        trajectories, num_transitions = split_by_observation_mismatch(
            dataset, tol=FLAGS.antmaze_tol
        )
        num_trajs = len(trajectories)
        # Success per trajectory if any reward == 0
        scores = [1.0 if np.any(traj == 0.0) else 0.0 for traj in trajectories]
        score_type = "success rate"
        return num_transitions, num_trajs, scores, score_type

    if "antmaze" in env_name_lower:
        # AntMaze: boundary-based splitting; success if any reward == 1 in a traj
        trajectories, num_transitions = split_antmaze_by_boundaries(
            env_name, tol=FLAGS.antmaze_tol
        )
        num_trajs = len(trajectories)
        # 0/1 success per trajectory
        scores = [1.0 if np.any(traj == 1) else 0.0 for traj in trajectories]
        score_type = "success rate"
        return num_transitions, num_trajs, scores, score_type

    if "kitchen" in env_name_lower:
        # Kitchen: split by rising-edge dones from get_d4rl_dataset; score=max(rew)/4
        dataset = get_d4rl_dataset(
            env_name,
            clip_action=FLAGS.clip_action,
        )
        num_transitions = dataset["actions"].shape[0]
        trajectories = split_kitchen_by_dones(dataset)
        num_trajs = len(trajectories)
        if num_trajs == 0:
            scores = []
        else:
            max_rewards = [float(np.max(traj)) for traj in trajectories]
            scores = list(np.clip(np.array(max_rewards) / 4.0, 0.0, 1.0))
        score_type = "success rate"
        return num_transitions, num_trajs, scores, score_type

    # Default: MuJoCo-style datasets (normalized return)
    dataset = get_d4rl_dataset(
        env_name,
        clip_action=FLAGS.clip_action,
    )
    num_transitions = dataset["actions"].shape[0]
    env = gym.make(env_name).unwrapped
    scores = compute_mujoco_normalized_scores(dataset, env)
    num_trajs = len(scores)
    score_type = "normalized score"
    return num_transitions, num_trajs, scores, score_type



def main(_):
    env_names = FLAGS.envs if FLAGS.envs else ([FLAGS.env] if FLAGS.env else DEFAULT_ENVS)

    rows = []
    for env_name in env_names:
        num_transitions, num_trajs, scores, score_type = evaluate_env(env_name)

        mean_score = float(np.mean(scores)) if len(scores) > 0 else float("nan")
        std_score = float(np.std(scores)) if len(scores) > 0 else float("nan")

        print("\n============================================================")
        print(f"Environment: {env_name}")
        print(f"Number of transitions: {num_transitions}")
        print(f"Number of trajectories: {num_trajs}")
        print(
            f"Score ({score_type}) - mean: {mean_score:.4f}, std: {std_score:.4f}"
            if np.isfinite(mean_score)
            else f"Score ({score_type}) - mean: nan, std: nan"
        )

        rows.append([
            env_name,
            int(num_transitions),
            int(num_trajs),
            mean_score,
            std_score,
        ])

    # Write CSV
    with open(FLAGS.output_csv, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "env_name",
            "number_of_transitions",
            "number_of_trajectories",
            "mean_of_score",
            "std_of_score",
        ])
        writer.writerows(rows)
    print(f"\nSaved summary CSV to: {FLAGS.output_csv}")


if __name__ == "__main__":
    app.run(main)


