import argparse
import numpy as np
import torch
from dm_control import suite
from tqdm.auto import tqdm

def _generate_mujoco(n_samples, window=24, var_num=14, seed=123):
    env = suite.load("hopper", "stand")
    physics = env.physics

    # Store the state of the RNG to restore later.
    st0 = np.random.get_state()
    np.random.seed(seed)

    data = np.zeros((n_samples, window, var_num))
    for i in range(n_samples):
        with physics.reset_context():
            # x and z positions of the hopper. We want z > 0 for the hopper to stay above ground.
            physics.data.qpos[:2] = np.random.uniform(0, 0.5, size=2)
            physics.data.qpos[2:] = np.random.uniform(
                -2, 2, size=physics.data.qpos[2:].shape
            )
            physics.data.qvel[:] = np.random.uniform(
                -5, 5, size=physics.data.qvel.shape
            )

        for t in range(window):
            data[i, t, : var_num // 2] = physics.data.qpos
            data[i, t, var_num // 2 :] = physics.data.qvel
            physics.step()

    # Restore RNG.
    np.random.set_state(st0)
    return data

def _generate_sine(n_samples, window=24, var_num=14, seed=123):
    """
    Sine data generation.

    Args:
        - n_samples: the number of samples
        - window: sequence length of the time-series
        - var_num: feature dimensions

    Returns:
        - data: generated data
    """ 
    # Store the state of the RNG to restore later.
    st0 = np.random.get_state()
    np.random.seed(seed)

    # Initialize the output
    data = list()
    # Generate sine data
    for i in tqdm(range(0, n_samples), total=n_samples, desc="Sampling sine-dataset"):
        # Initialize each time-series
        temp = list()
        # For each feature
        for k in range(var_num):
            # Randomly drawn frequency and phase
            freq = np.random.uniform(0, 0.1)            
            phase = np.random.uniform(0, 0.1)
        
            # Generate sine signal based on the drawn frequency and phase
            temp_data = [np.sin(freq * j + phase) for j in range(window)]
            temp.append(temp_data)
    
        # Align row/column
        temp = np.transpose(np.asarray(temp))
        # Normalize to [0,1]
        temp = (temp + 1)*0.5
        # Stack the generated data
        data.append(temp)

    # Restore RNG.
    np.random.set_state(st0)
    data = np.array(data)

    return data

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate synthetic MuJoCo data.")

    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mujoco", "sine"],
        required=True,
        help="Dataset to generate: 'mujoco' or 'sine'.",
    )
    parser.add_argument(
        "--n_samples", type=int, required=True, help="Number of samples to generate."
    )
    parser.add_argument(
        "--n_features",
        type=int,
        required=True,
        help="Number of features in each sample.",
    )

    parser.add_argument(
        "--window_size",
        type=int,
        default=None,
        help="Size of the time window for each sample.",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default=None,
        help="Path to save the generated data.",
    )
    parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility."
    )

    args = parser.parse_args()
    if args.save_path is None:
        args.save_path = f"{args.dataset}.pt"

    if args.dataset == "sine":
        data_generator = _generate_sine
    elif args.dataset == "mujoco":
        data_generator = _generate_mujoco
    else:
        raise ValueError("Unsupported dataset. Choose 'mujoco' or 'sine'.")
    
    data = data_generator(
        n_samples=args.n_samples,
        window=args.window_size,
        var_num=args.n_features,
        seed=args.seed,
    )
    data = torch.from_numpy(
        data
    ).float()  # Convert to PyTorch tensor and ensure float type
    torch.save(data, args.save_path)
    print(f"{args.dataset} data saved to {args.save_path}")
