import sys
import os
import yaml
import random
import numpy as np

import torch
import torch.backends.cudnn as cudnn

class DotDict(dict):
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

class StreamWrapper:
    def __init__(self, original_stream, log_file):
        self.original_stream = original_stream
        self.log_file = log_file

        with open(self.log_file, "w"): pass

    def write(self, message):
        if isinstance(message, str):
            self.original_stream.write(message)
            with open(self.log_file, "a") as file:
                file.write(message)

    def flush(self):
        self.original_stream.flush()

def log_std(path, log_name='std', incl_stdout=True, incl_stderr=True):
    log_file_path = os.path.join(path, f'{log_name}.log')
    if incl_stdout: sys.stdout = StreamWrapper(sys.stdout, log_file_path)
    if incl_stderr: sys.stderr = StreamWrapper(sys.stderr, log_file_path)

def set_all_seeds(seed):
    cudnn.deterministic = True
    cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

def get_wandb_args(team_path, exp_name, run_name):
    run_wandb_config_path = os.path.join(team_path, exp_name, run_name,
                                         'wandb',
                                         'latest-run',
                                         'files',
                                         'config.yaml'
                                        )
    with open(run_wandb_config_path, 'r') as f:
        wandb_config = yaml.safe_load(f)
    wandb_config.pop('_wandb', None)
    wandb_config.pop('wandb_version', None)
    return DotDict({k:v['value'] for k,v in wandb_config.items()})