import sys
sys.path.append(".")
from source.policy.fit_Q_values_minigrid import tabular_Q_belman
import numpy as np
import argparse
import os
import pandas as pd 

data_folder = "./data/datasets/"

def improve_policy():
    parser = argparse.ArgumentParser(description="Run the experiment")
    parser.add_argument("--dataset", type=str, help="The name of the dataset")
    parser.add_argument("--goal", type=str)
    parser.add_argument("--env_id", type=str, default="0", help="The name of the dataset")
    parser.add_argument("--device", type=int, default=0)
    args = parser.parse_args()

    dataset = args.dataset
    goal = args.goal

    path = f"{data_folder}{dataset}/env/{args.env_id}/{goal}/Q_policy.csv"

    print("Improving policy for ", goal, " in ", dataset)

    if True or not os.path.exists(path):
             
        Q, list_state, list_action = tabular_Q_belman(dataset, 
            goal, args.env_id, gamma=0.9, alpha=0.1, max_epochs=4000, verbose=True, device = args.device
        )

        Q_star_policy = []

            
        for i in range(Q.shape[0]):
            s = list_state[i]
            a = list_action[np.argmax(Q[i])]
            
            Q_star_policy.append({"state":s,"action":a})
        
        pd.DataFrame(Q_star_policy).to_csv(path)


        print(f"Policy {goal} created for {dataset}")



if __name__ == "__main__":
    improve_policy()
