# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import datetime
import csv
import json
import logging
import os
import time
from typing import Dict


def gather_metadata() -> Dict:
    date_start = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
    # gathering slurm metadata
    if 'SLURM_JOB_ID' in os.environ:
        slurm_env_keys = [k for k in os.environ if k.startswith('SLURM')]
        slurm_data = {}
        for k in slurm_env_keys:
            d_key = k.replace('SLURM_', '').replace('SLURMD_', '').lower()
            slurm_data[d_key] = os.environ[k]
    else:
        slurm_data = None
    return dict(
        date_start=date_start,
        date_end=None,
        successful=False,
        slurm=slurm_data,
        env=os.environ.copy(),
    )


class FileWriter:
    def __init__(self,
                 xpid: str = None,
                 xp_args: dict = None,
                 rootdir: str = '~/palaas',
                 symlink_latest: bool = True,
                 ):
        if not xpid:
            # make unique id
            xpid = '{proc}_{unixtime}'.format(
                proc=os.getpid(), unixtime=int(time.time()))
        self.xpid = xpid
        self._tick = 0

        # metadata gathering
        if xp_args is None:
            xp_args = {}
        self.metadata = gather_metadata()
        # we need to copy the args, otherwise when we close the file writer
        # (and rewrite the args) we might have non-serializable objects (or
        # other nasty stuff).
        self.metadata['args'] = copy.deepcopy(xp_args)
        self.metadata['xpid'] = self.xpid

        formatter = logging.Formatter('%(message)s')
        self._logger = logging.getLogger('palaas/out')

        # to stdout handler
        shandle = logging.StreamHandler()
        shandle.setFormatter(formatter)
        self._logger.addHandler(shandle)
        self._logger.setLevel(logging.INFO)

        rootdir = os.path.expandvars(os.path.expanduser(rootdir))
        # to file handler
        self.basepath = os.path.join(rootdir, self.xpid)
        if not os.path.exists(self.basepath):
            self._logger.info('Creating log directory: %s', self.basepath)
            os.makedirs(self.basepath, exist_ok=True)
        else:
            self._logger.info('Found log directory: %s', self.basepath)

        if symlink_latest:
            # Add 'latest' as symlink unless it exists and is no symlink.
            symlink = os.path.join(rootdir, 'latest')
            if os.path.islink(symlink):
                os.remove(symlink)
            if not os.path.exists(symlink):
                os.symlink(self.basepath, symlink)
                self._logger.info('Symlinked log directory: %s', symlink)

        self.paths = dict(
            msg='{base}/out.log'.format(base=self.basepath),
            logs='{base}/logs.csv'.format(base=self.basepath),
            fields='{base}/fields.csv'.format(base=self.basepath),
            meta='{base}/meta.json'.format(base=self.basepath),
        )

        self._logger.info('Saving arguments to %s', self.paths['meta'])
        if os.path.exists(self.paths['meta']):
            self._logger.warning('Path to meta file already exists. '
                                 'Not overriding meta.')
        else:
            self._save_metadata()

        self._logger.info('Saving messages to %s', self.paths['msg'])
        if os.path.exists(self.paths['msg']):
            self._logger.warning('Path to message file already exists. '
                                 'New data will be appended.')

        fhandle = logging.FileHandler(self.paths['msg'])
        fhandle.setFormatter(formatter)
        self._logger.addHandler(fhandle)

        self._logger.info('Saving logs data to %s', self.paths['logs'])
        self._logger.info('Saving logs\' fields to %s', self.paths['fields'])
        if os.path.exists(self.paths['logs']):
            self._logger.warning('Path to log file already exists. '
                                 'New data will be appended.')
            with open(self.paths['fields'], 'r') as csvfile:
                reader = csv.reader(csvfile)
                self.fieldnames = list(reader)[0]
        else:
            self.fieldnames = ['_tick', '_time']

    def log(self, to_log: Dict, tick: int = None,
            verbose: bool = False) -> None:
        if tick is not None:
            raise NotImplementedError
        else:
            to_log['_tick'] = self._tick
            self._tick += 1
        to_log['_time'] = time.time()

        old_len = len(self.fieldnames)
        for k in to_log:
            if k not in self.fieldnames:
                self.fieldnames.append(k)
        if old_len != len(self.fieldnames):
            with open(self.paths['fields'], 'w') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(self.fieldnames)
            self._logger.info('Updated log fields: %s', self.fieldnames)

        if to_log['_tick'] == 0:
            with open(self.paths['logs'], 'a') as f:
                f.write('# %s\n' % ','.join(self.fieldnames))

        if verbose:
            self._logger.info('LOG | %s', ', '.join(
                ['{}: {}'.format(k, to_log[k]) for k in sorted(to_log)]))

        with open(self.paths['logs'], 'a') as f:
            writer = csv.DictWriter(f, fieldnames=self.fieldnames)
            writer.writerow(to_log)

    def close(self, successful: bool = True) -> None:
        self.metadata['date_end'] = datetime.datetime.now().strftime(
            '%Y-%m-%d %H:%M:%S.%f')
        self.metadata['successful'] = successful
        self._save_metadata()

    def _save_metadata(self) -> None:
        with open(self.paths['meta'], 'w') as jsonfile:
            json.dump(self.metadata, jsonfile, indent=4, sort_keys=True)
