from src.goalsearch import GoalSearchSimple
from src.utils import *
from src.rl import *
from src.dynamicmap import *

import time
import numpy as np
import os
np.set_printoptions(precision=3)

# args:
SEED = 123
mode = 'environment'
# gpu?
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("will run on {} device!".format(device))

## minigrid args
# ATTN_SIZE = 5
# ENV_SIZE = 16
# CHANNELS = 7
# NB_ACTIONS = 3
# MAP_SIZE = 16
# MAP_CHANNELS = 16
# env_name = 'minigrid-v0'
# from gym_minigrid.wrappers import *
# from src.minigrid import OneHotDynamicObjectsWrapper
# env = gym.make('MiniGrid-Dynamic-Obstacles-16x16-v0')
# env = OneHotDynamicObjectsWrapper(env)
# env = ImgObsWrapper(env) # Get rid of the 'mission' field

# goalsearch
ATTN_SIZE = 3
ENV_SIZE = 10
CHANNELS = 4
NB_ACTIONS = 4
MAP_SIZE = 10
MAP_CHANNELS = 16
env_name = 'GoalSearch-v2'
env = GoalSearchSimple(10)

# # physenv args
# ATTN_SIZE = 21
# ENV_SIZE = 84
# CHANNELS = 3
# NB_ACTIONS = 4
# MAP_SIZE = 21
# MAP_CHANNELS = 48
# env_name = 'PhysEnv-v1'
# class ENV:
#     def render(self, img):
#         img = (img + 1)/2.
#         img = img * 255
#         img = img.astype(np.uint8)
#         return img
# env = ENV()

np.random.seed(SEED)
torch.manual_seed(SEED)

# initialize training data
demo_dir = 'testingdata-{}/'.format(env_name)
print('using testing data from {}'.format(demo_dir))
BATCH_SIZE = 6
seq_len = 25

dataset = torch.empty(seq_len, BATCH_SIZE, CHANNELS, ENV_SIZE, ENV_SIZE).to(device)
actionset = torch.empty(seq_len, BATCH_SIZE, dtype=torch.long).to(device)
agent_locs = torch.empty(seq_len, BATCH_SIZE, 2, dtype=torch.long).to(device)

for i in range(BATCH_SIZE):
    ep_dir = os.path.join(demo_dir, str(i))
    actionset[:, i] = torch.load(os.path.join(ep_dir, 'actions.pt'), map_location=device)[:seq_len]
    agent_locs[:, i] = torch.load(os.path.join(ep_dir, 'agent.pt'), map_location=device)[:seq_len]
    for step in range(seq_len):
        dataset[step, i] = torch.load(os.path.join(ep_dir, str(step) + '.pt'), map_location=device)

# initialize map
map = DynamicMap(
    size=MAP_SIZE,
    channels=MAP_CHANNELS,
    env_size=ENV_SIZE,
    env_channels=CHANNELS,
    nb_actions=NB_ACTIONS,
    batchsize=BATCH_SIZE,
    device=device,
    mode=mode)
map.to(device)
mse = MSEMasked()
mse_unmasked = nn.MSELoss(reduction='none')
model_dir = env_name + '/envrewardattn3_1/'
model_paths = [os.path.join(model_dir, 'map{}.pth'.format(step)) for step in range(100, 200100, 100)]

attn_span = range(-(ATTN_SIZE//2), ATTN_SIZE//2+1)
xy = np.flip(np.array(np.meshgrid(attn_span, attn_span)), axis=0).reshape(2, -1)

idxs_dim_0 = np.repeat(np.arange(BATCH_SIZE), ATTN_SIZE * ATTN_SIZE)
def create_attn_mask(loc):
    """create a batched mask out of batched attention locations"""
    attn = loc[range(BATCH_SIZE), :, np.newaxis] + xy  # get all indices in attention window size
    idxs_dim_2 = attn[:, 0, :].flatten()
    idxs_dim_3 = attn[:, 1, :].flatten()
    obs_mask = torch.zeros(BATCH_SIZE, 1, ENV_SIZE, ENV_SIZE)
    obs_mask[idxs_dim_0, :, idxs_dim_2, idxs_dim_3] = 1
    obs_mask = obs_mask.to(device)
    return obs_mask

CEloss = torch.nn.CrossEntropyLoss()
logsoftmax = torch.nn.LogSoftmax(dim=1)
start = time.time()
for step, path in enumerate(model_paths):
    # all_locs.append([])
    it = int(os.path.splitext(os.path.basename(path))[0][3:])
    print(it)
    # load the model
    print("loading " + path)
    map.load(path)
    map.to(device)
    # load the glimpse agent
    pathdir, pathname = os.path.split(path)
    glimpsenet = torch.load(os.path.join(pathdir, pathname.replace("map", "glimpse")), map_location='cpu')[0]
    if env_name == 'minigrid-v0' or env_name == 'GoalSearch-v2':
        glimpse_pi = PolicyFunction_x_x(channels=MAP_CHANNELS)
    elif env_name == 'PhysEnv-v1':
        glimpse_pi = PolicyFunction_21_84(channels=MAP_CHANNELS)
    else:
        raise ValueError('Unknown env_name')
    glimpse_V = ValueFunction(channels=MAP_CHANNELS, input_size=MAP_SIZE)
    glimpse_pi.load_state_dict(glimpsenet['policy_network'])
    glimpse_V.load_state_dict(glimpsenet['value_network'])
    glimpse_agent = GlimpseAgent(
        output_size=ENV_SIZE,
        attn_size=ATTN_SIZE,
        batchsize=BATCH_SIZE,
        policy_network=glimpse_pi,
        value_network=glimpse_V,
        device=device)
    test_batch = (dataset, actionset, None)
    state_batch, action_batch, reward_batch = test_batch
    # send to gpu
    state_batch = state_batch.to(device)
    action_batch = action_batch.to(device)
    attn_log_probs = []
    attn_rewards = []
    test_maps_prestep = []
    test_maps_heatmaps = []
    test_maps_poststep = []
    test_locs = []
    write_loss = 0
    post_write_loss = []
    post_step_loss = []
    sigma = []
    overall_reconstruction_loss = []
    overall_observed_objects = []
    overall_variational_acc = []
    overall_variational_reward = []
    # start!
    map.reset()
    # get an empty reconstruction
    post_step_reconstruction = map.reconstruct()
    # s = map.map.clone().detach()
    # loc = np.random.randint(0, 84 * 84, size=(BATCH_SIZE,))
    # loc = np.unravel_index(loc, (84, 84))
    # loc = np.array(list(zip(*loc)))
    if mode == 'follow':
        loc = glimpse_agent.norm_and_clip(agent_locs[0].cpu().numpy(), unraveled=True)
    else:
        loc = glimpse_agent.step(map.content().detach(), random=False)
        logits = glimpse_agent.pi(map.content().detach().to(device))
        test_maps_heatmaps.append(F.softmax(logits[3], dim=-1).view(1, ENV_SIZE, ENV_SIZE).detach().cpu())
        loc = np.clip(loc, ATTN_SIZE // 2, ENV_SIZE - 1 - ATTN_SIZE // 2).astype(np.int64)  # clip to avoid edges
    test_locs.append(loc[3].copy())
    obs_mask = create_attn_mask(loc)
    minus_obs_mask = 1-obs_mask
    # get an empty reconstruction
    post_step_reconstruction = map.reconstruct()
    for t in range(seq_len):
        # compute reconstruction loss
        post_step_loss.append(mse(post_step_reconstruction, state_batch[t], obs_mask).detach().cpu().numpy())
        # calculate per-channel losses on overall image
        observation = state_batch[t] * obs_mask
        channel_loss = []
        observed_objects = []
        for ch in range(CHANNELS):
            l = mse_unmasked(post_step_reconstruction[:, ch].flatten(start_dim=1), state_batch[t][:, ch].flatten(start_dim=1))
            l = l.mean(dim=1)
            channel_loss.append(l.detach().cpu().numpy())
            observed_objects.append(observation[:, ch].sum(dim=(1,2)).detach().cpu().numpy())
        overall_reconstruction_loss.append(channel_loss)
        overall_observed_objects.append(observed_objects)
        obs = state_batch[t] * obs_mask
        write_loss += map.write(obs, obs_mask, minus_obs_mask)
        post_write_reconstruction = map.reconstruct()
        post_write_loss.append(mse(post_write_reconstruction, state_batch[t], obs_mask).detach().cpu().numpy())
        test_maps_prestep.append(post_write_reconstruction[3].detach().cpu())
        # step forward the internal map
        actions = action_batch[t]
        actions = actions.unsqueeze(dim=1)
        onehot_action = torch.zeros(BATCH_SIZE, NB_ACTIONS).to(device)
        onehot_action.scatter_(1, actions, 1)
        step_cost = map.step(onehot_action)
        post_step_reconstruction = map.reconstruct()
        test_maps_poststep.append(post_step_reconstruction[3].detach().cpu())
        # # check variational accuracy
        # variational_output = map.q(s, map.map.detach(), onehot_action).detach()
        # variational_loc = variational_output.max(dim=1)[1].cpu().numpy()
        # variational_loc = np.unravel_index(variational_loc, (ENV_SIZE, ENV_SIZE))
        # variational_loc = np.array(list(zip(*variational_loc)))
        # acc = np.abs(variational_loc - loc) < 1
        # acc = np.all(acc, axis=1)
        # acc = np.mean(acc)
        # loc_indices = np.ravel_multi_index(np.transpose(loc), (ENV_SIZE, ENV_SIZE))
        # variational_reward = logsoftmax(variational_output).detach()
        # variational_reward = variational_reward[range(BATCH_SIZE), loc_indices].mean()
        # overall_variational_acc.append(acc)
        # overall_variational_reward.append(variational_reward)
        # s = map.map.clone().detach()
        # loc = np.random.randint(0, 84 * 84, size=(BATCH_SIZE,))
        # loc = np.unravel_index(loc, (84, 84))
        # loc = np.array(list(zip(*loc)))
        # select next attention spot
        if mode == 'follow':
            loc = glimpse_agent.norm_and_clip(agent_locs[t].cpu().numpy(), unraveled=True)
        else:
            loc = glimpse_agent.step(map.content().detach(), random=False)
            logits = glimpse_agent.pi(map.content().detach().to(device))
            test_maps_heatmaps.append(F.softmax(logits[3], dim=-1).view(1, ENV_SIZE, ENV_SIZE).detach().cpu())
            sigma.append(glimpse_agent.policy.entropy(logits).detach().cpu().numpy())
            loc = np.clip(loc, ATTN_SIZE//2, ENV_SIZE - 1 - ATTN_SIZE//2).astype(np.int64)  # clip to avoid edges
        # loc = all_locs[step][t+1]
        # all_locs[-1].append(loc)
        test_locs.append(loc[3].copy())
        obs_mask = create_attn_mask(loc)
        minus_obs_mask = 1-obs_mask
    overall_reconstruction_loss = torch.FloatTensor(overall_reconstruction_loss)
    overall_observed_objects = torch.FloatTensor(overall_observed_objects)
    overall_variational_acc = torch.FloatTensor(overall_variational_acc)
    overall_variational_reward = torch.FloatTensor(overall_variational_reward)
    post_step_loss = torch.FloatTensor(post_step_loss)
    post_write_loss = torch.FloatTensor(post_write_loss)
    sigma = torch.FloatTensor(sigma)
    test_loss = 0.01 * write_loss + post_write_loss.mean() + post_step_loss.mean()
    model_name = os.path.splitext(os.path.basename(path))[0]
    torch.save({
        'write_loss': 0.01 * write_loss,
        'post write reconstruction loss': post_write_loss,
        'post step reconstruction loss': post_step_loss,
        'sigma': sigma,
        'overall_reconstruction_loss': overall_reconstruction_loss,
        'overall_observed_objects': overall_observed_objects,
        'overall_variational_acc': overall_variational_acc,
        'overall_variational_reward': overall_variational_reward,
    }, os.path.join(model_dir, 'loss_{}.pt'.format(model_name)))
    # # save some generated images
    # save_example_images(
    #    [state_batch[t][3].cpu() for t in range(seq_len)],
    #    test_maps_heatmaps,
    #    test_maps_prestep,
    #    test_maps_poststep,
    #    test_locs,
    #    os.path.join(model_dir, 'predictions_{}.pdf'.format(model_name)),
    #    env,
    #    ATTN_SIZE,
    #    ENV_SIZE,
    #    1)
    to_print = "[{}] test loss: {:.3f}".format(model_name, test_loss)
    to_print += ", overall image loss: {:.3f}".format(overall_reconstruction_loss.mean())
    to_print += ", glimpse entropy: {:.3f}".format(sigma.mean())
    to_print += ", time/iter (ms): {:.3f}".format(1000 * (time.time() - start))
    print(to_print)
    start = time.time()
# all_locs = np.array(all_locs)
# np.save('locs.npy', all_locs)
