import numpy as np
from utils.ogd import OGD


class AdversarialDataGenerator:
    def __init__(self, X, A, m, eta=0.01, adv_reward=True, adv_constraints=True):
        self.X = X
        self.A = A
        self.m = m
        self.eta = eta
        self.adv_reward = adv_reward
        self.adv_constraints = adv_constraints
        self.rng = np.random.default_rng()

        if self.adv_constraints:
            self.constraint_learners = [
                [OGD(A, eta) for _ in range(m)] for _ in range(X)
            ]
            self.reward_learners = [OGD(A, eta) for _ in range(X)]
        else:
            if self.adv_reward:
                self.reward_learners = [OGD(A, eta) for _ in range(X)]

        self.true_reward_vectors = {
            s: self.rng.uniform(0, 1, size=A) for s in range(X)
        }
        self.true_constraint_vectors = {
            s: [self.rng.uniform(-1, 1, size=A) for _ in range(m)] for s in range(X)
        }

        self.current_reward_vectors = None
        self.current_constraint_vectors = None

    def get_adversarial_data(self, policy, t=None):
        reward_vectors = {}
        constraint_vectors = {}

        for s in range(self.X):
            pi_s = np.array([policy.get((s, a), 0.0) for a in range(self.A)])

            if self.adv_reward:
                grad = -self.true_reward_vectors[s] * pi_s
                self.reward_learners[s].update(grad)
                x_t = self.reward_learners[s].predict()
                reward_vectors[s] = x_t
            else:
                reward_vectors[s] = self.true_reward_vectors[s]

            constraint_vectors[s] = []
            for i in range(self.m):
                if self.adv_constraints:
                    grad_c = -self.true_constraint_vectors[s][i] * pi_s
                    self.constraint_learners[s][i].update(grad_c)
                    x_t_c = self.constraint_learners[s][i].predict()
                    constraint_vectors[s].append(x_t_c)
                else:
                    constraint_vectors[s].append(self.true_constraint_vectors[s][i])

        self.current_reward_vectors = reward_vectors
        self.current_constraint_vectors = constraint_vectors

        return reward_vectors, constraint_vectors


