import torch
import sys
sys.path.append('..')
sys.path.append('./')
from model.vgg import VGG11, BiF_VGG11
from model.resnet import resnet34, resnet18

def get_device(gpu):
    device= torch.device('cpu')
    if gpu >= 0 and torch.cuda.is_available():
        device = torch.device('cuda:'+str(gpu))
    print("training on", device)
    return device

def get_model(model, train_mask=None):
    if model == "vgg11":
        net = VGG11()
    elif model == "bif-vgg11":
        net = BiF_VGG11(train_mask=train_mask)
    
    elif "resnet18" in model:
        net = resnet18()
    elif model == "bif-resnet18":
        net = resnet18(train_mask=train_mask)    

    elif model == "resnet34":
        net = resnet34()
    elif model == "bif-resnet34":
        net = resnet34(train_mask=train_mask)

    else:
        raise ValueError("unavaliable model.")
    return net

def get_round_mask(tune_mask=0, round=0):
    if tune_mask==0 or round==0:
        return None
    if tune_mask==1:
        mask = "0000"
        mask_list = list(mask)
        mask_list[(round-1)%4] = "1"
        mask = "".join(mask_list)
    if tune_mask==2:
        mask = "00000000"
        mask_list = list(mask)
        mask_list[(round-1)%8] = "1"
        mask = "".join(mask_list)
    if tune_mask==3:
        mask_list=[
            "1100",
            "0011",
        ]
        mask = mask_list[(round-1)%2]
    if tune_mask==4:
        mask_list=[
            "11000000",
            "00110000",
            "00001100",
            "00000011",
        ]
        mask = mask_list[(round-1)%4]
    
    return mask
   
if __name__ == "__main__":
    for i in range(1,10):
        print(get_round_mask(1,i))