import os
import random
from datasets import Dataset
from scienceworld import ScienceWorldEnv

def _load_train_data(env):
    ############################
    ## filter agentboard data ##
    ############################
    remove_data = {}
    with open("./data_in_agentboard.log", 'r', encoding='utf8') as f:
        for line in f:
            task_name = line.split("task_name:")[-1].split(", var:")[0].strip()
            var = int(line.split("var:")[-1].split(", Your task is to")[0].strip())
            if task_name not in remove_data:
                remove_data[task_name] = [var]
            else:
                remove_data[task_name].append(var)

    task_names = env.get_task_names()
    new_data = {}
    total_num = 0
    remove_num = 0
    train_actor_task = []
    for task_num in range(30):
        task_name = task_names[task_num]
        # print(task_name)
        env.load(task_name)
        variations = env.get_variations_train()
        test_len = min(100, len(variations))
        random.seed(1)
        random.shuffle(variations)
        variations = variations[:test_len]
        total_num += len(variations)

        new_data[task_name] = []
        for var in variations:
            if (task_name in remove_data) and (var in remove_data[task_name]):
                remove_num += 1
            else:
                new_data[task_name].append({
                    "var": var,
                    "used_for_actor": True
                })
    print(total_num)
    print("remove num: ", remove_num)
    return new_data


def _load_dev_data(env):
    task_names = env.get_task_names()
    new_data = {}
    total_num = 0
    for task_num in range(30):
        task_name = task_names[task_num]
        env.load(task_name)
        variations = env.get_variations_dev()
        test_len = min(10, len(variations))
        random.seed(1)
        random.shuffle(variations)
        variations = variations[:test_len]
        total_num += len(variations)

        new_data[task_name] = []
        for var in variations:
            new_data[task_name].append({
                "var": var,
                "used_for_actor": False
            })

    print(total_num)
    return new_data


def main():
    env = ScienceWorldEnv("", None, envStepLimit=200)

    train_data_index = _load_train_data(env)
    dev_data_index = _load_dev_data(env)

    def _clean_obs(s):
        clean_toks = ['\n', '\t']
        for tok in clean_toks:
            s = s.replace(tok, ' ')
        return s
    
    def _create_data(data_index):
        data = []
        idx = 0
        for task_name, task_list in data_index.items():
            for item in task_list:
                var = item["var"]
                data.append({
                    "data_source": "sciworld",
                    "prompt": [{"content": "prompt_with_chat_template"}],
                    "extra_info": {
                        "split": "train", 
                        "index":idx, 
                        "task_name": task_name, 
                        "uid": f"{task_name}_{var}",
                        "env_name": "sciworld",
                        "used_for_actor": item["used_for_actor"],
                        "var": item["var"]
                    }
                })
                idx += 1
        return data

    
    train_data = _create_data(train_data_index)
    dev_data = _create_data(dev_data_index)
   
    train_dataset = Dataset.from_list(train_data)
    test_dataset = Dataset.from_list(dev_data)

    def make_map_fn(split):
        def process_fn(example, idx):
            return example
        return process_fn
    
    train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
    test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)

    output_path = "../../../data/sciworld"
    train_dataset.to_parquet(os.path.join(output_path, 'train_remove_agentboard.parquet'))
    test_dataset.to_parquet(os.path.join(output_path, 'test.parquet'))

if __name__ == "__main__":
    main()
    









