import argparse
import json
import numpy as np
import h5py
import gym
import d4rl

# Import wrappers directly
from .risky_wrappers import RewardHighVelocity, RewardUnhealthyPose

class DatasetWriter:
    """
    Simple HDF5 writer to store transitions with risk info.
    """
    def __init__(self, filename, buffer_size=1_000_000):
        self.filename = filename
        self.buffer_size = buffer_size
        self.count = 0
        # We'll store data in memory first, then write out at close.
        self.data = {
            'observations': [],
            'actions': [],
            'rewards': [],
            'next_observations': [],
            'terminals': [],
            'risky_states': [],
        }

    def append(self, obs, act, rew, next_obs, done, risky):
        self.data['observations'].append(obs)
        self.data['actions'].append(act)
        self.data['rewards'].append(rew)
        self.data['next_observations'].append(next_obs)
        self.data['terminals'].append(done)
        self.data['risky_states'].append(risky)
        self.count += 1

    def close(self):
        print(f"Total transitions to store: {self.count}")
        with h5py.File(self.filename, 'w') as hf:
            hf.create_dataset('observations', data=np.array(self.data['observations']), compression="gzip")
            hf.create_dataset('actions', data=np.array(self.data['actions']), compression="gzip")
            hf.create_dataset('rewards', data=np.array(self.data['rewards']), compression="gzip")
            hf.create_dataset('next_observations', data=np.array(self.data['next_observations']), compression="gzip")
            hf.create_dataset('terminals', data=np.array(self.data['terminals']), compression="gzip")
            hf.create_dataset('risky_states', data=np.array(self.data['risky_states']), compression="gzip")
        print(f"Saved dataset to {self.filename}")


def apply_wrappers(env, env_name,wrapper_configs):
    """
    wrapper_configs: List[{'type': 'RewardHighVelocity' or 'RewardUnhealthyPose', ...}]
    """
    for cfg in wrapper_configs:
        wtype = cfg.get('type', None)
        if wtype == 'RewardHighVelocity':
            env = RewardHighVelocity(
                env,
                env_name,
                max_vel=cfg.get('max_vel', 2.0),
                prob_vel_penal=cfg.get('prob_vel_penal', 0.3),
                cost_vel=cfg.get('cost_vel', -5.0)
            )
        elif wtype == 'RewardUnhealthyPose':
            # env= RewardHighVelocity(
            #     env,
            #     max_vel=cfg.get('max_vel', 2.0),
            #     prob_vel_penal=cfg.get('prob_vel_penal', 0.3),
            #     cost_vel=cfg.get('cost_vel', -5.0)
            # )
            env = RewardUnhealthyPose(
                env,
                prob_pose_penal=cfg.get('prob_pose_penal', 0.3),
                cost_pose=cfg.get('cost_pose', -10.0),
                healthy_angle_range=tuple(cfg.get('healthy_angle_range', [-0.5, 0.5])),
                done_if_exceed_factor =cfg.get('done_if_exceed_factor',2.0)
            )
        else:
            print(f"[Warning] Unknown wrapper type: {wtype}")
    return env


def create_risky_dataset_from_config(config_path, max_traj=-1):
    """
    Read env_name, output_file, wrappers from a JSON file and
    generate an HDF5 by re-evaluating D4RL data with risk-penalized rewards.
    """
    with open(config_path, 'r') as f:
        config = json.load(f)

    env_name = config['env_name']
    output_file = config['output_file']
    wrapper_configs = config.get('wrappers', [])

    base_env = gym.make(env_name)
    dataset = d4rl.qlearning_dataset(base_env)

    obs = dataset['observations']
    acts = dataset['actions']
    dones = dataset['terminals']

    # Create wrapped environment
    env = gym.make(env_name)
    env = apply_wrappers(env,env_name, wrapper_configs)

    writer = DatasetWriter(filename=output_file)

    num_transitions = len(obs)
    total_trajectories = 0
    prev_done = True

    for i in range(num_transitions):
        if prev_done:
            env.reset()
            total_trajectories += 1
            if (max_traj > 0) and (total_trajectories > max_traj):
                print("Reached the limit of trajectories.")
                break

        state_d4rl = obs[i]
        action_d4rl = acts[i]
        done_d4rl = dones[i]

        # Step the environment once
        next_state_env, rew_risky, done_env, info = env.step(action_d4rl)
        risky_flag = 1 if info.get('risky_state', False) else 0

        writer.append(
            state_d4rl,
            action_d4rl,
            rew_risky,
            next_state_env,
            done_d4rl,
            risky_flag
        )

        prev_done = bool(done_d4rl)

    writer.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True,
                        help='Path to a JSON config file that contains env_name, output_file, wrappers, etc.')
    parser.add_argument('--max_traj', type=int, default=-1,
                        help='If > 0, limit the number of trajectories to process')

    args = parser.parse_args()

    create_risky_dataset_from_config(
        config_path=args.config,
        max_traj=args.max_traj
    )