import models
import sys
import numpy as np
import models
import torch
import math
import gpytorch
import matplotlib.pyplot as plt
from pyro.infer.mcmc import NUTS, MCMC, HMC
from gpytorch.likelihoods import GaussianLikelihood
import pyro
import pickle
from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior, MultivariateNormalPrior
from gpytorch.likelihoods import GaussianLikelihood, FixedNoiseGaussianLikelihood
import pyro.distributions as dist
from torch.nn import Module
import torch.nn as nn
import torch.optim as optim
import time
import scipy
from scipy.optimize import minimize
from funcs import sin_target, band_gap_target
from surrogates import NPNVSurrogate, SepSurrogate
from experiment import experiment_num, experiment_num_two_step
from acquisition_function import UCB_by_num, UCB_nfc_by_num

#Z is the extracted features
#Y expermetal
#Y0 lowest energy per atom
#Y1 highest energy per atom

if __name__ == "__main__":
    GENERATE_NEW = True
    # GENERATE_NEW = False
    for exp in range(10):
        print(f"EXP: {exp}")
        save_dir = './saves/'
        target = band_gap_target('./', '_test', cost=[5, 1])
        surrogate = NPNVSurrogate(2, 2, n_mean=[1, 1], n_length_scale=[0.5, 0.5], n_scale=[1, 1], m_mean=2, m_l=0.5, m_scale=4)
        acqui = UCB_by_num()
        surrogate_nfc = NPNVSurrogate(2, 2, n_mean=[1, 1], n_length_scale=[0.5, 0.5], n_scale=[1, 1], m_mean=2, m_l=0.5, m_scale=4)
        acqui_nfc = UCB_nfc_by_num()
        surrogate_ts = NPNVSurrogate(2, 2, n_mean=[1, 1], n_length_scale=[0.5, 0.5], n_scale=[1, 1], m_mean=2, m_l=0.5, m_scale=4)
        acqui_ts = UCB_by_num()
        surrogate_sep = SepSurrogate(2, 2, mean=torch.tensor([2]), lengthscale=torch.tensor([[0.5]]), outputscale=torch.tensor([4]))
        acqui_sep = UCB_by_num()

        init_sample_num = 2
        iteration_num = 10

        if GENERATE_NEW:
            indices = torch.randperm(target.size)[:init_sample_num]
            index_x = torch.ones([indices.shape[0]]).type(torch.long)
            index_x[:int(index_x.shape[0] / 2)] = torch.zeros([int(index_x.shape[0] / 2)]).type(torch.long)
            torch.save(indices, save_dir + f'init_num_x_{exp}.ts')
            torch.save(index_x, save_dir + f'init_ind_x_{exp}.ts')
        else:
            indices = torch.load(save_dir + f'init_num_x_{exp}.ts')
            index_x = torch.load(save_dir + f'init_ind_x_{exp}.ts')
        print("Begin_running_model_1")
        experiment = experiment_num(target, surrogate, acqui)
        experiment.initialize_given(indices, index_x)
        experiment.run_iterations(iteration_num)
        experiment.save(save_dir+f'res_our_{exp}_')
        print("Begin_running_model_2")
        experiment_nfc = experiment_num(target, surrogate_nfc, acqui_nfc)
        experiment_nfc.initialize_given(indices, index_x)
        experiment_nfc.run_iterations(iteration_num)
        experiment_nfc.save(save_dir + f'res_nfc_{exp}_')
        print("Begin_running_model_3")
        experiment_sep = experiment_num(target, surrogate_sep, acqui_sep)
        experiment_sep.initialize_given(indices, index_x)
        experiment_sep.run_iterations(iteration_num)
        experiment_sep.save(save_dir + f'res_sep_{exp}_')
        print("Begin_running_model_4")
        experiment_ts = experiment_num_two_step(target, surrogate_ts, acqui_ts)
        experiment_ts.initialize_given(indices, index_x)
        experiment_ts.run_iterations(iteration_num)
        experiment_ts.save(save_dir + f'res_ts_{exp}_')
