from datasets import load_dataset, concatenate_datasets
import os
import sys
sys.path.append(".")
from source.language_models.prompt_templates import prompt_policy_minigrid, prompt_sequence_minigrid
import glob
import random

def prompt_map_minigrid(sample, goal, sequence):
    if sequence:
        sample["prompt"] = prompt_sequence_minigrid((22,22), sample["state"], goal)
    else:
        sample["prompt"] = prompt_policy_minigrid((22,22), sample["state"], goal)
    sample["completion"] = str(sample["action"]).replace(" ", "").replace("[","").replace("]","").strip() + "<|eot_id|>"
    return sample



def load_dataset_multi_env_minigrid(dataset_folder, single_env=True, frac_test_env=0.1, frac_test_goal=0.1, fraction=0.99999, fraction_goal =1,  obs=False, sequence=False, seed=0, split_goal=True, frac_test_sample=None):
    random.seed(seed)
    list_dataset = []
    list_val_dataset = []
    
    list_env = [x.split("/")[-2] for x in glob.glob(f"{dataset_folder}/env/*/")]
    goals_per_env = {}
    
    for env_id in list_env:
        goals_per_env[env_id.split("/")[-1]] = [x.split("/")[-1] for x in glob.glob(f"{dataset_folder}/env/{env_id}/*") if os.path.exists(x +"/sequence.csv")]

    random.shuffle(list_env)

    if single_env:
        train_envs = list_env[:1]
        test_envs = list_env[:1]
    else:
        test_envs = list_env[:int(frac_test_env*len(list_env))]
        train_envs = list_env[int(frac_test_env*len(list_env)):]
        
    print("train envs : ", train_envs)
    print("test envs : ", test_envs)

    test_goals = []
    train_goals = []
    for env_id in test_envs:
        test_goals += goals_per_env[env_id]
    for env_id in train_envs:
        train_goals += goals_per_env[env_id]
    
    
    random.shuffle(train_goals)
    
    goals = train_goals[:int(fraction_goal*len(train_goals))]
    
    
    test_goals = goals[:max(int(frac_test_goal*len(goals)),1)]
    train_goals = goals[max(int(frac_test_goal*len(goals)),1):]
    


    print("train goals : ", len(train_goals))
    print("test goals : ", len(test_goals))
    
    
    for env_id in train_envs:
        for goal in goals_per_env[env_id]:
            if goal in train_goals:
                if obs:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/demonstration_policy_obs.csv"
                elif sequence:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/sequence.csv"
                else:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/Q_policy.csv"
                try:
                    
                    data = load_dataset(
                        "csv", data_files=path, 
                        cache_dir= "/tmp/cache"
                    )
                    formulations = [goal]

                    for formulation in formulations: 
                        list_dataset.append(data.map(lambda x : prompt_map_minigrid(x, formulation, sequence), batched=False)["train"])
                except Exception as e:
                    print("path : ", path)
                    print(e)
    
    for env_id in test_envs:
        for goal in goals_per_env[env_id]:
            if goal in test_goals:
                if obs:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/demonstration_policy_obs.csv"
                elif sequence:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/sequence.csv"
                else:
                    path = f"{dataset_folder}/env/{env_id}/{goal}/Q_policy.csv"
                try:
                    data = load_dataset(
                        "csv", data_files=path, 
                        cache_dir= "/tmp/cache"
                    )
                    formulations = [goal]
               
                    for formulation in formulations: 
                        list_val_dataset.append(data.map(lambda x : prompt_map_minigrid(x, formulation, sequence), batched=False)["train"])
                except Exception as e:
                    print("path : ", path)
                    print(e)

            
    train_data =concatenate_datasets(list_dataset).shuffle(seed=seed).remove_columns("state").remove_columns("action")
    val_data = concatenate_datasets(list_val_dataset).shuffle(seed=seed).remove_columns("state").remove_columns("action")
    if fraction != 1:
        return train_data.train_test_split(fraction)["test"], val_data.train_test_split(fraction)["test"]
    
    if (frac_test_sample is not None) and (frac_test_sample < 1):
        dataset = train_data.train_test_split(frac_test_sample)
        return dataset["train"], dataset["test"]
    else:
        return train_data, val_data




