import os
import sys
import time
import pickle
import numpy as np


class TextLogger(object):
    """Writes stream output to external text file.

    Args:
        filename (str): the file to write stream output
        stream: the stream to read from. Default: sys.stdout
    """
    def __init__(self, filename, stream=sys.stdout):
        self.terminal = stream
        self.log = open(filename, 'a')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.terminal.close()
        self.log.close()


class CompleteLogger:
    """
    A useful logger that

    - writes outputs to files and displays them on the console at the same time.
    - manages the directory of checkpoints and debugging images.

    Args:
        root (str): the root directory of logger
        phase (str): the phase of training.

    """

    def __init__(self, root, phase='train'):
        self.root = root
        self.phase = phase
        self.epoch = 0

        os.makedirs(self.root, exist_ok=True)

        # redirect std out
        now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
        log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now))
        if os.path.exists(log_filename):
            os.remove(log_filename)
        self.logger = TextLogger(log_filename)
        sys.stdout = self.logger
        sys.stderr = self.logger

    def close(self):
        self.logger.close()
        
class Archiver(object):
    def __init__(self):
        self.basic_path = None
        self.checkpoint_path = None
        self.ac_records_path = None
    
    def set_path(self, log, seed):
        log = os.path.join(log, str(seed))
        if not os.path.exists(log):
            os.mkdir(log)

        self.basic_path = log

        self.checkpoint_path = os.path.join(log, 'checkpoints')
        
        self.ac_records_path = os.path.join(log, 'ac_records')
        
        if not os.path.exists(self.checkpoint_path):
            os.mkdir(self.checkpoint_path)
        if not os.path.exists(self.ac_records_path):
            os.mkdir(self.ac_records_path)
            
    def save_model(self, model, name):
        
        print(self.checkpoint_path)
        print(name)
        print(os.path.join(self.checkpoint_path, name))
        
        with open(os.path.join(self.checkpoint_path, name), 'wb') as f:
            pickle.dump(model, f)
        print('save model >>> ', os.path.join(self.checkpoint_path, name))
    
    def load_model(self, name):
        with open(os.path.join(self.checkpoint_path, name), 'rb') as f:
            model = pickle.load(f)
        return model
    
    def save_ac_samples(self, npy, name):
        npy = np.asarray(npy).astype('int')
        np.savetxt(os.path.join(self.ac_records_path, name), npy, fmt='%i')
        print('save ac samples >>> ', os.path.join(self.ac_records_path, name))
        
        