import os
import json
import pickle
import numpy as np
import torch

from gp import GP

from .basefunc import BaseFunc
from util import normalize_data


class LakeZurich(BaseFunc):
    def __init__(
        self,
        xsize=100,
        noise_std=0.01,
    ):
        """__init__.

        Parameters
        ----------
        xsize : int
            the size of the discrete input domain
        noise_std : float
            the standard deviation of the noise of the GP model
            to generate the noisy observation

        """
        xsize = int(xsize)
        noise_std = float(noise_std)

        xdim = 2
        super(LakeZurich, self).__init__(xdim, xsize, noise_std=noise_std)

        self.module_name = "lake_zurich"
        self.xsize = xsize
        self.x_domain = BaseFunc.generate_discrete_points(xsize, xdim)

        self.train_x, self.ys = LakeZurich.preprocess_data()

        hyperparameter_filename = (
            f"func/gp_hyperparameters/gp_hyperparameters_lake_zurich.json"
        )
        with open(hyperparameter_filename, "r", encoding="utf-8") as f:
            self.hyperparameters = json.load(f)

        self.data_gp_model = GP(
            self.train_x,
            self.ys,
            initialization=self.hyperparameters,
            prior=None,
            ard=True,
        )

    @staticmethod
    def preprocess_data():
        ys = np.genfromtxt(
            "dataset/lake_zurich/eur-06-e7-nitrogen-no3.csv", delimiter=","
        )
        x0s, x1s = np.meshgrid(
            np.linspace(0.0, 1.0, ys.shape[1]), np.linspace(0.0, 1.0, ys.shape[0])
        )
        xs = np.stack([x0s.flatten(), x1s.flatten()]).T

        ys = ys.flatten()
        ys = torch.from_numpy((ys - ys.mean()) / np.std(ys))
        train_x = torch.from_numpy(xs)

        return train_x.float(), ys.float()

    def get_noiseless_observation_from_inputs(self, x):
        """get function evaluation at input

        Parameters
        ----------
        x : tensor array of size (n, self.xdim)
            inputs to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs x

        """
        with torch.no_grad():
            x = x.reshape(-1, self.xdim)
            f_preds = GP.predict_f(self.data_gp_model, x)
            with torch.no_grad():
                f_means = f_preds.mean

        return f_means

    def get_noiseless_observation_from_input_idxs(self, x_idxs):
        """get function evaluation at input idxs

        Parameters
        ----------
        x_idxs : tensor array or list of int64 of shape (n,)
            indices of inputs in self.domain to be evaluated

        Returns
        -------
        val : tensor array of float32
            evaluations of the GP sample at inputs specified by x_idxs

        """
        x = self.x_domain[x_idxs, :].reshape(-1, self.xdim)
        return self.get_noiseless_observation_from_inputs(x)
