import torch
import numpy as np
import random
from alg.utils import SArunexp

#########################################################
# This code runs the meta-algorithm and SGD+test-set on #
# on *one* dataset and *one* pre-train fraction.        #
#########################################################

# which dataset, how many supports to begin with
DATASET_ID =0       # Dataset number, in {0,...,59} 
PRETRAIN_FRAC = 0.5 # Data fraction to train initial hypothesis (or SGD), in [0,1)

# number of datapoints in each dataset 
N_DATAPOINTS =  1000

# data and model
MODEL_TYPE = 'fcn'           
NAME_DATA = 'binarymnist'

# Hyper params
TRAIN_EPOCHS = 200
LEARNING_RATE = 0.01 
MOMENTUM = .95 
DROPOUT_PROB = .2
BATCH_SIZE = 60000 # ensures no minibatches

# Confidence and threshold
DELTA = 0.035
C = .69314718 # -ln(0.5)

# Fix device and random seed
DEVICE = torch.device("cpu") 
random_seed = 10 
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
random.seed(random_seed)
np.random.seed(random_seed)
G = torch.Generator(device='cpu').manual_seed(random_seed)

# jointly run meta-algo (called SA) and SGD+Test-set
SGD_ub, SA_ub, SA_p_misclass, SGD_p_misclass = SArunexp(DATASET_ID, G, C, NAME_DATA, DELTA, LEARNING_RATE, MOMENTUM, BATCH_SIZE, TRAIN_EPOCHS, DROPOUT_PROB, N_DATAPOINTS, PRETRAIN_FRAC, DEVICE)


# print ub and misclass on test set
print(SGD_ub, SA_ub, SGD_p_misclass, SA_p_misclass,)   
   
