import json
import numpy as np
from itertools import product


class StochasticSampler:
    def __init__(self, mean_map, clip=None, mode="bernoulli", T=None, seed=None):
        self.mean_map = mean_map
        self.clip = clip
        self.mode = mode
        self.T = T
        self.seed = seed
        self.samples = {}

        if T is not None:
            self._generate_samples()

    def _generate_samples(self):
        rng = np.random.default_rng(self.seed)
        for key, mean in self.mean_map.items():
            if self.mode == "bernoulli":
                self.samples[key] = rng.binomial(1, mean, size=self.T)
            elif self.mode == "bernoulli_sign":
                p = (mean + 1) / 2
                self.samples[key] = 2 * rng.binomial(1, p, size=self.T) - 1
            else:
                raise ValueError(f"Unsupported mode: {self.mode}")

    def __call__(self, x, a, i=None, t=None):
        key = (x, a) if i is None else (x, a, i)
        if self.T is not None and t is not None:
            return self.samples[key][t]
        else:
            return self.sample_from_key(key)

    def get_mean(self, x, a, i=None):
        key = (x, a) if i is None else (x, a, i)
        return self.mean_map.get(key, 0.0)

    def sample_from_key(self, key):
        mean = self.mean_map.get(key, 0.0)
        if self.mode == "bernoulli":
            return np.random.binomial(n=1, p=mean)
        elif self.mode == "bernoulli_sign":
            p = (mean + 1) / 2
            return 1 if np.random.binomial(n=1, p=p) else -1
        else:
            raise ValueError(f"Unsupported mode: {self.mode}")


def generate_mean_maps(X, A, m, reward_range=(0.2, 0.8), constraint_range=(-0.5, 0.5)):
    reward_means = {}
    constraint_means = {}

    for x, a in product(range(X), range(A)):
        mean_reward = np.round(np.random.uniform(*reward_range), 3)
        reward_means[str((x, a))] = float(mean_reward)

        for i in range(m):
            mean_constr = np.round(np.random.uniform(*constraint_range), 3)
            constraint_means[str((x, a, i))] = float(mean_constr)

    with open("data/reward_means.json", "w") as f:
        json.dump(reward_means, f, indent=2)

    with open("data/constraint_means.json", "w") as f:
        json.dump(constraint_means, f, indent=2)
