from os import path
from simple_rl.run_experiments import run_agents_on_mdp
from simple_rl.tasks import GridWorldMDP
from simple_rl.agents import QLearningAgent

import context
import rlang
from rlang.agents import RLangQLearningAgent
import openai
import sys
import json
import pandas as pd
import shutil
import time

sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
from utils import translate_named, read_named_examples, unroll, parse_vocab, read_input, save_translation_results

# sys.path.append('/Users/katieta/code/')
# sys.path.insert(0, '/Users/katieta/code/nl2rlang')

script_dir = path.dirname(__file__)

def create_mdp():
    # MDP parameters
    width, height = 6, 6
    lava_locs = [(3, 2), (1, 4), (2, 4), (2, 5)]
    walls = [(3, 1)]
    goal_locs = [(5, 1)]

    mdp = GridWorldMDP(width, height, walls=walls, lava_locs=lava_locs, goal_locs=goal_locs, slip_prob=0.0,
                       step_cost=0)
    states = list()
    for w in range(mdp.width):
        for h in range(mdp.height):
            states.append((w, h))

    return mdp, states


def simple_experiment():
    mdp, states = create_mdp()
    agent = QLearningAgent(mdp.get_actions())
    run_agents_on_mdp([agent], mdp)


def rlang_experiment():
    # We need to know these MDP and Q Learning parameters
    mdp, states = create_mdp()

    # Parse RLang program into knowledge object
    knowledge = rlang.parse_file("enhanced_gridworld.rlang")

    # Create a baseline Q-Learning agent
    agent = QLearningAgent(mdp.get_actions())

    # Create RLang Q-Learning agent
    rlang_agent = RLangQLearningAgent(
        actions=mdp.get_actions(), states=states, knowledge=knowledge)

    # Compare performance of agents on mdp
    run_agents_on_mdp([agent, rlang_agent], mdp)    


# def read_input(input_file_name):
#     with open(input_file_name, 'r') as f:
#         return f.readlines()

# def parse_vocab(path_to_rlang_file):
#     with open(path.join(script_dir, path_to_rlang_file), 'r') as f:
#         lines = f.readlines()
#         vocab = {}
        
#         for line in lines:
#             split = line.split()
#             if len(split) > 1 and split[0] in set(['Factor', 'Action', 'Proposition', 'Feature', 'MarkovFeature', ]):
#                 print(split[1])
#                 if split[0] in vocab.keys():
#                     vocab[split[0]].append(split[1])
#                 else:
#                     vocab[split[0]] = [split[1]]

#         final_vocab = []
#         print(vocab)
#         for k in vocab.keys():
#             final_vocab = final_vocab + vocab[k]

#     return ', '.join(final_vocab)

def translate_input(user_input, vocab):
    # QUESTION:  how do we determine if we should pass effect vs. policy examples?
    examples_df = read_named_examples(
        path.join(script_dir,'data/effect_english.txt'), 
        path.join(script_dir, 'data/effect_names.txt'),
        path.join(script_dir, 'data/effect_rlang.txt')
    )

    return translate_named(user_input, examples_df, vocab)[1]

def append_to_rlang(phrase):
    shutil.copy('gridworld.rlang', 'enhanced_gridworld.rlang')
    with open(path.join(script_dir, 'enhanced_gridworld.rlang'), 'a') as out_f:
        out_f.write(f'\n{unroll(phrase)}\n')

# def save_translation_results(translation_results):
#     results_csv_path = path.join(script_dir, 'translations.csv')
#     columns = ['english','rlang_translation']

#     if path.exists(results_csv_path):
#        rlang_translations_df = pd.read_csv(results_csv_path, header=0)
#     else:
#        rlang_translations_df = pd.DataFrame(columns=columns)

#     new_row = pd.DataFrame(translation_results, columns=columns)
#     rlang_translations_df = pd.concat([rlang_translations_df, new_row])
#     rlang_translations_df.to_csv(results_csv_path)

def main(argv):
    if len(argv) != 2:
        print('Invalid number of arguments')
        print(f'Expected input: `python run_gridworld.py <PATH_TO_ENGLISH_INPUT>`')
        return
    elif not argv[1].lower().endswith('.txt'):
        print(f'Invalid argument "{argv[1]}"". Input file must be .txt file')
        return

    user_input = read_input(path.join(script_dir, argv[1]))
    vocab = parse_vocab('gridworld.rlang')
    print('vocab?', vocab)

    translation_results = []
    for input_line in user_input:
        t = translate_input(input_line, vocab)
        print('translation? ', t)
        append_to_rlang(t)
        translation_results.append([input_line, t])
    
    save_translation_results(translation_results)

    # TODO: save experiment per input?
    rlang_experiment()


if __name__ == '__main__':
    main(sys.argv)