import torch
from deepKT import DKT,net as backbone
import numpy as np
import datetime
import os
import json
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--space', help='Search Space Id', type=str)
parser.add_argument('--scenario', help='Which meta-training scenario', type=str)
args = parser.parse_args()
args.scenario = "2"
args.space = "6766"
rootdir     = os.path.dirname(os.path.realpath(__file__))
np.random.seed(123)
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False 

data_file = os.path.join(rootdir,"preprocessing", "datasets","meta_validation_dataset_open_ml.json")
with open(data_file, "rb") as f:
    valid_data = json.load(f)   
valid_data = {args.space:valid_data[args.space]}    
data_file = os.path.join(rootdir,"preprocessing", "datasets","meta_train_dataset_open_ml.json")
with open(data_file, "rb") as f:
    hpo_data = json.load(f)
hpo_data = {args.space:hpo_data[args.space]}    

c,D = np.array(hpo_data[args.space][list(hpo_data[args.space].keys())[0]]["X"]).shape

backbone_params = json.load(open(os.path.join(rootdir,"Setconfig90.json"),"rb"))
backbone_params.update({"dim":D})
backbone_fn = lambda : backbone(backbone_params)
backbone_params.update({"fixed_context_size":5})
backbone_params.update({"minibatch_size":64})
backbone_params.update({"n_inner_steps":1})
checkpoint_path = os.path.join(rootdir,"checkpoints","DKLM", f"{args.space}", datetime.datetime.now().strftime('meta-%Y-%m-%d-%H-%M-%S-%f'))
backbone_params.update({"checkpoint_path":checkpoint_path})

model = DKT(train_data=hpo_data,valid_data=valid_data, kernel=backbone_params["kernel"],backbone_fn=backbone_fn,
                                      config=backbone_params,)

optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001)
scheduler_fn = lambda x,y: torch.optim.lr_scheduler.CosineAnnealingLR(x, y, eta_min=1e-7)
for epoch in range(100000):
    model.train_loop(epoch, optimizer, scheduler_fn)
