import os
import pandas as pd
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

def extract_tensorboard_data(path_to_events_file):
    # Load the TensorBoard event file
    event_acc = EventAccumulator(path_to_events_file)
    event_acc.Reload()  # Loads the file

    # Extract and convert to DataFrame
    data = {}
    for tag in event_acc.Tags()["scalars"]:
        events = event_acc.Scalars(tag)
        data[tag] = pd.DataFrame(events)[["wall_time", "step", "value"]].rename(columns={"value": tag})
    
    return pd.concat(data.values(), axis=1)

# log_paths = {
#     "0.9 expectile": "expectile_exp/tb/am-l-p-0.9",
#     "0.7 expectile": "expectile_exp/tb/am-l-p-0.7",
#     "0.5 expectile": "expectile_exp/tb/am-l-p-0.5",
#     "iql": "../old-iql/tmp_ant/tb/antmaze-large-play-v2-1"
# }

log_paths = {
    "iql": "../old-iql/tmp_kitchen/tb/kitchen-partial-v0",
    "0.9 expectile": "tmp_kitchen/tb/kitchen-partial-v0-0.9",
    "0.7 expectile": "tmp_kitchen/tb/kitchen-partial-v0-0.7",
    "0.5 expectile": "tmp_kitchen/tb/kitchen-partial-v0-0.5",
    "0.3 expectile": "tmp_kitchen/tb/kitchen-partial-v0-0.3"
}

# log_paths = {
#     "iql": "../old-iql/tmp_locomotion/tb/walker2d-medium-replay-v2",
#     "0.9 expectile": "tmp_locomotion/tb/walker2d-medium-replay-v2-0.9",
#     "0.7 expectile": "tmp_locomotion/tb/walker2d-medium-replay-v2-0.7",
#     "0.5 expectile": "tmp_locomotion/tb/walker2d-medium-replay-v2-0.5",
# }
def a():
    plt.figure(figsize=(10, 5))

    for exp_name, path in log_paths.items():
        # Assuming all logs are in the same sub-directory structure
        event_file = next((os.path.join(path, f) for f in os.listdir(path) if 'events.out.tfevents' in f), None)
        print(event_file)
        if event_file:
            df = extract_tensorboard_data(event_file)
            print(df.head())
            # plt.plot(df.iloc[:, 1], df['training/q1'], label=exp_name)
            plt.plot(df.iloc[:, 1], df['evaluation/average_returns'], label=exp_name)

    # plt.title('Comparison of Q-values')
    plt.xlabel('Steps')
    plt.ylabel('Q-value')
    plt.legend()
    plt.grid(True)
    plt.savefig('figures/q_values_comparison.png')

def b():
    all_x = []
    all_y = []

    for exp_name, path in log_paths.items():
        # Assuming all logs are in the same sub-directory structure
        event_file = next((os.path.join(path, f) for f in os.listdir(path) if 'events.out.tfevents' in f), None)
        if event_file:
            # Extract data from TensorBoard event file
            df = extract_tensorboard_data(event_file)
            
            # Assuming the 'Steps' are in the second column and 'evaluation/average_returns' in another
            x_values = df.iloc[:, -5].dropna().tolist()  # Convert column to list
            y_values = df['evaluation/average_returns'].dropna().tolist()  # Convert column to list
            # print(df["evaluation/average_returns"])
            
            # Append extracted data to the all_x and all_y lists
            all_x.append(x_values)
            all_y.append(y_values)
            
            print(x_values)
            print(y_values)
    all_x = [
       [100000.0, 200000.0, 300000.0, 400000.0, 500000.0, 600000.0, 700000.0, 800000.0, 900000.0, 1000000.0, 1100000.0, 1200000.0, 1300000.0, 1400000.0, 1500000.0],
       [100000.0, 200000.0, 300000.0, 400000.0, 500000.0, 600000.0, 700000.0, 800000.0, 900000.0, 1000000.0, 1100000.0, 1200000.0, 1300000.0, 1400000.0, 1500000.0],
       [100000.0, 200000.0, 300000.0, 400000.0, 500000.0, 600000.0, 700000.0, 800000.0, 900000.0, 1000000.0, 1100000.0, 1200000.0, 1300000.0, 1400000.0, 1500000.0],
       [100000.0, 200000.0, 300000.0, 400000.0, 500000.0, 600000.0, 700000.0, 800000.0, 900000.0, 1000000.0, 1100000.0, 1200000.0, 1300000.0, 1400000.0, 1500000.0],
       [100000.0, 200000.0, 300000.0, 400000.0, 500000.0, 600000.0, 700000.0, 800000.0, 900000.0, 1000000.0, 1100000.0, 1200000.0, 1300000.0, 1400000.0, 1500000.0]
    ]

    all_y = [
       [18.5, 18.0, 45.0, 43.5, 39.5, 40.0, 49.5, 53.5, 60.0, 63.5, 60.0, 64.5, 61.5, 64.0, 64.0],
       [36.5, 40.75, 44.5, 54.25, 53.0, 60.5, 63.75, 64.25, 56.25, 60.0, 62.25, 65.75, 67.75, 67.75, 63.25],
       [18.5, 30.0, 60.0, 60.25, 62.75, 63.5, 66.25, 65.25, 58.25, 67.5, 63.0, 66.75, 69.75, 68.5, 68.25],
       [17.0, 22.75, 49.25, 47.25, 40.25, 57.0, 66.0, 68.25, 69.0, 68.0, 63.3, 67.0, 70.75, 70.0, 69.25],
       [8.75, 19.5, 17.0, 56.0, 60.5, 72.75, 70.5, 69.0, 73.25, 70.25, 69.75, 71.5, 73.0, 73.5, 70.0]
    ]

    log_keys = list(log_paths.keys())
    for i, (x, y) in enumerate(zip(all_x, all_y)):
        plt.plot(x, y, label=log_keys[i])  # Assumes Python 3.7+ for ordered dict keys

    # plt.title('Comparison of Evaluation Rewards')
    plt.xlabel('Steps')
    plt.ylabel('Averaged Rewards')
    plt.legend()
    plt.grid(True)
    plt.savefig('figures/q_values_comparison.png')
    plt.show()

b()
