import numpy as np
import pandas as pd
import os

class travel_time_metric():
    def __init__(self, env):
        self.env = env
        self.travel_times = []
        self.enter_time = {}
        
    def update(self, done=False):
        vehicles = self.env.get_vehicles(include_waiting=True)
        current_time = self.env.get_current_time()
        
        for vehicle in vehicles:
            if not vehicle in self.enter_time:
                self.enter_time[vehicle] = current_time
                
        for vehicle in list(self.enter_time.keys()):
            if done or vehicle not in vehicles:
                self.travel_times.append(current_time - self.enter_time[vehicle])
                self.enter_time.pop(vehicle)
                
    def get_travel_time(self):
        return 0.0 if len(self.travel_times) == 0 else np.mean(self.travel_times)
    
    def log_travel_time(self, save_path):
        data = [np.mean(self.travel_times)]
        df = pd.DataFrame(data, columns=["travel time"], dtype=float)
        df.to_csv(os.path.join(save_path, "travel_time.csv"))
        