import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import os

import torch
import torch.nn.functional as F

from pytorch_rl.utils import ImgToTensor
from pytorch_rl.policies import MultinomialPolicy
import pytorch_rl.networks as networks
from phys_env.phys_env import PhysEnv

from src.dynamicmap import DynamicMap
from src.rl import GlimpseAgent, AttentionConstrainedEnvironment
from src.networks import *
from src.goalsearch import GoalSearchSimple

d = 'GoalSearch-v2/train_rl_full_1/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

env = GoalSearchSimple(10)
policy = MultinomialPolicy()

trunk = FlattenTrunk
ac = networks.ActorCritic(trunk, (4, 10, 10), 4)
ac.to(device)

def preprocess(x):
    # return ImgToTensor()(x)/127.5 - 1
    return ImgToTensor()(x).to(device)

for step in range(100, 157000, 100):
    path = os.path.join(d, 'actor_critic_{}.pth'.format(step))
    print("loading rl agent {}".format(path))
    ac.load_state_dict(torch.load(path, map_location=device))
    data = []
    for ep in range(5):
        obs = preprocess(env.reset())
        done = False
        reward = 0
        while not done:
            logits = ac.pi(obs.unsqueeze(dim=0))
            action = policy(logits, test=True)
            next_obs, r, done, _ = env.step(action.cpu().numpy())
            obs = preprocess(next_obs)
            reward += r
        data.append(reward)
        print("episode {} reward {}".format(ep, reward))
    np.save(os.path.join(d, 'rltest{}.npy'.format(step)), data)
