"""
set random seed
"""
import torch 
import random
import numpy as np 

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


"""
load arguments
"""
import argparse

def parse_arg():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--dataset_name', default='hymenoptera', help='pmf, 110-classifiers, openml')
    parser.add_argument(
        '--model_name', default='resnet18', help='pmf, 110-classifiers, openml')
    parser.add_argument(
        '--model_weights', default='IMAGENET1K_V1', help='model_weights')
    parser.add_argument(
        '--is_fulltrain', type=bool, default=False, help='model_weights')
    parser.add_argument(
        '--is_save', type=bool, default=False, help='model_weights')

    parser.add_argument(
        '--save_path', default='default', help='the path to save result')
    parser.add_argument(
        '--save_name', default='default', help='save the reuslts')
    parser.add_argument(
        '--set_path', default='default', help='the path to save result')
    parser.add_argument(
        '--set_name', default='default', help='the path to save result')

    parser.add_argument(
        '--random_seed', type=int, default=0, help='for random seed')
    parser.add_argument(
        '--n_epoch', type=int, default=20, help='for random seed')
    parser.add_argument(
        '--batch_size', type=int, default=16, help='the batch size')
    parser.add_argument(
        '--device', default='cpu', help='the device, cuda')
    
    args, unparsed = parser.parse_known_args()
    return args



"""
Set Logging
"""
import os
import logging

RESULT_PATH = '../result/'

def setup_logging(dataset_name, save_path, save_name, is_visual=False):
    formatter = logging.Formatter("%(message)s")
    
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    console_handler.setLevel(level=logging.WARNING)

    if is_visual:
        logs_path = os.path.join(RESULT_PATH+'{}/{}/{}'.format(save_path, dataset_name, save_name), 'visual')
    else:
        logs_path = os.path.join(RESULT_PATH+'{}/{}'.format(save_path, dataset_name), save_name)

    os.makedirs(os.path.dirname(logs_path), exist_ok=True)
    file_handler = logging.FileHandler(logs_path, mode='a', encoding='utf-8')
    file_handler.setFormatter(formatter)
    file_handler.setLevel(level=logging.INFO)
    
    logging.basicConfig(level=logging.INFO, handlers=[console_handler, file_handler]) 



"""
load Dataset
"""
from torchvision import datasets
import torchvision.transforms as transforms

def load_dataset(dataset_name='hymenoptera', model_type='resnet18', batch_size=16, random_seed=0):

    # Data Path
    data_path = 'anonymized'
    if dataset_name in ['hymenoptera', 'mnist', 'aircraft', 'usps', 'stl10', 'cifar10']:
        data_path = f'{data_path}/torchvision/{dataset_name}'
    
    # Load Dataset
    if dataset_name in ['hymenoptera']:
        data_transforms = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
            'test': transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ]),
        }

        # Load Dataset
        train_set = datasets.ImageFolder(os.path.join(data_path, 'train'), data_transforms['train'])
        test_set = datasets.ImageFolder(os.path.join(data_path, 'test'), data_transforms['test'])
    elif dataset_name in ['mnist']:
        if 'resnet' in model_type:
            data_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        else:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])

        train_set = datasets.MNIST(data_path, train=True, transform=data_transforms, download=True)
        test_set  = datasets.MNIST(data_path, train=False, transform=data_transforms, download=True)
    elif dataset_name in ['aircraft']:
        if 'resnet' in model_type:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                #transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        else:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])

        train_set = datasets.FGVCAircraft(data_path, split="trainval", transform=data_transforms, download=True)
        test_set  = datasets.FGVCAircraft(data_path, split="test", transform=data_transforms, download=True)

    elif dataset_name in ['usps']:
        if 'resnet' in model_type:
            data_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        else:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x.repeat(3,1,1)),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])

        train_set = datasets.USPS(data_path, train=True, transform=data_transforms, download=True)
        test_set  = datasets.USPS(data_path, train=False, transform=data_transforms, download=True)
        n_class = 10

    elif dataset_name=='stl10':
        if 'resnet' in model_type:
            data_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        else:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        
        train_set = datasets.STL10(data_path, split='train', transform=data_transforms, download=True)
        test_set  = datasets.STL10(data_path, split='test', transform=data_transforms, download=True)
        n_class = 10
            
    elif dataset_name=='cifar10':
        
        if 'resnet' in model_type:
            data_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        else:
            data_transforms = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ])
        
        train_set = datasets.CIFAR10(data_path, train=True , transform=data_transforms, download=True)
        test_set  = datasets.CIFAR10(data_path, train=False, transform=data_transforms, download=True)
        n_class = 10

    n_train = int(len(train_set) * 0.8)
    n_val   = len(train_set) - n_train
    n_test  = len(test_set)
    if dataset_name in ['usps']:
        n_class = 10
    else:
        n_class = len(train_set.classes)

    train_set, val_set = torch.utils.data.random_split(train_set, [n_train, n_val])

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader   = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=0)

    #logging.warning(f'- n_train: {n_train}')
    #logging.warning(f'- n_val: {n_val}')
    #logging.warning(f'- n_test: {n_test}')

    return train_loader, val_loader, test_loader, n_class



"""
Load Model
"""
import torch.nn as nn
from torchvision import models

def load_model(model_type='resnet18', n_output=2, weights='IMAGENET1K_V1', is_full=False):
    model = None
    if model_type=='resnet18':
        model = models.resnet18(weights=weights)
    elif model_type=='resnet152':
        model = models.resnet152(weights=weights)
    elif model_type=='densenet121':
        model = models.densenet121(weights=weights)
    elif model_type=='mobilenet2':
        model = models.mobilenet_v2(weights=weights)
    elif model_type=='vit':
        model = models.vit_b_16(weights=weights)

    for param in model.parameters():
        param.requires_grad = is_full
    
    if 'resnet'  in model_type:
        n_features = model.fc.in_features
        model.fc = nn.Linear(n_features, n_output)
    elif 'densenet' in model_type:
        n_features = model.classifier.in_features
        model.classifier = nn.Linear(n_features, n_output)
    elif 'mobilenet' in model_type:
        n_features = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(n_features, n_output)
    elif 'vit' in model_type:
        n_features = model.heads.head.in_features
        model.heads.head = nn.Linear(n_features, n_output)

    return model





"""
Load Optimizer
"""
import opt_sw
import torch.optim as optim
import torch_optimizer

def load_optimizer(net, settings={'optimizer_space': ['SGD', 'SGDM'], 'optimizer_type':'OptSW', 'optimizer_params':{'lr':0.01}}, device='cpu'):
    optimizer = None
    optimizer_space  = settings['optimizer_space']
    optimizer_type   = settings['optimizer_type']
    optimizer_params = settings['optimizer_params']
    
    # single optimizer
    if optimizer_type == 'SGD':
        optimizer = optim.SGD(net.parameters(), **optimizer_params)
    elif optimizer_type == 'SGDM':
        optimizer = optim.SGD(net.parameters(), **optimizer_params)
    elif optimizer_type == 'Adagrad':
        optimizer = optim.Adagrad(net.parameters(), **optimizer_params)
    elif optimizer_type == 'RMSprop':
        optimizer = optim.RMSprop(net.parameters(), **optimizer_params)
    elif optimizer_type == 'Adam':
        optimizer = optim.Adam(net.parameters(), **optimizer_params)

    # hybird optimizer
    elif optimizer_type=='SWATS':
        optimizer = torch_optimizer.SWATS(net.parameters(), **optimizer_params)
    elif optimizer_type=='Padam':
        optimizer = opt_sw.Padam(net.parameters(), **optimizer_params)
    elif optimizer_type=='AdaBound':
        optimizer = torch_optimizer.AdaBound(net.parameters(), **optimizer_params)

    # fine-grain 
    elif optimizer_type=='RandomSW':
        optimizer = opt_sw.RandomSwitcher(**settings)
    elif optimizer_type=='CyclicalSW':
        optimizer = opt_sw.CyclicalSwitcher(**settings)
    elif optimizer_type=='SMACSW':
        optimizer = opt_sw.SMACSwitcher(**settings)

    # Our Method
    elif optimizer_type == 'OptSW':
        optimizer = opt_sw.OptSwitcher(**optimizer_params, device=device)

    return optimizer


