import os
import torch
from datetime import datetime


class Logger():
    def __init__(self, log_name, log=True):
        self.log = log
        # date = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
        self.dir = f""
        if self.log:
            os.makedirs(self.dir, exist_ok=True)
            self.f = open(os.path.join(self.dir, "log.txt"), "w")
        
    def write_config(self, args, models, name):
        if self.log:
            with open(os.path.join(self.dir, "config.txt"), "w") as f:
                print(name, file=f)
                print(args, file=f)
                print(file=f)
                if type(models) == list:
                    for (i, x) in enumerate(models):
                        print(x, file=f)
                        print(file=f)
                else:
                    print(models, f)
        print(args)
        print()
        if type(models) == list:
            for (i, x) in enumerate(models):
                print(x)
                print()
        else:
            print(models)

    def print(self, x=None):
        if x is not None:
            print(x, flush=True)
        else:
            print(flush=True)
        if self.log:
            if x is not None:
                print(x, file=self.f, flush=True)
            else:
                print(file=self.f, flush=True)

    def __del__(self):
        if self.log:
            self.f.close()

    def save(self, model, name):
        if self.log:
            torch.save(model.state_dict(), os.path.join(self.dir, name))