import numpy as np
import time
import sys
import os
#import deep_sea_treasure
#from deep_sea_treasure import DeepSeaTreasureV0
#from deep_sea_treasure import FuelWrapper
filepath = os.path.join(os.getcwd(), "deep_sea_treasure_v2")
if filepath not in sys.path:
    sys.path.append(filepath)
from deep_sea_treasure_v2 import DeepSeaTreasureV0, FuelWrapper
import pandas as pd
import json

def read_paretos(file):
    f = open(file)
    data = json.load(f)['pareto_front']
    pareto_front = {}
    j = 0
    rew_matrix = pd.DataFrame(columns=["Treasure", "Time", "Fuel"])
    for i in data:
        rew_matrix.loc[j] = [i['treasure'], i['time'], i['fuel']]
        pareto_front[j] = i['action_sequence']
        j+=1
    f.close()
    return pareto_front, rew_matrix
    
def init_dst():
    # Make sure experiment are reproducible, so people can use the exact same versions
    #print(f"Using DST {deep_sea_treasure_v2.__version__.VERSION} ({deep_sea_treasure_v2.__version__.COMMIT_HASH})")

    dst: DeepSeaTreasureV0 =  FuelWrapper.new(DeepSeaTreasureV0.new(
        max_steps=50,
        render_treasure_values=True,
        max_velocity=4.0,
        implicit_collision_constraint=False,
        render_grid=True
    ))
    return dst    

def print_results(received, preferred, actions, pareto_front, pareto_rew_matrix):
    print("Actions:", actions)
    labels = ["treasure", "time", "fuel"]
    rews = {'received': received, 'preferred': preferred, 'difference': received-preferred}
    rew_df = pd.DataFrame(data=rews, index=labels)
    pd.set_option("display.precision", 3)
    print("\nDifference between the preferred and received rewards:")
    print(rew_df)
    rew_diff = pareto_rew_matrix['Treasure'] == received[0]
    pareto_sim = pd.DataFrame(data={"Same treasure": rew_diff})
    n = len(actions)
    metric = 0
    metrics = []
    for i in range(len(pareto_front)):
        seq = pareto_front[i]
        for j in range(min(len(seq), n)):
            if seq[j] == actions[j]: 
                metric +=1
        metrics.append(metric/n)
        metric = 0
    pareto_sim["Ratio of same actions"] = metrics
    print("\nDifference to Pareto optimal solutions")
    print(pareto_sim)

def array_to_json(arr):
    return np.where(arr==1)[0][0] - 3

def main():
    # Pareto front and baseline
    pareto_front, pareto_rew_matrix = read_paretos("Pipeline/data/3-objective.json")
    incomplete_pareto_rew_matrix = pareto_rew_matrix.drop(pareto_rew_matrix.tail(1).index,inplace = False)
    baseline = incomplete_pareto_rew_matrix.mean(axis=0).values
    dst = init_dst()
    policy = []
    stop: bool = False
    time_reward: int = 0
    fuel_reward: int = 0
    current_state = np.concatenate((dst.sub_vel, np.array(list(dst.treasures.keys())).T), axis=1)

    running_reward = np.array([0.,0.,0.])
    priority = np.array([1.,1.,1.])
    actions = [(np.array([0.,0.,0.,1.,0.,0.,0.]),np.array([0.,0.,0.,0.,1.,0.,0.])),
               (np.array([0.,0.,0.,0.,1.,0.,0.]),np.array([0.,0.,0.,0.,1.,0.,0.])),
               (np.array([0.,0.,0.,0.,1.,0.,0.]),np.array([0.,0.,0.,0.,1.,0.,0.])),
               (np.array([0.,0.,0.,0.,1.,0.,0.]),np.array([0.,0.,1.,0.,0.,0.,0.])),
               (np.array([0.,0.,0.,1.,0.,0.,0.]),np.array([0.,0.,0.,1.,0.,0.,0.]))]
    action_ind = 0

    print(f"\nHello! Conciliator steering has started.\n")
    # Seek out a policy for each user profile 
    
    #priority = np.array([3,0,0])
    print(f"Priority: {priority}")
    print(f"Baseline: {baseline}")

    dst.render()
    #time.sleep(1)
    while not stop:
        action = actions[action_ind]
        previous_velo = dst.sub_vel.flatten()
        next_state, reward, done, debug_info = dst.step(action)
        action = (array_to_json(action[0]), array_to_json(action[1]))
        policy.append(action)
        
        next_velo = dst.sub_vel.flatten()
        if np.all(next_velo == 0.0) and np.any(previous_velo+np.asarray(action) != next_velo):
            sys.exit(f"Collision occurred with a policy {policy}!")
        current_state = next_state
        time_reward += reward[1]
        fuel_reward += reward[2]
        running_reward += np.array([reward[0],reward[1],reward[2]])
        
        if done:
            received_rews = np.asarray([reward[0], time_reward, fuel_reward])
            time_reward = 0
            fuel_reward = 0
            policy = []
        
        if not stop:
            dst.render()
            action_ind += 1
            time.sleep(1)
        
        if done:
            dst.reset()
            stop = True

if __name__ == "__main__":
    main()