"""
Logging Utilities

Written by Patrick Coady (pat-coady.github.io)
"""
import os
import csv


class Logger(object):
    """ Simple training logger: saves to file and optionally prints to stdout """
    def __init__(self, parent_folder_name: str, env_name: str, agent_name: str, oracle_name: str, now: str) -> None:
        self._save_path = os.path.join(parent_folder_name, env_name, agent_name, oracle_name, now)
        os.makedirs(self.save_path())
        self.write_header = True
        self.log_entry = {}
        self.f = open(os.path.join(self.save_path(), 'log.csv'), 'w')
        self.writer = None  # DictWriter created with first call to write() method

    def save_path(self) -> str:
        return self._save_path

    def write(self, display=True) -> None:
        """ Write 1 log entry to file, and optionally to stdout
        Log fields preceded by '_' will not be printed to stdout

        Args:
            display: boolean, print to stdout
        """
        if display:
            self.disp(self.log_entry)
        if self.write_header:
            fieldnames = [x for x in self.log_entry.keys()]
            self.writer = csv.DictWriter(self.f, fieldnames=fieldnames)
            self.writer.writeheader()
            self.write_header = False
        self.writer.writerow(self.log_entry)
        self.log_entry = {}

    @staticmethod
    def disp(log) -> None:
        """Print metrics to stdout"""
        log_keys = [k for k in log.keys()]
        log_keys.sort()
        # print episode data
        string_episode = 'Episode {}'.format(log['_Episode'])
        string_return = 'Mean Return = {:.2f} (Std = {:.2f})'.format(log['_AvgRewardSum'],
                                                                     log['_StdRewardSum'])
        string_discounted_return = 'Mean Discounted Return = {:.2f} (Std = {:.2f})'.format(log['_AvgDiscountedRewardSum'],
                                                                                           log['_StdDiscountedRewardSum'])
        print(string_episode + ': ' + string_return + ', ' + string_discounted_return)
        
        if "_lambda_opt" in log:
                string_lambda = 'Dual opt. = {:.4f}, Dual scalar = {:.4f}'.format(log['_lambda_opt'],
                                                                                log['_lambda_scalar'])
                string_distance = 'd_min = {:.4f}, d_max. = {:.4f}, d_true = {:.4f}'.format(log['_d_min'],
                                                                                        log['_d_max'],
                                                                                        log['_d_true'])
                print(' '*(len(string_episode)+2)  + string_lambda + ', ' + string_distance)

        # print all other keys not starting with _
        for key in log_keys:
            if key[0] != '_':  # don't display log items with leading '_'
                print('{:s}: {:.3g}'.format(key, log[key]))
        print('\n')

    def log(self, items: dict) -> None:
        """ Update fields in log (does not write to file, used to collect updates).

        Args:
            items: dictionary of items to update
        """
        self.log_entry.update(items)

    def close(self) -> None:
        """ Close log file - log cannot be written after this """
        self.f.close()
