__author__ = 'XF'
__date__ = '2023/07/03'

'''
some frequently used general functions for FairAD.

'''
import os
from os import path as osp
import pickle
import time
import numpy as np
from builtins import print as b_print
import json 
from argparse import ArgumentParser, ArgumentTypeError

ROOT_DIR = osp.abspath(osp.dirname(__file__))
DATA_DIR = osp.join(ROOT_DIR, 'data')

# other
begining_line = '=============================== Begin ======================================='
ending_line =   '================================ End ========================================'


# object serialization
def obj_save(path, obj):

    if obj is not None:
        with open(path, 'wb') as file:
            pickle.dump(obj, file)
    else:
        print('object is None!')


# object instantiation
def obj_load(path):

    if os.path.exists(path):
        with open(path, 'rb') as file:
            obj = pickle.load(file)
        return obj
    else:
        raise OSError('no such path:%s' % path)

# logging
class Log(object):

    def __init__(self, log_dir, log_name):
        self.log_path = osp.join(log_dir, generate_filename('.txt', *log_name, timestamp=True))
        self.print(begining_line)
        self.print('date: %s' % time.strftime('%Y/%m/%d-%H:%M:%S'))
    
    def print(self, *args, end='\n'):

        with open(file=self.log_path, mode='a', encoding='utf-8') as console:
            b_print(*args, file=console, end=end)
        b_print(*args, end=end)
    
    @property
    def ending(self):
        self.print('date: %s' % time.strftime('%Y/%m/%d-%H:%M:%S'))
        self.print(ending_line)



def generate_filename(suffix, *args, sep='_', timestamp=False):

    '''

    :param suffix: suffix of file
    :param sep: separator, default '_'
    :param timestamp: add timestamp for uniqueness
    :param args:
    :return:
    '''

    filename = sep.join(args).replace(' ', '_')
    if timestamp:
        filename += time.strftime('_%Y%m%d%H%M%S')
    if suffix[0] == '.':
        filename += suffix
    else:
        filename += ('.' + suffix)

    return filename


def json_load(path):
    
    with open(path, 'r', encoding='utf8') as f:
        content = json.load(f)
    return content


def json_dump(path, dict_obj):

    with open(path, 'a+', encoding='utf-8') as f:
        json.dump(dict_obj, f, indent=4, ensure_ascii=False)


def args_to_dict(args, keys):
    return {k: getattr(args, k) for k in keys}


def create_argparser(default_args):

    parser = ArgumentParser()
    add_dict_to_argparser(parser, default_args)
    return parser


def new_dir(father_dir, mk_dir=None):

    if mk_dir is None:
        new_path = osp.join(father_dir, time.strftime('%Y%m%d%H%M%S'))
    else:
        new_path = osp.join(father_dir, mk_dir)
    if not os.path.exists(new_path):
        os.makedirs(new_path)
    return new_path
    

def add_dict_to_argparser(parser, default_dict):
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v, type=v_type)


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ArgumentTypeError("boolean value expected")



def get_missing_data(data, missing_rate):

    mask = np.random.rand(*data.shape) < missing_rate

    data[mask] = 0.0

    return data, 1 - 1 * mask