import os
import sys
import torch
import random
import numpy as np
from datetime import datetime

def safe_state(cfg, silent=False):
    old_f = sys.stdout
    class F:
        def __init__(self, silent):
            self.silent = silent

        def write(self, x):
            if not self.silent:
                if x.endswith("\n"):
                    old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
                else:
                    old_f.write(x)

        def flush(self):
            old_f.flush()

    sys.stdout = F(silent)

    random.seed(cfg.general.random_seed)
    np.random.seed(cfg.general.random_seed)
    torch.manual_seed(cfg.general.random_seed)
    device = torch.device("cuda:{}".format(cfg.general.device))
    torch.cuda.set_device(device)

    return device

class IOStream():
    def __init__(self, path):
        if not os.path.exists(path):
            os.makedirs(os.path.dirname(path), exist_ok=True)

        self.f = open(path, 'a')

    def cprint(self, text):
        print(text)
        self.f.write(text+'\n')
        self.f.flush()

    def close(self):
        self.f.close()

C0 = 0.28209479177387814
def SH2RGB(sh):
	return sh * C0 + 0.5