import numpy as np
import argparse
import os
from helpers import regret,EI
import json
from sklearn.preprocessing import MinMaxScaler
from deepKT import DeepKernelGP,net as backbone
import torch

parser = argparse.ArgumentParser()
parser.add_argument('--n_iters', default=100, type=int, nargs='?', help='number of iterations for optimization method')
parser.add_argument('--space', help='Search Space Id', type=str)
parser.add_argument('--task', help='Selected Task ID', type=str)
parser.add_argument('--seed', help='Which seed', type=str,choices=["0","1","2","3","4",])
parser.add_argument('--fold', help='Select which fold to use', type=str,choices=["validation","test"])
args = parser.parse_args()

args.seed = int(args.seed)
rootdir     = os.path.dirname(os.path.realpath(__file__))
savedir     = os.path.join(rootdir,"results",f"seed-{args.seed}","DKLM",args.fold,args.space)
os.makedirs(savedir,exist_ok=True)

data_file = os.path.join(rootdir,"preprocessing", "datasets",f"meta_{args.fold}_dataset_open_ml.json")
with open(data_file, "rb") as f:
    hpo_data = json.load(f)   
    
initialization_ids = os.path.join(rootdir,"preprocessing","bo_initialization_ids.json")
with open(initialization_ids, "rb") as f:
    init_ids = json.load(f)
    
Lambda,response =     np.array(hpo_data[args.space][args.task]["X"]), MinMaxScaler().fit_transform(np.array(hpo_data[args.space][args.task]["y"]))
c,D = Lambda.shape
########## NOT used
Z = np.random.rand(32)

random = np.random.RandomState(301)
randomInitializer = np.random.RandomState(args.seed) ########### for random restarts
log_dir     = os.path.join(rootdir,"logs",f"seed-{args.seed}","DKLM",args.fold,args.space)
os.makedirs(log_dir,exist_ok=True)
logger = os.path.join(log_dir,f"{args.task}.txt")
    

backbone_params = json.load(open(os.path.join(rootdir,"Setconfig90.json"),"rb"))
backbone_params.update({"dim":D})
backbone_fn = lambda : backbone(backbone_params)
load_model = False
checkpoint_path = os.path.join(rootdir,"checkpoints","DKLM",args.space,"meta-v1")
if args.fold == "test":
    x = init_ids[args.space][args.task][f"test{args.seed}"] 
else:
    random = np.random.RandomState(seed=int(args.seed))
    x = random.choice(np.arange(c),size=5,replace=False).tolist()    
y = response[x]
q = Lambda[x]
random_seed = randomInitializer.randint(0,100000)
whiteNoise = np.random.RandomState(314)
for _ in range(args.n_iters):
    retries = 0
    if max(response) in y:
        break        
    model     = DeepKernelGP(Lambda,response.reshape(-1,),Z,log_dir=logger,kernel=backbone_params["kernel"],
                              support=x,backbone_fn=backbone_fn,
                              config=backbone_params,seed =random_seed)

    optimizer = torch.optim.Adam([{'params': model.model.parameters(), 'lr': backbone_params["lr"]},
                          {'params': model.feature_extractor.parameters(), 'lr': backbone_params["lr"]}])
    
    losses,weights,initial_weights = model.train(x,load_model=load_model,
                                                  checkpoint=checkpoint_path,epochs=backbone_params["epochs"],
                                      optimizer=optimizer,verbose=False)
    noise_fn = lambda x : x
    done = False
    while not done and retries < 5:
        try:
            if c > 100:
                scores = []
                for i in range(100, c+100, 100):
                    predict_fn = lambda queries: model.predict(x,range(i-100,min(i,c)), noise_fn)
                    score   =     EI(max(y),predict_fn,support=np.where(np.array(x)<i)[0].tolist(),queries=Lambda,return_variance=False, return_score=True)
                    scores += score.tolist()
                scores = np.array(scores)
                scores[x] = 0
                candidate = np.argmax(scores)
            else:
                predict_fn = lambda queries: model.predict(x)
                
                candidate   =     EI(max(y),predict_fn,support=x,queries=Lambda,return_variance=False)
            done = True
        except Exception as fasdsa:
            print(fasdsa)
            retries+=1
            noise_fn = lambda x: x + whiteNoise.randn()*0.01
    
    x.append(candidate)

    y = response[x]
    q = Lambda[x]
    done=True
results            = regret(y,response)
results['indices'] = np.asarray(x).reshape(-1,)
results.to_csv(os.path.join(savedir,f"{args.task}.csv"))