import torch
import numpy as np
import gym

## torch settings

device = torch.device('cpu')
dtype = torch.float

## env

env_name = 'CartPole-v0'
env = gym.make(env_name)
#env._max_episode_steps = 5000

## feature/action dims
n_φ = env.observation_space.shape[0]
n_a = env.action_space.n

## render settings

render, render_after = False, 100000

## agent params

γ = 0.99
αs = [10**-3, 10**-3.5, 10**-4, 10**-4.5, 10**-5, 10**-5.5, 10**-6, 10**-6.5, 10**-7, 10**-7.5]
ωs = [0.0]
nn_widths = [256]

## nn stuff

def πn_init(layers):
  θs = []
  for i in range(1, len(layers)):
    fan_in, fan_out = layers[i - 1], layers[i]
    sd = np.sqrt(1 / fan_in) if i < len(layers) - 1  else 0
    θs.append(sd * torch.randn(fan_in, fan_out, device=device, dtype=dtype))
    θs[-1].requires_grad = True
  return θs

def πn_f(θs, x, activation=torch.tanh):
  for i, θ in enumerate(θs):
    x = torch.matmul(x, θ)
    x = activation(x) if i < len(θs) - 1 else torch.softmax(x, dim=0)
  return x

## experiment

n_runs = 30
n_frames = 100000
n_eval_eps = 10
results = [0.0] * n_frames
results_dir = 'DeepResults'
saving = True

## run exp

for α in αs:
  for ω in ωs:
    for w in nn_widths:
      # nn architecture
      nn = [n_φ, w, w, n_a]

      # set exp name
      exp_name = f"{env_name}_ipg_{ω:.2f}_{α:.8f}_{w}"

      for run in range(0, n_runs, 1):
        # seed rng
        torch.manual_seed(run)
        env.seed(run)

        # init policy
        θs = πn_init(nn)

        # episodic returns buffer
        Gs = []

        # episode start
        s = torch.as_tensor(env.reset(), device=device, dtype=dtype)
        e = [torch.zeros(θ.shape) for θ in θs]
        Δθ = [torch.zeros(θ.shape) for θ in θs]
        G = 0.0

        # do run
        for frame in range(n_frames):
          # render
          if render and frame + 1 >= render_after: env.render()

          # choose action
          π = πn_f(θs, s)
          a = torch.distributions.Categorical(π).sample().item()

          # env step
          sp, r, T, _ = env.step(a)
          sp = torch.as_tensor(sp, device=device, dtype=dtype)
          G += r

          # update
          J = torch.log(π[a])
          J.backward()
          with torch.no_grad():
            for i in range(len(θs)):
              e[i] = γ * e[i] + θs[i].grad
              Δθ[i] = (1 - ω) * Δθ[i] + α * r * e[i]
              if ω != 0.0: θs[i] += ω * Δθ[i]
              if T: θs[i] += (1 - ω) * Δθ[i]
              θs[i].grad.zero_()

          # store results
          if len(Gs) > 0:
            avg_len = min(len(Gs), n_eval_eps)
            results[frame] = sum(Gs[-avg_len:]) / avg_len
          else:
            results[frame] = 0.0
          if saving and (frame + 1) % n_frames == 0:
            np.save(f"{results_dir}/{exp_name}_{frame + 1}_{run}.npy", results[:frame + 1])

          # next time step
          if T:
            # output stuff
            H = -(π * torch.log(π + 1e-6)).sum().item()
            print(f"{exp_name} - run: {run + 1:>2}, G_avg: {results[frame]:>9.3f}, G: {G:>9.3f}, steps: {frame + 1:>7}, H: {H:>5.3f}")
            Gs.append(G)
            # restart episode
            s = torch.as_tensor(env.reset(), device=device, dtype=dtype)
            e = [torch.zeros(θ.shape) for θ in θs]
            Δθ = [torch.zeros(θ.shape) for θ in θs]
            G = 0.0
          else:
            # update state
            s = sp
