import torch
import os
import numpy as np
import gym

## torch settings

device = torch.device('cpu')
dtype = torch.float

## env

env_name = 'CartPole-v0'
env = gym.make(env_name)
#env._max_episode_steps = 5000

## feature/action dims
n_φ = env.observation_space.shape[0]
n_a = env.action_space.n

## render settings

render, render_after = False, 100000

## agent params

γ = 0.99
α_πs = [10**-4]#[10**-3.5, 10**-4.5, 10**-5.5]#[10**-3, 10**-3.5, 10**-4, 10**-4.5, 10**-5, 10**-5.5]
α_vs = [10**-4]#[10**-3, 10**-4, 10**-5]
nn_widths = [128]

## nn stuff

def πn_init(layers):
  θs = []
  for i in range(1, len(layers)):
    fan_in, fan_out = layers[i - 1], layers[i]
    sd = np.sqrt(1 / fan_in) if i < len(layers) - 1  else 0
    θs.append(sd * torch.randn(fan_in, fan_out, device=device, dtype=dtype))
    θs[-1].requires_grad = True
  return θs

def πn_f(θs, x, activation=torch.tanh):
  for i, θ in enumerate(θs):
    x = torch.matmul(x, θ)
    x = activation(x) if i < len(θs) - 1 else torch.softmax(x, dim=0)
  return x

def vn_init(layers):
  ws = []
  for i in range(1, len(layers)):
    fan_in, fan_out = layers[i - 1], layers[i]
    sd = np.sqrt(1 / fan_in) if i < len(layers) - 1  else 0
    ws.append(sd * torch.randn(fan_in, fan_out, device=device, dtype=dtype))
    ws[-1].requires_grad = True
  return ws

def vn_f(ws, x, activation=torch.tanh):
  for i, w in enumerate(ws):
    x = torch.matmul(x, w)
    if i < len(ws) - 1: x = activation(x)
  return x

## experiment

n_runs = 30
n_frames = 100000
n_eval_eps = 10
results = [0.0] * n_frames
results_dir = 'DeepResults'
saving = False

## run exp

for α_v in α_vs:
  for α_π in α_πs:
    for nn_width in nn_widths:
      # nn architecture
      πn = [n_φ, nn_width, nn_width, n_a]
      vn = [n_φ, nn_width, nn_width, 1]

      # set exp name
      exp_name = f"{env_name}_ac_{α_π:.8f}_{α_v:.8f}_{nn_width}"

      for run in range(0, n_runs, 1):
        #if os.path.isfile(f"{results_dir}/{exp_name}_{n_frames}_{run}.npy"):
        #  print(f"{results_dir}/{exp_name}_{n_frames}_{run}.npy exists!")
        #  continue

        # seed rng
        torch.manual_seed(run)
        env.seed(run)

        # init policy
        θs = πn_init(πn)
        ws = vn_init(vn)

        # episodic returns buffer
        Gs = []

        # episode start
        s = torch.as_tensor(env.reset(), device=device, dtype=dtype)
        eπ = [torch.zeros(θ.shape) for θ in θs]
        ev = [torch.zeros(w.shape) for w in ws]
        G = 0.0
        print(vn_f(ws, s).item())

        # divergence check
        v = torch.tensor(0.0)
        diverged = False

        # do run
        for frame in range(n_frames):
          # render
          if render and frame + 1 >= render_after: env.render()

          # choose action
          π = πn_f(θs, s)
          if not diverged and (torch.any(torch.isnan(π)) or torch.abs(v) > 1e10):
            diverged = True
          a = torch.distributions.Categorical(π).sample().item() if not diverged else 0

          # env step
          sp, r, T, _ = env.step(a)
          sp = torch.as_tensor(sp, device=device, dtype=dtype)
          G += r

          if not diverged:
            v, vp = vn_f(ws, s), vn_f(ws, sp)
            δ = r + γ * vp.detach() - v.detach()
            print(δ - r)
            # update
            Jπ = torch.log(π[a]); Jπ.backward()
            Jv = v; Jv.backward()
            with torch.no_grad():
              for i in range(len(θs)):
                eπ[i] = γ * eπ[i] + θs[i].grad
                ev[i] = γ * ev[i] + ws[i].grad
                θs[i] += α_π * δ * eπ[i]
                ws[i] += α_v * δ * ev[i]
                θs[i].grad.zero_()
                ws[i].grad.zero_()

          # store results
          if len(Gs) > 0:
            avg_len = min(len(Gs), n_eval_eps)
            results[frame] = sum(Gs[-avg_len:]) / avg_len
          else:
            results[frame] = 0.0
          if saving and (frame + 1) % n_frames == 0:
            np.save(f"{results_dir}/{exp_name}_{frame + 1}_{run}.npy", results[:frame + 1])

          # next time step
          if T:
            # output stuff
            H = -(π * torch.log(π)).sum().item() if not diverged else 0.0
            print(f"{exp_name} - run: {run + 1:>2}, G_avg: {results[frame]:>9.3f}, G: {G:>9.3f}, steps: {frame + 1:>7}, H: {H:>5.3f}")
            Gs.append(G)
            # restart episode
            s = torch.as_tensor(env.reset(), device=device, dtype=dtype)
            eπ = [torch.zeros(θ.shape) for θ in θs]
            ev = [torch.zeros(w.shape) for w in ws]
            G = 0.0
            print(vn_f(ws, s).item())
          else:
            # update state
            s = sp
