import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import logging
import random

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


def set_logger(file_path = 'log.txt', handle = 1):
    # create logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Create Handler
    # type 1: file handler
    # type 2: stream handler
    if handle == 1:
        log_handler = logging.FileHandler(file_path, mode='w', encoding='UTF-8')
    elif handle == 2:
        log_handler = logging.StreamHandler()
    else:
        log_handler = logging.FileHandler(file_path, mode='w', encoding='UTF-8')

    # Set formatter
    formatter = logging.Formatter('%(asctime)s - %(funcName)s - %(levelname)s - %(message)s')
    # formatter = logging.Formatter('%(levelname)s - %(message)s')
    log_handler.setFormatter(formatter)

    # Add to logger
    logger.addHandler(log_handler)

    return logger

