import numpy as np
from scipy.stats import norm
from util import feasibility_check, optimal_solution

class Experiment:
    def __init__(self, experiment_id):
        self.experiment_id = experiment_id
        if self.experiment_id == 1:
            self.dist = "Gaussian"
            self.s = 2
            self.k = 4
            self.m = 1
            self.phi = np.full((self.m,), 0.9)
            self.Y_mean = np.array([[[0.4], [0.4], [1.203232393378558], [2.0448536269514723]], [[0.4], [0.4], [1.203232393378558], [2.0448536269514723]]])
            self.Y_variance = np.ones_like(self.Y_mean)
            self.Y_std = np.sqrt(self.Y_variance)
            self.Y_quantile = norm.ppf(0.9, loc=self.Y_mean, scale=self.Y_std)
            self.F = np.array([[[0.95], [0.95], [0.80], [0.50]],[[0.95], [0.95], [0.80], [0.50]]])
            self.b = np.full((self.m,), 2.0448536269514723)
            self.X_mean = np.array([[1, 0.8, 1.2, 0.9], [1, 0.8, 1.2, 0.9]])
            self.X_variance = np.ones_like(self.X_mean)
            self.feasibility = feasibility_check(self.s, self.k, self.F, self.phi)
            self.opt_solution = optimal_solution(self.s, self.X_mean, self.feasibility)

        elif self.experiment_id == 2:
            self.dist = "Bernoulli"
            self.s = 2
            self.k = 4
            self.m = 1
            self.phi = np.full((self.m,), 0.9)
            self.Y_mean = np.array([[[0.35], [0.45], [1.5],[1.8]], [[0.35], [0.45], [1.5], [1.8]]])
            self.Y_variance = np.ones_like(self.Y_mean)
            self.Y_std = np.sqrt(self.Y_variance)
            self.Y_quantile = norm.ppf(0.9, loc=self.Y_mean, scale=self.Y_std)
            self.b = np.full((self.m,), 2.0)
            self.F = norm.cdf(2.0, loc=self.Y_mean, scale=self.Y_std)
            self.X_mean = np.array([[0.8, 0.6, 0.9, 0.4], [0.8, 0.6, 0.9, 0.4]])
            self.X_variance = self.X_mean * (1-self.X_mean)
            self.feasibility = feasibility_check(self.s, self.k, self.F, self.phi)
            self.opt_solution = optimal_solution(self.s, self.X_mean, self.feasibility)

        elif self.experiment_id == 3:
            self.dist = "Gaussian"
            self.s = 2
            self.k = 7
            self.m = 1
            self.phi = np.full((self.m,), 0.9)
            self.Y_mean = np.array([[[0.35], [0.45], [0.45], [1.65], [1.65], [1.75], [1.75]],
                                    [[0.35], [0.45], [0.45], [1.65], [1.65], [1.75], [1.75]]])
            self.Y_variance = np.ones_like(self.Y_mean)
            self.Y_std = np.sqrt(self.Y_variance)
            self.Y_quantile = norm.ppf(0.9, loc=self.Y_mean, scale=self.Y_std)
            self.b = np.full((self.m,), 2.0)
            self.F = norm.cdf(2.0, loc=self.Y_mean, scale=self.Y_std)
            self.X_mean = np.array([[1, 0.7, 0.6, 1.3, 1.4, 0.7, 0.8],
                                    [1, 0.7, 0.6, 1.3, 1.4, 0.7, 0.8]])
            self.X_variance = np.ones_like(self.X_mean)
            self.feasibility = feasibility_check(self.s, self.k, self.F, self.phi)
            self.opt_solution = optimal_solution(self.s, self.X_mean, self.feasibility)

    def generate_samples(self, dist, task, arm, cons):
        np.random.seed()
        if dist == "Gaussian" and cons == self.m:
            samples = np.random.normal(self.X_mean[task, arm], self.Y_std[task, arm])
        elif dist == "Bernoulli" and cons == self.m:
            samples = np.random.binomial(n=1, p=self.X_mean[task, arm])
        else:
            samples = np.random.normal(self.Y_mean[task, arm, cons], self.Y_std[task, arm, cons])
        return samples
