import torch
import gpytorch
from omegaconf import DictConfig
from simulators.simulator import SIR
import pickle
import hydra

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@hydra.main(version_base="1.3", config_path="../../config", config_name="data_generation")
def sample_SIR(cfg: DictConfig):

    # prior
    prior_lengthscale = 7.0
    prior_scale = 2.5
    # simulator
    simulator_likelihood_scale = 0.05
    # time setting
    seq_len = 500  # number of time steps
    T = 50  # time horizon

    seed = cfg.data_config.seed
    N = cfg.data_config.N

    # set seed
    torch.manual_seed(seed)

    # define prior
    ts = torch.linspace(0, T, seq_len)
    likelihood = gpytorch.likelihoods.GaussianLikelihood()
    prior = ExactGPModel(train_x=None, train_y=None, likelihood=likelihood)
    prior.covar_module.base_kernel.lengthscale = prior_lengthscale
    prior.covar_module.base_kernel.scale = prior_scale
    prior.eval()
    prior_dist = prior(ts)

    # sample from prior
    theta = prior_dist.sample(torch.Size([N])).to(device)
    theta = torch.sigmoid(theta) * 0.35

    # run simulator
    x = SIR(theta, ts, likelihood_scale=simulator_likelihood_scale, device=device)

    # save data
    dict_to_save = {
            "theta": theta,
            "x": x,
            "simulation_grid": ts,
        }
    print(cfg.data_config.path, cfg.data_config.data_file)
    with open(cfg.data_config.path+cfg.data_config.data_file, "wb") as f:
            pickle.dump(dict_to_save, f)
    print("data saved.")


if __name__ == "__main__":
    print("started sampling....")
    sample_SIR()
