# -*- coding: utf-8 -*-
import argparse
import os
import glob
import numpy as np
import torch
from torch import nn, optim
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
from env import CONTROL_SUITE_ENVS, Env, GYM_ENVS, EnvBatcher
from memory import SeqExperienceReplay
from models import bottle, bottle_action, bottle_transition, Encoder, ObservationModel, RewardModel, TransitionModel, GNetwork, DeterministicPolicy, GaussianPolicy
from utils import lineplot, write_video, create_log_gaussian, hard_update, soft_update

from env import postprocess_observation

# Hyperparameters
parser = argparse.ArgumentParser(description='Deep Free Energy Network')
parser.add_argument('--id', type=str, default='fenet', help='Experiment ID')
parser.add_argument('--results-dir', type=str, default=f'results', help=f'Results Directory. Default: results')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--env', type=str, default='cheetah-run', choices=GYM_ENVS + CONTROL_SUITE_ENVS, help='Gym/Control Suite environment')
parser.add_argument('--symbolic-env', action='store_true', help='Symbolic features')
parser.add_argument('--observation-noise', type=float, default=0, metavar='εo', help='Observation noise eg. 0.1 or 0.2')
parser.add_argument('--sparse-env', action='store_true', help='Sparse environment')
# imitation learning
parser.add_argument('--number-of-experts', type=int, default=10000, metavar='E', help='Number of experts')
# network architecture
parser.add_argument('--activation-function', type=str, default='relu', choices=dir(F), help='Model activation function')
parser.add_argument('--embedding-size', type=int, default=1024, metavar='E', help='Observation embedding size')  # Note that the default encoder for visual observations outputs a 1024D vector; for other embedding sizes an additional fully-connected layer is used
parser.add_argument('--hidden-size', type=int, default=200, metavar='H', help='Hidden size')
parser.add_argument('--belief-size', type=int, default=200, metavar='H', help='Belief/hidden size')
parser.add_argument('--state-size', type=int, default=30, metavar='Z', help='State/latent size')
# mujoco env
parser.add_argument('--max-episode-length', type=int, default=1000, metavar='T', help='Max episode length')
parser.add_argument('--action-repeat', type=int, default=4, metavar='R', help='Action repeat')
# training phase
parser.add_argument('--experience-size', type=int, default=1500, metavar='D', help='Experience replay maximum size for seeds + active collection (number of episodes)')
parser.add_argument('--episodes', type=int, default=1000, metavar='E', help='Total number of episode for training')
parser.add_argument('--seed-episodes', type=int, default=10, metavar='S', help='Seed episodes')
parser.add_argument('--collect-interval', type=int, default=100, metavar='C', help='Collect interval')
parser.add_argument('--batch-size', type=int, default=50, metavar='B', help='Batch size')
parser.add_argument('--chunk-size', type=int, default=50, metavar='L', help='Chunk size')
parser.add_argument('--burn-in-size', type=int, default=20, metavar='Burn', help='Burn in size')
parser.add_argument('--action-noise', type=float, default=0.3, metavar='εa', help='Action noise for training phase')
parser.add_argument('--start-expert', type=int, default=0, metavar='p', help='Start expert imitation calculation')
parser.add_argument('--start-agent', type=int, default=0, metavar='p', help='Start agent RL calculation')
parser.add_argument('--on-policy-episodes', type=int, default=1, metavar='e', help='Number of episodes to calculate policy loss (1 to use off policy data)')
parser.add_argument('--value-period', type=int, default=0, metavar='vp', help='Stop other learnings until this period')
# learning
parser.add_argument('--reward-scale', type=float, default=100, metavar='R0', help='Reward weight')
parser.add_argument('--action-scale', type=float, default=10, metavar='a0', help='Action weight')
parser.add_argument('--policy-scale', type=float, default=100, metavar='p0', help='Policy weight in imitation')
parser.add_argument('--free-nats', type=float, default=3, metavar='F', help='Free nats')
parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate')
parser.add_argument('--learning-rate-schedule', type=int, default=0, metavar='αS', help='Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)') 
parser.add_argument('--adam-epsilon', type=float, default=1e-4, metavar='ε', help='Adam optimiser epsilon value')
parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm')
parser.add_argument('--tau', type=float, default=0.01, metavar='τ', help='Target smoothing coefficient')
parser.add_argument('--discount', type=float, default=0.99, metavar='gamma', help='Discount rate for RL')
parser.add_argument('--imagination-horizon', type=int, default=2, metavar='H', help='Imagination horizon distance')
parser.add_argument('--lambda-value', type=float, default=0.95, metavar='lambda', help='lambda for v')
# planning
parser.add_argument('--gaussian-policy', action='store_true', help='Use gaussian policy')
# test phase
parser.add_argument('--test', action='store_true', help='Test only')
parser.add_argument('--test-interval', type=int, default=25, metavar='I', help='Test interval (episodes)')
parser.add_argument('--test-episodes', type=int, default=10, metavar='E', help='Number of test episodes')
parser.add_argument('--checkpoint-interval', type=int, default=50, metavar='I', help='Checkpoint interval (episodes)')
parser.add_argument('--checkpoint-experience', action='store_true', help='Checkpoint experience of expert')  # to save experts
parser.add_argument('--prior-check', action='store_true', help='Generate movies of prior prediction')
parser.add_argument('--test-run', action='store_true', help='Test run')
# loading
parser.add_argument('--models', type=str, default='', metavar='M', help='Load model checkpoint')
parser.add_argument('--experience-replay', type=str, default='', metavar='ER', help='Load experience replay of experts')
# misc
parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)')
parser.add_argument('--render', action='store_true', help='Render environment')

args = parser.parse_args()
print(' ' * 26 + 'Options')
for k, v in vars(args).items():
  print(' ' * 26 + k + ': ' + str(v))


# Setup
results_dir = os.path.join(args.results_dir, args.id, args.env)
if os.path.exists(results_dir) and os.listdir(results_dir):
  raise FileExistsError(results_dir)
os.makedirs(results_dir, exist_ok=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available() and not args.disable_cuda:
  args.device = torch.device('cuda')
  torch.cuda.manual_seed(args.seed)
else:
  args.device = torch.device('cpu')

metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [],
           'observation_loss': [], 'reward_loss': [], 'action_decoder_loss': [], 'kl_loss': [], 'observation_loss_a': [],
           'reward_loss_a': [], 'action_decoder_loss_a': [], 'kl_loss_a': [], 'g_loss': [], 'expected_FE': [],
           'expected_obs_loss': [], 'expected_kl_loss': []}


# Initialise training environment and expert experience replay
env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth, args.sparse_env)
episode_length = args.max_episode_length // args.action_repeat
if args.experience_replay is not '' and os.path.exists(args.experience_replay):
  # For imitation learning
  expert_list = glob.glob(os.path.join(args.experience_replay, '*.npz'))
  expert_list = expert_list[:min(args.number_of_experts, len(expert_list))]
  print('Use {} experts for this experiment'.format(len(expert_list)))
  D_expert = SeqExperienceReplay(len(expert_list), args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device, episode_length, args.burn_in_size, args.on_policy_episodes)
  for i in range(len(expert_list)):
    if i % 100 == 0:
      print('Loading experts{}'.format(i))
    npepisode = np.load(expert_list[i])
    npepisode_length = npepisode['action'].shape[0]
    D_expert.append_from_expert(npepisode, npepisode_length)


# Initialise model parameters randomly
transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device)
observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, env.action_size, args.activation_function).to(device=args.device)
reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, env.action_size, args.activation_function).to(device=args.device)
if args.gaussian_policy:
  # policy prior
  action_decoder = GaussianPolicy(args.belief_size, args.state_size, args.hidden_size, env.action_size, args.env, args.activation_function, env.action_space).to(device=args.device)
  # policy posterior
  policy = GaussianPolicy(args.belief_size, args.state_size, args.hidden_size, env.action_size, args.env, args.activation_function, env.action_space).to(device=args.device)
else:
  # Note that the results in the paper are in the case of deterministic policy.
  # policy prior
  action_decoder = DeterministicPolicy(args.belief_size, args.state_size, args.hidden_size, env.action_size, args.env, args.activation_function, env.action_space).to(device=args.device)
  # policy posterior
  policy = DeterministicPolicy(args.belief_size, args.state_size, args.hidden_size, env.action_size, args.env, args.activation_function, env.action_space).to(device=args.device)
encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device)
g_model = GNetwork(args.state_size, env.action_size, args.hidden_size).to(device=args.device)
g_target_model = GNetwork(args.state_size, env.action_size, args.hidden_size).to(device=args.device)
param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters()) + list(action_decoder.parameters())
optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
param_list_agent = list(policy.parameters())
optimiser_agent = optim.Adam(param_list_agent, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
param_list_value = list(g_model.parameters())
optimiser_value = optim.Adam(param_list_value, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
hard_update(g_target_model, g_model)
if args.models is not '' and os.path.exists(args.models):
  print('Loading')
  if torch.cuda.is_available() and not args.disable_cuda:
    model_dicts = torch.load(args.models)
  else:
    model_dicts = torch.load(args.models, map_location=torch.device('cpu'))
  transition_model.load_state_dict(model_dicts['transition_model'])
  observation_model.load_state_dict(model_dicts['observation_model'])
  reward_model.load_state_dict(model_dicts['reward_model'])
  action_decoder.load_state_dict(model_dicts['action_decoder'])
  policy.load_state_dict(model_dicts['policy'])
  encoder.load_state_dict(model_dicts['encoder'])
  optimiser.load_state_dict(model_dicts['optimiser'])
  optimiser_agent.load_state_dict(model_dicts['optimiser_agent'])
  # g_model.load_state_dict(model_dicts['g_model'])
  # g_target_model.load_state_dict(model_dicts['g_target_model'])
  # optimiser_value.load_state_dict(model_dicts['optimiser_value'])
free_nats = torch.full((1, ), args.free_nats, device=args.device)  # Allowed deviation in KL divergence


def update_belief_and_act(args, env, transition_model, policy, encoder, belief, posterior_state, action, observation, t, test):
  # Infer belief
  if t == 0:  # for burn-in
    action_size = action.size()
    obs = encoder(observation)
    obs_size = obs.size()
    beliefs, _, _, _, posterior_states, posterior_mean, posterior_std_dev = transition_model(posterior_state, belief, args.burn_in_size-1, action.unsqueeze(dim=0).expand(args.burn_in_size-1, *action_size), obs.unsqueeze(dim=0).expand(args.burn_in_size-1, *obs_size))
    posterior_state, belief = posterior_states[-1], beliefs[-1]
  belief, _, _, _, posterior_state, posterior_mean, posterior_std_dev = transition_model(posterior_state, belief, 1, action.unsqueeze(dim=0), encoder(observation).unsqueeze(dim=0))  # Action and observation need extra time dimension
  belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
  posterior_mean, posterior_std_dev = posterior_mean.squeeze(dim=0), posterior_std_dev.squeeze(dim=0)
  if args.gaussian_policy:
    action, _, action_test, _, _ = policy.sample(belief, posterior_state)
    if test:
      action = action_test
  else:
    action = policy(belief, posterior_state)
  if not test and args.action_noise > 0:
    action = action + args.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
  next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # Perform environment step (action repeats handled internally)
  return belief, posterior_state, action, next_observation, reward, done


# Initialize agent experience replay
D_agent = SeqExperienceReplay(args.experience_size, args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device, episode_length, args.burn_in_size, args.on_policy_episodes, args.discount)
print('Initialize agent experience replay with {} episodes'.format(args.seed_episodes))
with torch.no_grad():
  for s in range(args.seed_episodes):
    observation, total_reward, done = env.reset(), 0, False
    belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
    episode_length = args.max_episode_length // args.action_repeat
    for t in range(episode_length):
      if args.models is not '' and os.path.exists(args.models):
        belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, transition_model, policy, encoder, belief, posterior_state, action, observation.to(device=args.device), t, test=True)
      else:
        action = env.sample_random_action()
        next_observation, reward, done = env.step(action)
      D_agent.append(observation, action.cpu(), reward, done)
      total_reward += reward
      observation = next_observation
      if done:
        print(total_reward)
        break


# Testing only
if args.test:
  # Set models to eval mode
  transition_model.eval()
  observation_model.eval()
  reward_model.eval()
  action_decoder.eval()
  policy.eval()
  encoder.eval()
  g_model.eval()
  g_target_model.eval()
  with torch.no_grad():
    total_reward = 0
    for episode in tqdm(range(args.test_episodes)):
      observation = env.reset()
      belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
      episode_length = args.max_episode_length // args.action_repeat
      video_frames = []
      if args.prior_check:
        for t in range(episode_length):
          if t == 0:  # for burn-in
            action_size = action.size()
            obs = encoder(observation.to(device=args.device))
            obs_size = obs.size()
            prev_action = action
            beliefs, _, _, _, posterior_states, _, _ = transition_model(posterior_state, belief, args.burn_in_size - 1, action.unsqueeze(dim=0).expand(args.burn_in_size - 1, *action_size), obs.unsqueeze(dim=0).expand(args.burn_in_size - 1, *obs_size))
            prior_state, belief = posterior_states[-1], beliefs[-1]
          if args.gaussian_policy:
            _, _, q_action, _, _ = policy.sample(belief, prior_state)
          else:
            q_action = policy(belief, prior_state)
          belief, prior_state, _, _ = transition_model(prior_state, belief, 1, q_action.unsqueeze(dim=0))
          belief, prior_state = belief.squeeze(dim=0), prior_state.squeeze(dim=0)
          if args.gaussian_policy:
            _, _, action, _, _ = action_decoder.sample(belief, prior_state)
          else:
            action = action_decoder(belief, prior_state)
          video_frames.append(make_grid(observation_model(belief, prior_state).cpu() + 0.5, nrow=5).numpy())  # Decentre
        episode_str = str(episode).zfill(len(str(args.test_episodes)))
        write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
      elif args.test_run:
        for t in range(episode_length):
          belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, transition_model, policy, encoder, belief, posterior_state, action, observation.to(device=args.device), t, test=True)
          if args.observation_noise > 0:
            next_observation = next_observation + args.observation_noise * torch.randn_like(next_observation)
          total_reward += reward
          video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
          observation = next_observation
          if done:
            break
        episode_str = str(episode).zfill(len(str(args.test_episodes)))
        write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
      else:
        # for experience checkpoint
        tmp_obs = np.empty((episode_length, 3, 64, 64), dtype=np.float32)
        tmp_action = np.empty((episode_length, env.action_size), dtype=np.float32)
        tmp_reward = np.empty((episode_length,), dtype=np.float32)
        tmp_nonterminals = np.empty((episode_length, 1), dtype=np.float32)

        for t in range(episode_length):
          belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, transition_model, policy, encoder, belief, posterior_state, action, observation.to(device=args.device), t, test=True)
          total_reward += reward
          # for experience checkpoint
          tmp_obs[t] = postprocess_observation(observation.numpy(), args.bit_depth)
          tmp_action[t] = action.cpu().numpy()
          tmp_reward[t] = reward
          tmp_nonterminals[t] = not done

          observation = next_observation
          if args.render:
            env.render()
          if done:
            break
        if args.checkpoint_experience:
          np.savez(os.path.join(results_dir, 'experience{}'.format(
            episode)), obs=tmp_obs, action=tmp_action, reward=tmp_reward, nonterminals=tmp_nonterminals)

  print('Average Reward:', total_reward / args.test_episodes)
  env.close()
  quit()

# Training (and testing)
for episode in tqdm(range(1, args.episodes + 1), total=args.episodes, initial=1):
  # Learning
  losses = []
  for s in range(args.collect_interval):
    # Start Imitation Learning
    if episode > args.start_expert:
      # Draw expert sequence chunks {(o_t, a_t, r_t, terminal_t)} ~ D uniformly at random from the expert dataset (including terminal flags)
      observations, actions, rewards, nonterminals, _ = D_expert.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0
      
      # Create initial belief and state for time t = 0
      init_belief_b, init_state_b = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)
      # burn in period
      beliefs_b, _, _, _, posterior_states_b, _, _ = transition_model(init_state_b, init_belief_b, actions[:args.burn_in_size-1].size(0), actions[:args.burn_in_size-1], bottle(encoder, (observations[1:args.burn_in_size], )), nonterminals[:args.burn_in_size-1])
      init_belief = beliefs_b[-1].detach()
      init_state = posterior_states_b[-1].detach()
      # Update belief/state using posterior (over entire sequence at once)
      beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = transition_model(init_state, init_belief, actions[args.burn_in_size-1:-1].size(0), actions[args.burn_in_size-1:-1], bottle(encoder, (observations[args.burn_in_size:], )), nonterminals[args.burn_in_size-1:-1])

      # Calculate losses related to F; sum over final dims, average over batch and time
      x_tuple = (beliefs, posterior_states)
      if args.gaussian_policy:
        generated_actions, _, _, generated_actions_mean, generated_actions_std = bottle_action(action_decoder.sample, x_tuple)
      else:
        generated_actions = bottle(action_decoder, x_tuple)
      action_decoder_loss = args.action_scale * F.mse_loss(generated_actions, actions[args.burn_in_size:], reduction='none').sum(dim=2).mean(dim=(0, 1))
      observation_loss = F.mse_loss(bottle(observation_model, (beliefs, posterior_states)), observations[args.burn_in_size:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
      # reward_loss = args.reward_scale * F.mse_loss(bottle(reward_model, (beliefs, posterior_states)), rewards[args.burn_in_size:], reduction='none').mean(dim=(0, 1))
      reward_loss = torch.zeros(1, device=args.device)  # Disable reward model learning for expert data
      kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out

      # Expected Free Energy G calculation
      if args.gaussian_policy:
        q_actions, _, _, q_actions_m, q_actions_s = bottle_action(policy.sample, x_tuple)  # a_t
      else:
        q_actions = bottle(policy, x_tuple)
      beliefs2, prior_states2, _, _ = bottle_transition(transition_model, q_actions, x_tuple)  #s_t -> s_t+1
      x_tuple = (beliefs2, prior_states2)
      if args.gaussian_policy:
        generated_actions2, _, _, generated_actions_m2, generated_actions_s2 = bottle_action(action_decoder.sample, x_tuple)
        q_actions2, _, _, q_actions_m2, q_actions_s2 = bottle_action(policy.sample, x_tuple)  # a_t+1
      else:
        generated_actions2 = bottle(action_decoder, x_tuple)
        q_actions2 = bottle(policy, x_tuple)

      # Calculate losses related to G
      expected_obs_loss = torch.zeros(1, device=args.device)  # equivalent to entropy of observation decoder. When obs likelihood is deterministic, H is 0.
      if args.gaussian_policy:
        expected_kl_loss = args.policy_scale * kl_divergence(Normal(q_actions_m2, q_actions_s2), Normal(generated_actions_m2, generated_actions_s2)).sum(dim=2).mean(dim=(0,1))
      else:
        expected_kl_loss = args.policy_scale * F.mse_loss(q_actions2, generated_actions2, reduction='none').sum(dim=2).mean(dim=(0,1))

    # Start Reinforcement Learning
    if episode > args.start_agent:
      # Draw agent sequence chunks {(o_t, a_t, r_t, terminal_t)} ~ D uniformly at random from the agent dataset (including terminal flags)
      if args.on_policy_episodes > 1:
        observations_a, actions_a, rewards_a, nonterminals_a, returns_a = D_agent.sample_recent(args.batch_size, args.chunk_size, episode)  # Transitions start at time t = 0
      else:
        observations_a, actions_a, rewards_a, nonterminals_a, returns_a = D_agent.sample(args.batch_size, args.chunk_size)
      # Create initial belief and state for time t = 0
      init_belief_ab, init_state_ab = torch.zeros(args.batch_size, args.belief_size, device=args.device), torch.zeros(args.batch_size, args.state_size, device=args.device)
      # burn in period
      beliefs_ab, _, _, _, posterior_states_ab, _, _ = transition_model(init_state_ab, init_belief_ab, actions_a[:args.burn_in_size-1].size(0), actions_a[:args.burn_in_size-1], bottle(encoder, (observations_a[1:args.burn_in_size], )), nonterminals_a[:args.burn_in_size-1])
      init_belief_a = beliefs_ab[-1].detach()
      init_state_a = posterior_states_ab[-1].detach()
      # Update belief/state using posterior (over entire sequence at once)
      beliefs_a, prior_states_a, prior_means_a, prior_std_devs_a, posterior_states_a, posterior_means_a, posterior_std_devs_a = transition_model(init_state_a, init_belief_a, actions_a[args.burn_in_size-1:-1].size(0), actions_a[args.burn_in_size-1:-1], bottle(encoder, (observations_a[args.burn_in_size:], )), nonterminals_a[args.burn_in_size-1:-1])

      # Calculate losses related to F; sum over final dims, average over batch and time
      x_tuple = (beliefs_a, posterior_states_a)
      if args.gaussian_policy:
        generated_actions_a, _, _, generated_actions_a_mean, generated_actions_a_std = bottle_action(action_decoder.sample, x_tuple)
      else:
        generated_actions_a = bottle(action_decoder, x_tuple)
      # action_decoder_loss_a = F.mse_loss(generated_actions_a, actions_a[args.burn_in_size:], reduction='none').sum(dim=2).mean(dim=(0, 1))
      action_decoder_loss_a = torch.zeros(1, device=args.device)  # Disable action decoder learning for agent data
      observation_loss_a = F.mse_loss(bottle(observation_model, (beliefs_a, posterior_states_a)), observations_a[args.burn_in_size:], reduction='none').sum(dim=2 if args.symbolic_env else (2, 3, 4)).mean(dim=(0, 1))
      reward_loss_a = args.reward_scale * F.mse_loss(bottle(reward_model, (beliefs_a, posterior_states_a)), rewards_a[args.burn_in_size:], reduction='none').mean(dim=(0, 1))
      kl_loss_a = torch.max(kl_divergence(Normal(posterior_means_a, posterior_std_devs_a), Normal(prior_means_a, prior_std_devs_a)).sum(dim=2), free_nats).mean(dim=(0, 1))  # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out

      # G calculation
      x_tuple = (beliefs_a.detach(), posterior_states_a.detach())
      G = [torch.empty(0)] * args.imagination_horizon
      v = [torch.empty(0)] * args.imagination_horizon
      expected_FE = torch.zeros((args.chunk_size, args.batch_size), device=args.device)
      for i in range(args.imagination_horizon):
        if args.gaussian_policy:
          q_actions_a, _, _, q_actions_am, q_actions_as = bottle_action(policy.sample, x_tuple)  # a_t
        else:
          q_actions_a = bottle(policy, x_tuple)
        beliefs_a2, _, prior_states_am2, prior_states_as2 = bottle_transition(transition_model, q_actions_a, x_tuple)  # q(s_t) -> q(s_t+1) : prior
        generated_observations_a2 = bottle(observation_model, (beliefs_a2, prior_states_am2))  # q(o_t+1)
        _, prior_states_a2, _, _, posterior_states_a2, posterior_states_am2, posterior_states_as2 = bottle_transition(transition_model, q_actions_a, x_tuple, generated_observations_a2, encoder)  # q(s_t+1|o_t+1) : posterior
        x_tuple = (beliefs_a2, prior_states_a2)
        if args.gaussian_policy:
          generated_actions_a2, _, _, generated_actions_am2, generated_actions_as2 = bottle_action(action_decoder.sample, x_tuple)
          q_actions_a2, _, _, q_actions_am2, q_actions_as2 = bottle_action(policy.sample, x_tuple)  # a_t+1
        else:
          generated_actions_a2 = bottle(action_decoder, x_tuple)
          q_actions_a2 = bottle(policy, x_tuple)

        expected_rewards = args.reward_scale * bottle(reward_model, (beliefs_a2, prior_states_a2))
        epistemic_values = - kl_divergence(Normal(posterior_states_am2, posterior_states_as2), Normal(prior_states_am2, prior_states_as2)).sum(dim=2)
        if args.gaussian_policy:
          expected_kl_loss_a = torch.max(kl_divergence(Normal(q_actions_am2, q_actions_as2), Normal(generated_actions_am2.detach(), generated_actions_as2.detach())).sum(dim=2), free_nats * (env.action_size/args.state_size))
        else:
          expected_kl_loss_a = torch.max(F.mse_loss(q_actions_a2, generated_actions_a2.detach(), reduction='none').sum(dim=2), free_nats * (env.action_size/args.state_size))
        if i == 0:
          G[i] = - expected_rewards + epistemic_values + expected_kl_loss_a
        else:
          G[i] = G[i-1] + np.power(args.discount, i) * (- expected_rewards + epistemic_values + expected_kl_loss_a)
        v[i] = np.power(args.lambda_value, i) * (G[i] + np.power(args.discount, i+1) * bottle(g_target_model, (prior_states_a2, )))
        if i == args.imagination_horizon-1:
          expected_FE = expected_FE + v[i]
        else:
          expected_FE = expected_FE + (1 - args.lambda_value) * v[i]
      g_loss = F.mse_loss(bottle(g_model, (posterior_states_a.detach(), )), expected_FE.detach(), reduction='none').mean(dim=(0, 1))
      expected_FE = expected_FE.mean(dim=(0,1))
      # expected_FE = 10 * expected_FE.mean(dim=(0, 1))

    # Apply linearly ramping learning rate schedule. Default: disabled.
    if args.learning_rate_schedule != 0:
      for group in optimiser.param_groups:
        group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate)
      for group in optimiser_agent.param_groups:
        group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate)
      for group in optimiser_value.param_groups:
        group['lr'] = min(group['lr'] + args.learning_rate / args.learning_rate_schedule, args.learning_rate)

    if episode > args.value_period:
      # Update model parameters
      optimiser.zero_grad()
      if episode > args.start_expert:
        (observation_loss + reward_loss + action_decoder_loss + kl_loss).backward(retain_graph=True)
      if episode > args.start_agent:
        (observation_loss_a + reward_loss_a + action_decoder_loss_a + kl_loss_a).backward(retain_graph=True)
      nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2)
      optimiser.step()

      optimiser_agent.zero_grad()
      if episode > args.start_expert:
        # (expected_obs_loss + expected_kl_loss).backward(retain_graph=True)
        (expected_obs_loss + expected_kl_loss).backward()
      if episode > args.start_agent:
        expected_FE.backward(retain_graph=True)
        # expected_FE.backward()
      nn.utils.clip_grad_norm_(param_list_agent, args.grad_clip_norm, norm_type=2)
      optimiser_agent.step()

    optimiser_value.zero_grad()
    if episode > args.start_agent:
      g_loss.backward()
    nn.utils.clip_grad_norm_(param_list_value, args.grad_clip_norm, norm_type=2)
    optimiser_value.step()
    soft_update(g_target_model, g_model, args.tau)

    # Store losses
    loss_tmp = []
    if episode > args.start_expert:
      loss_tmp += [observation_loss.item(), reward_loss.item(), action_decoder_loss.item(), kl_loss.item(), expected_obs_loss.item(), expected_kl_loss.item()]
    else:
      loss_tmp += [0, 0, 0, 0, 0, 0]
    if episode > args.start_agent:
      loss_tmp += [observation_loss_a.item(), reward_loss_a.item(), action_decoder_loss_a.item(), kl_loss_a.item(), g_loss.item(),
                   expected_FE.item()]
    else:
      loss_tmp += [0, 0, 0, 0, 0, 0]
    losses.append(loss_tmp)

  # Plot loss metrics
  losses = tuple(zip(*losses))
  metrics['observation_loss'].append(losses[0])
  metrics['reward_loss'].append(losses[1])
  metrics['action_decoder_loss'].append(losses[2])
  metrics['kl_loss'].append(losses[3])
  metrics['expected_obs_loss'].append(losses[4])
  metrics['expected_kl_loss'].append(losses[5])
  metrics['observation_loss_a'].append(losses[6])
  metrics['reward_loss_a'].append(losses[7])
  metrics['action_decoder_loss_a'].append(losses[8])
  metrics['kl_loss_a'].append(losses[9])
  metrics['g_loss'].append(losses[10])
  metrics['expected_FE'].append(losses[11])
  metrics['episodes'].append(episode)

  if episode > args.start_expert:
    lineplot(metrics['episodes'][-len(metrics['observation_loss']):], metrics['observation_loss'], 'observation_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['reward_loss']):], metrics['reward_loss'], 'reward_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['action_decoder_loss']):], metrics['action_decoder_loss'], 'action_decoder_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['kl_loss']):], metrics['kl_loss'], 'kl_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['expected_obs_loss']):], metrics['expected_obs_loss'], 'expected_obs_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['expected_kl_loss']):], metrics['expected_kl_loss'], 'expected_kl_loss', results_dir)

  if episode > args.start_agent:
    lineplot(metrics['episodes'][-len(metrics['observation_loss_a']):], metrics['observation_loss_a'], 'observation_loss_a', results_dir)
    lineplot(metrics['episodes'][-len(metrics['reward_loss_a']):], metrics['reward_loss_a'], 'reward_loss_a', results_dir)
    lineplot(metrics['episodes'][-len(metrics['action_decoder_loss_a']):], metrics['action_decoder_loss_a'], 'action_decoder_loss_a', results_dir)
    lineplot(metrics['episodes'][-len(metrics['kl_loss_a']):], metrics['kl_loss_a'], 'kl_loss_a', results_dir)
    lineplot(metrics['episodes'][-len(metrics['g_loss']):], metrics['g_loss'], 'g_loss', results_dir)
    lineplot(metrics['episodes'][-len(metrics['expected_FE']):], metrics['expected_FE'], 'expected_FE', results_dir)


  # Active data collection
  with torch.no_grad():
    total_rewards = []
    for s in range(args.on_policy_episodes):
      observation, total_reward, done = env.reset(), 0, False
      belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), torch.zeros(1, args.state_size, device=args.device), torch.zeros(1, env.action_size, device=args.device)
      episode_length = args.max_episode_length // args.action_repeat
      for t in range(episode_length):
        belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, env, transition_model, policy, encoder, belief, posterior_state, action, observation.to(device=args.device), t, test=False)
        if args.observation_noise > 0:
          next_observation = next_observation + args.observation_noise * torch.randn_like(next_observation)
        D_agent.append(observation, action.cpu(), reward, done)
        total_reward += reward
        observation = next_observation
        if args.render:
          env.render()
        if done:
          break
      total_rewards.append(total_reward)
    # Update and plot train reward metrics
    metrics['train_rewards'].append(total_rewards)
    lineplot(metrics['episodes'][-len(metrics['train_rewards']):], metrics['train_rewards'], 'train_rewards', results_dir)


  # Test model
  if episode % args.test_interval == 0:
    # Set models to eval mode
    transition_model.eval()
    observation_model.eval()
    reward_model.eval()
    action_decoder.eval()
    policy.eval()
    encoder.eval()
    g_model.eval()
    g_target_model.eval()
    # Initialise parallelised test environments
    test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth, args.sparse_env), {}, args.test_episodes)
    
    with torch.no_grad():
      observation, total_rewards, video_frames = test_envs.reset(), np.zeros((args.test_episodes, )), []
      belief, posterior_state, action = torch.zeros(args.test_episodes, args.belief_size, device=args.device), torch.zeros(args.test_episodes, args.state_size, device=args.device), torch.zeros(args.test_episodes, env.action_size, device=args.device)
      episode_length = args.max_episode_length // args.action_repeat
      for t in range(episode_length):
        belief, posterior_state, action, next_observation, reward, done = update_belief_and_act(args, test_envs, transition_model, policy, encoder, belief, posterior_state, action, observation.to(device=args.device), t, test=True)
        if args.observation_noise > 0:
          next_observation = next_observation + args.observation_noise * torch.randn_like(next_observation)
        total_rewards += reward.numpy()
        if not args.symbolic_env:  # Collect real at t vs. predicted frames at t+1 for video
          video_frames.append(make_grid(torch.cat([observation, observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
        observation = next_observation
        if done.sum().item() == args.test_episodes:
          break
    
    # Update and plot reward metrics (and write video if applicable) and save metrics
    metrics['test_episodes'].append(episode)
    metrics['test_rewards'].append(total_rewards.tolist())
    lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir)
    if not args.symbolic_env:
      episode_str = str(episode).zfill(len(str(args.episodes)))
      write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
      save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
    torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

    # Set models to train mode
    transition_model.train()
    observation_model.train()
    reward_model.train()
    action_decoder.train()
    policy.train()
    encoder.train()
    g_model.train()
    g_target_model.train()
    # Close test environments
    test_envs.close()


  # Checkpoint models
  if episode % args.checkpoint_interval == 0:
    torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(),
                'reward_model': reward_model.state_dict(), 'action_decoder': action_decoder.state_dict(), 'policy': policy.state_dict(),
                'encoder': encoder.state_dict(), 'g_model': g_model.state_dict(), 'g_target_model': g_target_model.state_dict(),
                'optimiser': optimiser.state_dict(), 'optimiser_agent': optimiser_agent.state_dict(),
                'optimiser_value': optimiser_value.state_dict()}, os.path.join(results_dir, 'models_%d.pth' % episode))

# Close training environment
env.close()
