from collections import defaultdict
import torch as th
import torch.nn.functional as f
import numpy as np

th.autograd.set_detect_anomaly(True)

# Construct the environment
## We have three different environment configurations:

class IlluToyMDP():

  def __init__(self, adversary_policy=None, victim_policy=None, mode="unattacked"):
    # 3 different modes: unattacked, attacked_adversary, attacked_victim
    self.step_ctr = 0
    self.set_mode(mode, adversary_policy, victim_policy)
    self.prob_s0 = th.tensor([1/3., 2/3.], dtype=th.float)
    self.reward_matrix = th.FloatTensor([[-1, 1],[1, 0.5]])  ## state x action
    pass

  def set_mode(self, mode, adversary_policy=None, victim_policy=None):
    if mode == "attacked_adversary":
      assert victim_policy is not None, "Need a valid victim policy!"
    elif mode == "attacked_victim":
      assert adversary_policy is not None, "Need a valid adversary policy!"
    self.mode = mode
    self.adversary_policy = adversary_policy
    self.victim_policy = victim_policy

  def step(self, action):
    info = {}
    if self.mode == "unattacked":
      reward = self.reward_matrix[self.state, action].item()
    elif self.mode == "attacked_adversary":
      action_v = th.multinomial(self.victim_policy(action), 1)
      reward = -self.reward_matrix[self.state, action_v].item()
      info["victim_action"] = action_v
    elif self.mode == "attacked_victim":
      reward = self.reward_matrix[self.state, action].item()
    obs = self.obs
    done = True
    terminated = False
    return obs, reward, done, terminated, info

  def reset(self):
    self.state = th.multinomial(self.prob_s0, 1)
    if self.mode == "unattacked":
      self.obs = self.state
    elif self.mode == "attacked_adversary":
      self.obs = self.state
    elif self.mode == "attacked_victim":
      self.obs = self.adversary_policy(self.state)
    return self.obs

  def exact_kl(self,
               order,
               adversary_policy,
               victim_policy):
    assert order in ["attacked|unattacked", "unattacked|attacked"]
    action_space = list(range(self.reward_matrix.shape[1]))
    obs_space = list(range(self.reward_matrix.shape[0]))
    state_space = obs_space
    exact_kls = []

    if True:
      for o in obs_space:
          o = th.tensor(o)
          for a in action_space:
            a = th.tensor(a)
            # calculate P_a(tau_o) = P_a(o, a) = int_s pi_a(o|s)p(s)pi_v(a|o)
            P_a = 0.
            for s in state_space:
              s = th.tensor(s)
              P_a += self.prob_s0[s] * adversary_policy(s)[o] * victim_policy(o)[a]
            P = self.prob_s0[o] * victim_policy(o)[a]
            if order == "attacked|unattacked":
              if not(P_a.item() == 0.0 == P.item()):
                contrib = P_a.clone() * (th.log(P_a.clone()) - th.log(P.clone()))
                exact_kls.append(contrib)
            else:
              if not(P_a.item() == 0.0 == P.item()):
                contrib = P.clone() * (th.log(P.clone()) - th.log(P_a.clone()))
                exact_kls.append(contrib)
    mean_kl = th.mean(th.cat([ek.reshape(1) for ek in exact_kls]))
    if th.isnan(mean_kl).item() or mean_kl.item() < 0.0:
        assert False, "FATAL"
    return mean_kl


def exact_dist():
    # Find optimal adversary policy with REINFORCE (with exact distributional constraints)
    env = IlluToyMDP()
    n_states = 2
    params = th.nn.Parameter(th.zeros(n_states, n_states).uniform_().double(), requires_grad=True)  # defines adversary policy logits
    n_rollout_steps = 100
    n_rollout_steps_eval = 1000
    n_updates = 10
    n_episodes = 1000
    alpha = 5*10E-3
    alpha_lambda = 10.0
    lmbda = 1.0  # arbitrary!
    adversary_policy = lambda obs: f.softmax(params[obs])
    victim_policy = lambda obs: th.tensor([0.0, 1.0], dtype=th.float) if obs.item() == 0 else th.tensor([1.0, 0.0],
                                                                                                        dtype=th.float)
    env.set_mode("attacked_adversary", victim_policy=victim_policy)
    done = False
    obs = env.reset()
    trajectories = [defaultdict(lambda: {})]
    trajectories[-1]["obs"] = obs.item()
    for i in range(n_episodes):

        # rollout phase
        for j in range(n_rollout_steps):
            action = th.multinomial(adversary_policy(obs), 1)
            _, reward, done, _, info = env.step(action)
            trajectories[-1]["reward"] = reward
            trajectories[-1]["action"] = action.item()
            if done:
                obs = env.reset()
                trajectories.append(defaultdict(lambda: {}))
                trajectories[-1]["obs"] = obs.item()

        # train phase

        kl_exact_lst = []
        for j in range(n_updates):
            kl_exact = env.exact_kl("attacked|unattacked",
                                    adversary_policy=adversary_policy,
                                    victim_policy=victim_policy)
            loss = 0
            for n, traj in enumerate(trajectories[:-1]):
                try:
                    prob = f.softmax(params[traj["obs"]])[traj["action"]]
                    loss += th.log(prob) * (traj["reward"]) - lmbda * kl_exact
                    kl_exact_lst.append(kl_exact.detach().cpu().numpy())
                except:
                    #print("WARNING: loss addition failed")
                    pass

            loss /= len(trajectories)
            loss.backward()
            params.data += alpha * params.grad.data
            params.grad.zero_()

            # Now we do the dual ascent step
            with th.no_grad():
                kl_exact = env.exact_kl("attacked|unattacked",
                                        adversary_policy=adversary_policy,
                                        victim_policy=victim_policy)
            lmbda = max(lmbda + alpha_lambda * kl_exact, 0)

        print("LOSS:", loss.item())
        print(f.softmax(params, -1))
        print("lambda: ", lmbda.item())
        print("kl: ", np.mean(kl_exact_lst))

        # evaluate
        reward_lst = []
        for j in range(n_rollout_steps_eval):
            action = th.multinomial(adversary_policy(obs), 1)
            _, reward, done, _, info = env.step(action)
            reward_lst.append(reward)
            if done:
                obs = env.reset()
        print("TEST REWARD MEAN: ", np.mean(reward_lst))

        trajectories = [defaultdict(lambda: [])]

if __name__ == "__main__":
	exact_dist()

