import time
import numpy as np
from collections import OrderedDict
import os
import matplotlib
matplotlib.use('Agg')
np.set_printoptions(precision=3)

import torch
from torch.utils.data import DataLoader
import torch.optim as optim

from pytorch_rl.utils import AverageMeter
from src.utils import SequenceDataset, time_collate
from src.rl import GlimpseAgent
from src.networks import *
from src.dynamicmap import DynamicMap


if __name__ == "__main__":
    # args:
    BATCH_SIZE = 8
    SEED = 123
    # MODE = 'environmental'
    MODE = 'follow'

    ## minigrid args
    # ATTN_SIZE = 5
    # ENV_SIZE = 16
    # CHANNELS = 7
    # NB_ACTIONS = 3
    # MAP_SIZE = 16
    # MAP_CHANNELS = 16
    # START_SEQ_LEN = 50
    # END_SEQ_LEN = 50
    # env_name = 'minigrid-v0'

    # # physenv args
    # ATTN_SIZE = 21
    # ENV_SIZE = 84
    # CHANNELS = 3
    # NB_ACTIONS = 4
    # MAP_SIZE = 21
    # MAP_CHANNELS = 48
    # START_SEQ_LEN = 25
    # END_SEQ_LEN = 25
    # env_name = 'PhysEnv-v1'

    # goalsearch args
    ATTN_SIZE = 3
    ENV_SIZE = 10
    CHANNELS = 4
    NB_ACTIONS = 4
    MAP_SIZE = 10
    MAP_CHANNELS = 16
    START_SEQ_LEN = 50
    END_SEQ_LEN = 50
    env_name = 'GoalSearch-v2'

    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # initialize training data
    demo_dir = 'trainingdata-{}/'.format(env_name)
    print('using training data from {}'.format(demo_dir))
    dataset = SequenceDataset(data_dir=demo_dir)
    seq_len = START_SEQ_LEN
    dataset.set_seqlen(seq_len)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE,
                              num_workers=8, collate_fn=time_collate,
                              drop_last=True, pin_memory=True)

    # gpu?
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("will run on {} device!".format(device))

    # initialize map
    map = DynamicMap(
        size=MAP_SIZE,
        channels=MAP_CHANNELS,
        env_size=ENV_SIZE,
        env_channels=CHANNELS,
        batchsize=BATCH_SIZE,
        nb_actions=NB_ACTIONS,
        device=device,
        mode=MODE)
    map.to(device)

    if env_name == 'minigrid-v0' or env_name == 'GoalSearch-v2':
        policy_network = PolicyFunction_x_x(channels=MAP_CHANNELS)
    elif env_name == 'PhysEnv-v1':
        policy_network = PolicyFunction_21_84(channels=MAP_CHANNELS)
    else:
        raise ValueError('Unknown env_name')
    value_network = ValueFunction(channels=MAP_CHANNELS, input_size=MAP_SIZE)
    glimpse_agent = GlimpseAgent(
        output_size=ENV_SIZE,
        attn_size=ATTN_SIZE,
        batchsize=BATCH_SIZE,
        policy_network=policy_network,
        value_network=value_network,
        device=device,)

    optimizer = optim.Adam(map.parameters(), lr=1e-4)

    # iterate through data and learn!
    training_metrics = OrderedDict([
        ('map/write_cost', AverageMeter()),
        ('map/step_cost', AverageMeter()),
        ('map/post_write', AverageMeter()),
        ('map/post_step', AverageMeter()),
        ('map/overall', AverageMeter()),
        ('map/min_overall', AverageMeter()),
        ('q/loss', AverageMeter()),
        ('glimpse/policy_loss', AverageMeter()),
        ('glimpse/policy_entropy', AverageMeter()),
        ('glimpse/val_loss', AverageMeter()),
        ('glimpse/avg_val', AverageMeter()),
        ('glimpse/avg_reward', AverageMeter()),
        ])

    i_batch = 0
    start = time.time()
    for epoch in range(30000):
        train_loader_iter = iter(train_loader)
        for (state_batch, action_batch, reward_batch, agent_locs) in train_loader_iter:
            state_batch = state_batch.to(device)
            action_batch = action_batch.to(device)
            reward_batch = reward_batch.to(device)
            if not MODE == 'follow':
                agent_locs = None
            # get started training!
            optimizer.zero_grad()
            loss = map.lossbatch(
                state_batch,
                action_batch,
                reward_batch,
                glimpse_agent,
                training_metrics,
                agent_locs)
            # propagate loss back through entire training sequence
            loss.backward()
            optimizer.step()
            if not MODE == 'follow':
                # and update the glimpse agent
                glimpse_agent.update(map.content().detach(), None, training_metrics, None, scope='glimpse')
            glimpse_agent.reset()
            i_batch += 1

            if i_batch % 10 == 0:
                to_print = 'epoch [{}] batch [{}]'.format(epoch, i_batch)
                for key, value in training_metrics.items():
                    if type(value) == AverageMeter:
                        to_print += ", {}: {:.3f}".format(key, value.avg)
                to_print += ", time/it (ms): {:.3f}".format(1000 * (time.time() - start)/100)
                print(to_print)
                start = time.time()
            if i_batch % 100 == 0:
                agentsavepath = env_name + '/followattn21_1/'
                print('saving network weights to {} ...'.format(agentsavepath))
                torch.save(map.tosave(), os.path.join(agentsavepath, 'map{}.pth'.format(i_batch)))
                # glimpse_net = glimpse_agent.ppo.actor_critic
                glimpse_net = {'policy_network': glimpse_agent.policy_network.state_dict(),
                               'value_network': glimpse_agent.value_network.state_dict(),
                               'v_optimizer': glimpse_agent.a2c.V_optimizer.state_dict(),
                               'pi_optimizer': glimpse_agent.a2c.pi_optimizer.state_dict()},
                torch.save(glimpse_net, os.path.join(agentsavepath, 'glimpse{}.pth'.format(i_batch)))
