import string
import random

from sacred.observers import S3Observer
from sacred.observers.base import RunObserver
from sacred.observers.s3_observer import s3_join

DEFAULT_S3_PRIORITY = 20


class SetRunId(RunObserver):
    def __init__(self, run_id, priority=100):
        self.run_id = run_id
        self.priority = priority

    def queued_event(
        self, ex_info, command, queue_time, config, meta_info, _id
    ):
        return _id if _id else self.run_id

    def started_event(
        self, ex_info, command, host_info, start_time, config, meta_info, _id
    ):
        return _id if _id else self.run_id


class ImprovedS3Observer(S3Observer):
    def __init__(
        self,
        bucket,
        basedir,
        resource_dir=None,
        source_dir=None,
        priority=DEFAULT_S3_PRIORITY,
        region=None,
        write_every_n_hearbeats=10,
    ):
        self.write_every_n_hearbeats = write_every_n_hearbeats
        self._heartbeat_count = 0

        super().__init__(
            bucket=bucket,
            basedir=basedir,
            resource_dir=resource_dir,
            source_dir=source_dir,
            priority=priority,
            region=region,
        )

    def save_cout(self):
        binary_data = self.cout.encode("utf-8")
        key = s3_join(self.dir, "cout.txt")
        self.put_data(key, binary_data)

    def heartbeat_event(self, info, captured_out, beat_time, result):
        self.info = info
        self.run_entry["heartbeat"] = beat_time.isoformat()
        self.run_entry["result"] = result
        self.cout = captured_out
        self._heartbeat_count += 1

        if self._heartbeat_count >= self.write_every_n_hearbeats:
            self.save_cout()
            self.save_json(self.run_entry, "run.json")
            if self.info:
                self.save_json(self.info, "info.json")
            self._heartbeat_count = 0

    def completed_event(self, stop_time, result):
        self.run_entry["stop_time"] = stop_time.isoformat()
        self.run_entry["result"] = result
        self.run_entry["status"] = "COMPLETED"
        self.save_json(self.run_entry, "run.json")
        self.save_cout()
        if self.info:
            self.save_json(self.info, "info.json")
