import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import os

import torch
import torch.nn.functional as F

from pytorch_rl.utils import ImgToTensor
from pytorch_rl.policies import MultinomialPolicy
import pytorch_rl.networks as networks
from phys_env.phys_env import PhysEnv

from src.dynamicmap import DynamicMap
from src.rl import GlimpseAgent, AttentionConstrainedEnvironment
from src.networks import *
from src.goalsearch import GoalSearchSimple

mode = 'follow'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# # physenv
# ATTN_SIZE = 21
# ENV_SIZE = 84
# CHANNELS = 3
# NB_ACTIONS = 4
# MAP_SIZE = 21
# MAP_CHANNELS = 48
# trunk = ConvTrunk21
# env_name = 'PhysEnv-v1'
# d = env_name + '/train_rl_{}_2/'.format(mode)
# from phys_env import phys_env
# env = phys_env.PhysEnv()

# goalsearch
ATTN_SIZE = 3
ENV_SIZE = 10
CHANNELS = 4
NB_ACTIONS = 4
MAP_SIZE = 10
MAP_CHANNELS = 16
trunk = FlattenTrunk
env_name = 'GoalSearch-v2'
d = env_name + '/train_rl_follow_1/'
env = GoalSearchSimple(10)

map = DynamicMap(
    size=MAP_SIZE,
    channels=MAP_CHANNELS,
    env_size=ENV_SIZE,
    env_channels=CHANNELS,
    batchsize=1,
    nb_actions=NB_ACTIONS,
    device=device,
    mode=mode,)

env = AttentionConstrainedEnvironment(ENV_SIZE, ATTN_SIZE, device, env)

policy = MultinomialPolicy()
state_shape = (MAP_CHANNELS, MAP_SIZE, MAP_SIZE)

ac = networks.ActorCritic(trunk, state_shape, NB_ACTIONS)
ac.to(device)

for step in range(100, 70000, 100):
    path = os.path.join(d, 'map_{}.pth'.format(step))
    print("loading map {}".format(path))
    map.load(path)
    map.to(device)
    path = os.path.join(d, 'glimpse_{}.pth'.format(step))
    print("loading glimpse agent {}".format(path))
    glimpsenet = torch.load(path, map_location='cpu')
    if env_name == 'PhysEnv-v1':
        glimpse_pi = PolicyFunction_21_84(channels=state_shape[0])
    else:
        glimpse_pi = PolicyFunction_x_x(channels=state_shape[0])
    glimpse_V = ValueFunction(channels=state_shape[0], input_size=state_shape[1])
    glimpse_pi.load_state_dict(glimpsenet['policy_network'])
    glimpse_V.load_state_dict(glimpsenet['value_network'])
    # glimpse_pi = glimpsenet.policy_head
    # glimpse_V = glimpsenet.value_head
    glimpse_agent = GlimpseAgent(
        output_size=ENV_SIZE,
        attn_size=ATTN_SIZE,
        batchsize=1,
        policy_network=glimpse_pi,
        value_network=glimpse_V,
        device=device)
    path = os.path.join(d, 'actor_critic_{}.pth'.format(step))
    print("loading rl agent {}".format(path))
    ac.load_state_dict(torch.load(path, map_location=device)['actor_critic'])
    data = []
    for ep in range(5):
        map.reset()
        # starting glimpse location
        if mode == 'follow':
            agent_loc = env.env.player_body.position
            glimpse_action = (agent_loc[1], 84 - agent_loc[0])
            glimpse_action_clipped = glimpse_agent.norm_and_clip(glimpse_action, unraveled=True)
        else:
            glimpse_logits = glimpse_agent.pi(map.content().detach())
            glimpse_action = glimpse_agent.policy(glimpse_logits).detach()
            glimpse_action_clipped = glimpse_agent.norm_and_clip(glimpse_action.cpu().numpy())
        obs, unmasked_obs, mask = env.reset(loc=glimpse_action_clipped)
        done = False
        reward = 0
        while not done:
            # write observation to map
            map.write(obs.unsqueeze(dim=0), mask, 1 - mask)
            # take a step in the environment!
            state = map.content().detach()
            logits = ac.pi(state)
            action = policy(logits, test=True)
            # step the map forward according to agent action
            onehot_action = torch.zeros((1, 4)).to(device)
            onehot_action[0, action] = 1
            map.step(onehot_action)
            # no need to store gradient information for rollouts
            map.detach()
            # glimpse agent decides where to look after map has stepped
            if mode == 'follow':
                # agent_loc = env.env.player_body.position
                # glimpse_action = (agent_loc[1], 84 - agent_loc[0])
                glimpse_action = (self.env.agent_x, self.env.agent_y)
                glimpse_action_clipped = glimpse_agent.norm_and_clip(glimpse_action, unraveled=True)
            else:
                glimpse_logits = glimpse_agent.pi(map.content().detach())
                glimpse_action = glimpse_agent.policy(glimpse_logits).detach()
                glimpse_action_clipped = glimpse_agent.norm_and_clip(glimpse_action.cpu().numpy())
            (next_obs, next_unmasked_obs, next_mask), r, done, _ = env.step(action.cpu().numpy(), loc=glimpse_action_clipped)
            obs = next_obs
            mask = next_mask
            unmasked_obs = next_unmasked_obs
            reward += r
        data.append(reward)
        print("episode {} reward {}".format(ep, reward))
    np.save(os.path.join(d, 'rltest{}.npy'.format(step)), data)
