from vllm.distributed.parallel_state import destroy_model_parallel
import torch
import gc
import sys
sys.path.append(".")
from source.abstraction.abstraction_function import LLM_abstraction
import pickle 
import dill


if __name__ == "__main__":

    with open('./cache/dict_env_id.pkl', 'rb') as handle:
        dict_env_id = pickle.load(handle)

    abstraction_builder = LLM_abstraction()

    abstract_function = {}

    for env_id in dict_env_id:
            abstraction_builder.generate_new_abstraction_function(dict_env_id[env_id]["obs"], env_id, dict_env_id[env_id]["goal"])
        
    destroy_model_parallel()
    del abstraction_builder.llm.model.llm_engine.model_executor.driver_worker
    del abstraction_builder.llm.model

    abstraction_builder.llm = None

    with open ('./cache/dict_abstract.pkl', 'wb') as handle:
        dill.dump(abstraction_builder, handle)

    del abstraction_builder
    gc.collect()
    torch.cuda.empty_cache()
    torch.distributed.destroy_process_group()
    print("LLM abstraction done")