import os
from vagen.env import REGISTERED_ENV
import numpy as np
import yaml
import argparse
from datasets import Dataset, load_dataset
from vagen.env.utils.env_utils import permanent_seed
def create_dataset_from_yaml(yaml_file_path: str, force_gen=False,seed=42,train_path='./train.parquet',test_path='./test.parquet'):
    """
    Create dataset from a YAML configuration file.
    
    Args:
        yaml_file_path (str): Path to the YAML configuration file
        force_gen (bool): Whether to force regeneration of existing datasets
        seed (int): Seed for random number generation
        train_path (str): Path to save the training dataset
        test_path (str): Path to save the testing dataset
        
    The YAML file should have the following structure:
    ```
    env1:
        env_name: sokoban  # or frozenlake
        env_config:
            # parameters to override the default env config
        train_size: 100  # number of instances
        test_size:100
    env2:
        env_name: frozenlake
        env_config:
            # parameters to override the default env config
        train_size: 100  # number of instances
        test_size:100
    ```
    
    If the environment config class (e.g., SokobanEnvConfig, FrozenLakeEnvConfig) has a 
    generate_seeds(size) method, it will be used to generate seeds for that environment.
    """
    
    if isinstance(yaml_file_path, str):
        with open(yaml_file_path, 'r') as f:
            yaml_config = yaml.safe_load(f)
    else:
        yaml_config = yaml_file_path
    
    os.makedirs(os.path.dirname(train_path), exist_ok=True)
    os.makedirs(os.path.dirname(test_path), exist_ok=True)
    
    if not force_gen and os.path.exists(train_path) and os.path.exists(test_path):
        print(f"Dataset files already exist at {train_path} and {test_path}. Skipping generation.")
        print(f"Use --force-gen to override and regenerate the dataset.")
        return train_path, test_path
    
    
    train_instances = []
    test_instances = []
    
    global_seed = seed
    permanent_seed(global_seed)
    
    
    for key, value in yaml_config.items():
        env_name = value.get('env_name')
        custom_env_config = value.get('env_config', {})
        train_size,test_size = (value.get('train_size', 100), value.get('test_size', 100))
       
        
        env_config = REGISTERED_ENV[env_name]["config_cls"](**custom_env_config)
        seeds_for_env_train = None
        seeds_for_env_test = None
        if hasattr(env_config, 'generate_seeds'):
            seeds_for_env_train = env_config.generate_seeds(train_size)
            seeds_for_env_test = env_config.generate_seeds(test_size)
            print(f"Using {len(seeds_for_env_train)} trian seeds generated by {env_name} config's generate_seeds method")
            print(f"Using {len(seeds_for_env_test)} test seeds generated by {env_name} config's generate_seeds method")
        else:
            seeds_for_env_train = np.random.randint(0, 2**31 - 1, size=train_size).tolist()
            seeds_for_env_test = np.random.randint(0, 2**31 - 1, size=test_size).tolist()
        for seed in seeds_for_env_train:
            env_settings = {
                'env_name': env_name,
                'env_config': custom_env_config,
                'seed': seed
            }
            instance = {
                "data_source": env_name,
                "prompt": [{"role": "user", "content": ''}],
                "extra_info": {"split": "train", **env_settings}
            }
            train_instances.append(instance)
        for seed in seeds_for_env_test:
            env_settings = {
                'env_name': env_name,
                'env_config': custom_env_config,
                'seed': seed
            }
            instance = {
                "data_source": env_name,
                "prompt": [{"role": "user", "content": ''}],
                "extra_info": {"split": "test", **env_settings}
            }
            test_instances.append(instance)
            
    
    def make_map_fn(split):
        def process_fn(example, idx):
            return example
        return process_fn
        
    # Create datasets
    if train_instances:
        train_dataset = Dataset.from_list(train_instances)
        train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
        train_dataset.to_parquet(train_path)
        print(f"Train dataset with {len(train_instances)} instances saved to {train_path}")
    
    if test_instances:
        test_dataset = Dataset.from_list(test_instances)
        test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
        test_dataset.to_parquet(test_path)
        print(f"Test dataset with {len(test_instances)} instances saved to {test_path}")
    
    if not train_instances and not test_instances:
        print("No instances were generated. Check your YAML configuration.")
    
    return train_path, test_path
        
        
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--yaml_path", type=str, required=True, help="Path to YAML configuration file")
    parser.add_argument("--force_gen", action="store_true", help="Force regenerate dataset even if exists")
    parser.add_argument("--train_path", type=str, default="./train.parquet", help="Path to save the training dataset")
    parser.add_argument("--test_path", type=str, default="./test.parquet", help="Path to save the testing dataset")
    parser.add_argument("--seed", type=int, default=42, help="Seed for random number generation")
    args = parser.parse_args()
    print(args)
    train_path, test_path = create_dataset_from_yaml(args.yaml_path, args.force_gen, args.seed, args.train_path, args.test_path)
    
    # Optionally load the dataset and print examples
    train_dataset = load_dataset('parquet', data_files={"train": train_path}, split="train")
    test_dataset = load_dataset('parquet', data_files={"test": test_path}, split="test")
    for i in range(2):
        print(train_dataset[i])
        env_name = train_dataset[i]["extra_info"]["env_name"]
        env_config_cls = REGISTERED_ENV[env_name]["config_cls"]
        env_config= env_config_cls(**train_dataset[i]["extra_info"]["env_config"])
        print(env_config.config_id())
    for i in range(2):
        print(train_dataset[i])
        env_name = test_dataset[i]["extra_info"]["env_name"]
        env_config_cls = REGISTERED_ENV[env_name]["config_cls"]
        env_config= env_config_cls(**test_dataset[i]["extra_info"]["env_config"])
        print(env_config.config_id())