"""
set random seed
"""
import torch 
import random
import numpy as np 

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


"""
load arguments
"""
import argparse

def parse_arg():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--dataset_name', default='hymenoptera', help='pmf, 110-classifiers, openml')
    parser.add_argument(
        '--model_name', default='resnet18', help='pmf, 110-classifiers, openml')
    parser.add_argument(
        '--model_weights', default='IMAGENET1K_V1', help='model_weights')
    parser.add_argument(
        '--is_fulltrain', type=bool, default=False, help='model_weights')

    parser.add_argument(
        '--save_path', default='default', help='the path to save result')
    parser.add_argument(
        '--save_name', default='default', help='save the reuslts')
    parser.add_argument(
        '--set_path', default='default', help='the path to save result')
    parser.add_argument(
        '--set_name', default='default', help='the path to save result')

    parser.add_argument(
        '--random_seed', type=int, default=0, help='for random seed')
    parser.add_argument(
        '--n_epoch', type=int, default=20, help='for random seed')
    parser.add_argument(
        '--batch_size', type=int, default=16, help='the batch size')
    parser.add_argument(
        '--device', default='cpu', help='the device, cuda')
    
    args, unparsed = parser.parse_known_args()
    return args



"""
Set Logging
"""
import os
import logging

RESULT_PATH = '../result/'

def setup_logging(dataset_name, save_path, save_name, is_visual=False):
    formatter = logging.Formatter("%(message)s")
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    console_handler.setLevel(level=logging.WARNING)

    if is_visual:
        logs_path = os.path.join(RESULT_PATH+'{}/{}/{}'.format(save_path, dataset_name, save_name), 'visual')
    else:
        logs_path = os.path.join(RESULT_PATH+'{}/{}'.format(save_path, dataset_name), save_name)

    os.makedirs(os.path.dirname(logs_path), exist_ok=True)
    file_handler = logging.FileHandler(logs_path, mode='a', encoding='utf-8')
    file_handler.setFormatter(formatter)
    file_handler.setLevel(level=logging.INFO)
    
    logging.basicConfig(level=logging.INFO, handlers=[console_handler, file_handler]) 






"""
Load Optimizer
"""
import opt_sw
import torch.optim as optim
import torch_optimizer

def load_optimizer(net, settings={'optimizer_space': ['SGD', 'SGDM'], 'optimizer_type':'OptSW', 'optimizer_params':{'lr':0.01}}, device='cpu'):
    optimizer = None
    optimizer_space  = settings['optimizer_space']
    optimizer_type   = settings['optimizer_type']
    optimizer_params = settings['optimizer_params']
    
    # single optimizer
    if optimizer_type == 'SGD':
        optimizer = optim.SGD(net.parameters(), **optimizer_params)
    elif optimizer_type == 'SGDM':
        optimizer = optim.SGD(net.parameters(), **optimizer_params)
    elif optimizer_type == 'Adagrad':
        optimizer = optim.Adagrad(net.parameters(), **optimizer_params)
    elif optimizer_type == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(), **optimizer_params)
    elif optimizer_type == 'Adam':
        optimizer = optim.Adam(net.parameters(), **optimizer_params)

    # hybird optimizer
    elif optimizer_type=='SWATS':
        optimizer = torch_optimizer.SWATS(net.parameters(), **optimizer_params)
    elif optimizer_type=='Padam':
        optimizer = opt_sw.Padam(net.parameters(), **optimizer_params)
    elif optimizer_type=='AdaBound':
        optimizer = torch_optimizer.AdaBound(net.parameters(), **optimizer_params)

    # fine-grain 
    elif optimizer_type=='RandomSW':
        optimizer = opt_sw.RandomSwitcher(**settings)
    elif optimizer_type=='CyclicalSW':
        optimizer = opt_sw.CyclicalSwitcher(**settings)
    elif optimizer_type=='SMACSW':
        optimizer = opt_sw.SMACSwitcher(**settings)

    # Our Method
    elif optimizer_type == 'OptSW':
        optimizer = opt_sw.OptSwitcher(**optimizer_params, device=device)

    return optimizer

