# -*- coding: utf-8 -*-
import argparse
import os
import glob
import numpy as np
import torch
from torch import jit
from torch import nn, optim
from torch.distributions import Normal
from torch.distributions.kl import kl_divergence
from torch.nn import functional as F
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
from env import CONTROL_SUITE_ENVS, Env, GYM_ENVS, EnvBatcher
from memory import SeqExperienceReplay
from utils import lineplot, write_video
from env import postprocess_observation

# Hyperparameters
parser = argparse.ArgumentParser(description='Comparison experiments')
parser.add_argument('--id', type=str, default='comparison', help='Experiment ID')
parser.add_argument('--results-dir', type=str, default=f'/mnt/ISINAS1/ogishima/proj_icml_exp/results', help=f'Results Directory. Default: /mnt/ISINAS1/ogishima/proj_icml_exp/results')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed')
parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA')
parser.add_argument('--env', type=str, default='cheetah-run', choices=GYM_ENVS + CONTROL_SUITE_ENVS, help='Gym/Control Suite environment')
parser.add_argument('--symbolic-env', action='store_true', help='Symbolic features')
parser.add_argument('--observation-noise', type=float, default=0, metavar='εo', help='Observation noise eg. 0.1 or 0.2')
parser.add_argument('--sparse-env', action='store_true', help='Sparse environment')
# imitation learning
parser.add_argument('--number-of-experts', type=int, default=10000, metavar='E', help='Number of experts')
# network architecture
parser.add_argument('--activation-function', type=str, default='relu', choices=dir(F), help='Model activation function')
parser.add_argument('--embedding-size', type=int, default=1024, metavar='E', help='Observation embedding size')  # Note that the default encoder for visual observations outputs a 1024D vector; for other embedding sizes an additional fully-connected layer is used
parser.add_argument('--hidden-size', type=int, default=200, metavar='H', help='Hidden size')
parser.add_argument('--belief-size', type=int, default=200, metavar='H', help='Belief/hidden size')
parser.add_argument('--state-size', type=int, default=30, metavar='Z', help='State/latent size')
# mujoco env
parser.add_argument('--max-episode-length', type=int, default=1000, metavar='T', help='Max episode length')
parser.add_argument('--action-repeat', type=int, default=4, metavar='R', help='Action repeat')
# training phase
parser.add_argument('--experience-size', type=int, default=1500, metavar='D', help='Experience replay maximum size for seeds + active collection (number of episodes)')
parser.add_argument('--episodes', type=int, default=1000, metavar='E', help='Total number of episode for training')
parser.add_argument('--seed-episodes', type=int, default=10, metavar='S', help='Seed episodes')
parser.add_argument('--collect-interval', type=int, default=100, metavar='C', help='Collect interval')
parser.add_argument('--batch-size', type=int, default=50, metavar='B', help='Batch size')
parser.add_argument('--chunk-size', type=int, default=50, metavar='L', help='Chunk size')
parser.add_argument('--burn-in-size', type=int, default=0, metavar='Burn', help='Burn in size')
parser.add_argument('--on-policy-episodes', type=int, default=1, metavar='e', help='Number of episodes to calculate policy loss (1 to use off policy data)')
# learning
parser.add_argument('--learning-rate', type=float, default=1e-3, metavar='α', help='Learning rate')
parser.add_argument('--learning-rate-schedule', type=int, default=0, metavar='αS', help='Linear learning rate schedule (optimisation steps from 0 to final learning rate; 0 to disable)')
parser.add_argument('--adam-epsilon', type=float, default=1e-4, metavar='ε', help='Adam optimiser epsilon value')
parser.add_argument('--grad-clip-norm', type=float, default=1000, metavar='C', help='Gradient clipping norm')
parser.add_argument('--a-loss-coef', type=float, default=1, metavar='AL', help='action loss coefficient')
# test phase
parser.add_argument('--test', action='store_true', help='Test only')
parser.add_argument('--test-interval', type=int, default=25, metavar='I', help='Test interval (episodes)')
parser.add_argument('--test-episodes', type=int, default=10, metavar='E', help='Number of test episodes')
parser.add_argument('--checkpoint-interval', type=int, default=50, metavar='I', help='Checkpoint interval (episodes)')
parser.add_argument('--checkpoint-experience', action='store_true', help='Checkpoint experience of expert')  # to save experts
# loading
parser.add_argument('--models', type=str, default='', metavar='M', help='Load model checkpoint')
parser.add_argument('--experience-replay', type=str, default='', metavar='ER', help='Load experience replay of experts')
# misc
parser.add_argument('--bit-depth', type=int, default=5, metavar='B', help='Image bit depth (quantisation)')
parser.add_argument('--render', action='store_true', help='Render environment')
parser.add_argument('--policy-number', type=int, default=0 , metavar='p', help='feedfoward:0, recurrent policy:1, decoder:2')

args = parser.parse_args()
print(' ' * 26 + 'Options')
for k, v in vars(args).items():
  print(' ' * 26 + k + ': ' + str(v))

# Setup
results_dir = os.path.join(args.results_dir, args.id, args.env)
if os.path.exists(results_dir) and os.listdir(results_dir):
  raise FileExistsError(results_dir)
os.makedirs(results_dir, exist_ok=True)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available() and not args.disable_cuda:
  args.device = torch.device('cuda')
  torch.cuda.manual_seed(args.seed)
else:
  args.device = torch.device('cpu')
if args.policy_number == 2:
  metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [],
          'loss': [], 'a_loss': [], 'o_loss': []}
else:
  metrics = {'steps': [], 'episodes': [], 'train_rewards': [], 'test_episodes': [], 'test_rewards': [],
          'loss': []}


# Initialise training environment and expert experience replay
env = Env(args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat, args.bit_depth, args.sparse_env)
episode_length = args.max_episode_length // args.action_repeat
if args.experience_replay is not '' and os.path.exists(args.experience_replay):
  # imitation learning
  expert_list = glob.glob(os.path.join(args.experience_replay, '*.npz'))
  expert_list = expert_list[:min(args.number_of_experts, len(expert_list))]
  print('Use {} experts for this experiment'.format(len(expert_list)))
  D_expert = SeqExperienceReplay(len(expert_list), args.symbolic_env, env.observation_size, env.action_size, args.bit_depth, args.device, episode_length, args.burn_in_size, args.on_policy_episodes)
  for i in range(len(expert_list)):
    if i % 100 == 0:
      print('Loading experts{}'.format(i))
    npepisode = np.load(expert_list[i])
    npepisode_length = npepisode['action'].shape[0]
    D_expert.append_from_expert(npepisode, npepisode_length)

def bottle(f, x_tuple):
  x_sizes = tuple(map(lambda x: x.size(), x_tuple))
  y = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes)))  # x[0]: x_tuple[i]=posterior, x[1]: x_sizes[i] then concatenate time,batch -> time x batch
  y_size = y.size()
  return y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:])

def bottle_obs(f, x_tuple):
    x_sizes = tuple(map(lambda x: x.size(), x_tuple))
    action, next_obs = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple,
                                                                      x_sizes)))  # x[0]: x_tuple[i]=posterior, x[1]: x_sizes[i] then concatenate time,batch -> time x batch
    a_size, o_size = action.size(), next_obs.size()
    return action.view(x_sizes[0][0], x_sizes[0][1], *a_size[1:]), next_obs.view(x_sizes[0][0], x_sizes[0][1], *o_size[1:])

class Encoder(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.embedding_size = embedding_size
        self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
        self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
        self.fc = nn.Identity() if embedding_size == 1024 else nn.Linear(1024, embedding_size)

    def forward(self, observation):
        hidden = F.relu(self.conv1(observation))
        hidden = F.relu(self.conv2(hidden))
        hidden = F.relu(self.conv3(hidden))
        hidden = F.relu(self.conv4(hidden))
        hidden = hidden.view(-1, 1024)
        hidden = self.fc(hidden)  # Identity if embedding size is 1024 else linear projection
        return hidden

class FeedforwardPolicy(nn.Module):
    def __init__(self, embedding_size, hidden_size, action_size):
        super(FeedforwardPolicy, self).__init__()
        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, action_size)

    def forward(self, obs):
        hidden = F.relu(self.linear1(obs))
        hidden = F.relu(self.linear2(hidden))
        action = self.linear3(hidden)
        return action

class RecurrentPolicy(nn.Module):
    def __init__(self, embedding_size, hidden_size, belief_size, action_size):
        super(RecurrentPolicy, self).__init__()

        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.rnn = nn.GRUCell(belief_size, belief_size)
        self.linear3 = nn.Linear(belief_size, action_size)

    def forward(self, obs, belief):
        T = obs.size(0)
        action = [torch.empty(0)] * T
        for t in range(T):
            hidden = F.relu(self.linear1(obs[t]))
            hidden = F.relu(self.linear2(hidden))
            belief = self.rnn(hidden, belief)
            action[t] = self.linear3(belief)
        return torch.stack(action, dim=0), belief

class RecurrentDecoder(nn.Module):
    def __init__(self, embedding_size, hidden_size, belief_size, action_size):
        super(RecurrentDecoder, self).__init__()

        self.embedding_size = embedding_size

        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.rnn = nn.GRUCell(belief_size, belief_size)
        self.linear3 = nn.Linear(belief_size, action_size)
        self.linear4 = nn.Linear(belief_size, embedding_size)

        self.conv1 = nn.ConvTranspose2d(embedding_size, 128, 5, stride=2)
        self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
        self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
        self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)


    def forward(self, obs, belief):
        T = obs.size(0)
        action = [torch.empty(0)] * T
        next_obs = [torch.empty(0)] * T
        for t in range(T):
            hidden = F.relu(self.linear1(obs[t]))
            hidden = F.relu(self.linear2(hidden))
            belief = self.rnn(hidden, belief)
            action[t] = self.linear3(belief)

            hidden = self.linear4(belief)
            hidden = hidden.view(-1, self.embedding_size, 1, 1)
            hidden = F.relu(self.conv1(hidden))
            hidden = F.relu(self.conv2(hidden))
            hidden = F.relu(self.conv3(hidden))
            next_obs[t] = self.conv4(hidden)
        return torch.stack(action, dim=0), torch.stack(next_obs, dim=0), belief

# Initialise model parameters randomly
encoder = Encoder(args.embedding_size).to(device=args.device)
if args.policy_number == 0:
    policy = FeedforwardPolicy(args.embedding_size, args.hidden_size, env.action_size).to(device=args.device)
if args.policy_number == 1:
    policy = RecurrentPolicy(args.embedding_size, args.hidden_size, args.belief_size, env.action_size).to(device=args.device)
if args.policy_number == 2:
    policy = RecurrentDecoder(args.embedding_size, args.hidden_size, args.belief_size, env.action_size).to(device=args.device)
param_list = list(encoder.parameters()) + list(policy.parameters())
optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
if args.models is not '' and os.path.exists(args.models):
  print('Loading')
  if torch.cuda.is_available() and not args.disable_cuda:
    model_dicts = torch.load(args.models)
  else:
    model_dicts = torch.load(args.models, map_location=torch.device('cpu'))
  encoder.load_state_dict(model_dicts['encoder'])
  policy.load_state_dict(model_dicts['policy'])

if args.test:
    encoder.eval()
    policy.eval()
    with torch.no_grad():
        total_reward = 0
        for episode in tqdm(range(args.test_episodes)):
            observation = env.reset()
            belief, posterior_state, action = torch.zeros(1, args.belief_size, device=args.device), \
                                              torch.zeros(1, args.state_size, device=args.device), \
                                              torch.zeros(1, env.action_size, device=args.device)
            episode_length = args.max_episode_length // args.action_repeat
            video_frames = []
            for t in range(episode_length):
                if args.policy_number == 0:
                    action = policy(encoder(observation.to(device=args.device)))
                if args.policy_number == 1:
                    action, belief = policy(encoder(observation.to(device=args.device)).unsqueeze(dim=0), belief)
                    action = action.squeeze(dim=0)
                if args.policy_number == 2:
                    action, generated_obs, belief = policy(
                        encoder(observation.to(device=args.device)).unsqueeze(dim=0), belief)
                    action, generated_obs = action.squeeze(dim=0), generated_obs.squeeze(dim=0)
                next_observation, reward, done = env.step(action[0].cpu())
                if args.observation_noise > 0:
                    next_observation = next_observation + args.observation_noise * torch.randn_like(next_observation)
                total_reward += reward
                if args.policy_number == 2:
                    video_frames.append(make_grid(torch.cat([observation, generated_obs.cpu()], dim=3) + 0.5,
                                                  nrow=5).numpy())  # Decentre
                else:
                    video_frames.append(make_grid(observation + 0.5, nrow=5).numpy())
                observation = next_observation
                if done:
                    break
            episode_str = str(episode).zfill(len(str(args.test_episodes)))
            write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
        print('Average Reward:', total_reward / args.test_episodes)
        env.close()
        quit()

for episode in tqdm(range(1, args.episodes + 1), total=args.episodes, initial=1):
  # Model fitting
  losses = []
  if args.policy_number == 2:
    a_losses = []
    o_losses = []
  else:
    pass

  for s in range(args.collect_interval):
    observations, actions, rewards, nonterminals, _ = D_expert.sample(args.batch_size, args.chunk_size)  # Transitions start at time t = 0
    if args.policy_number == 0:
        generated_actions = bottle(policy, (bottle(encoder, (observations, )), ))
        loss = F.mse_loss(generated_actions, actions, reduction='none').sum(dim=2).mean(dim=(0,1))

    if args.policy_number == 1:
        init_belief = torch.zeros(args.batch_size, args.belief_size, device=args.device)
        generated_actions, _ = policy(bottle(encoder, (observations, )), init_belief)
        loss = F.mse_loss(generated_actions, actions, reduction='none').sum(dim=2).mean(dim=(0,1))

    if args.policy_number == 2:
        init_belief = torch.zeros(args.batch_size, args.belief_size, device=args.device)
        generated_actions, generated_obs, _ = policy(bottle(encoder, (observations, )), init_belief)
        a_loss = F.mse_loss(generated_actions, actions, reduction='none').sum(dim=2).mean(dim=(0,1))
        o_loss = F.mse_loss(generated_obs[:-1], observations[1:], reduction='none').sum(dim=(2,3,4)).mean(dim=(0,1))
        loss = args.a_loss_coef * a_loss + o_loss

    optimiser.zero_grad()
    loss.backward()
    nn.utils.clip_grad_norm_(param_list, args.grad_clip_norm, norm_type=2)
    optimiser.step()
    
    loss_tmp = [loss.item()]
    losses.append(loss_tmp)
    if args.policy_number == 2:
      a_loss_tmp = [a_loss.item()]
      o_loss_tmp = [o_loss.item()]
      a_losses.append(a_loss_tmp)
      o_losses.append(o_loss_tmp)
    else:
      pass

  losses = tuple(zip(*losses))
  metrics['loss'].append(losses[0])
  metrics['episodes'].append(episode)
  lineplot(metrics['episodes'][-len(metrics['loss']):], metrics['loss'], 'loss',
           results_dir)
  if args.policy_number == 2:
    a_losses = tuple(zip(*a_losses))
    o_losses = tuple(zip(*o_losses))
    metrics['a_loss'].append(a_losses[0])
    metrics['o_loss'].append(o_losses[0])
    lineplot(metrics['episodes'][-len(metrics['a_loss']):], metrics['a_loss'], 'a_loss',
            results_dir)
    lineplot(metrics['episodes'][-len(metrics['o_loss']):], metrics['o_loss'], 'o_loss',
            results_dir)
  else:
    pass

  # Test model
  if episode % args.test_interval == 0:
      # Set models to eval mode
      policy.eval()
      # Initialise parallelised test environments
      test_envs = EnvBatcher(Env, (args.env, args.symbolic_env, args.seed, args.max_episode_length, args.action_repeat,
                                   args.bit_depth, args.sparse_env), {}, args.test_episodes)

      with torch.no_grad():
          observation, total_rewards, video_frames = test_envs.reset(), np.zeros((args.test_episodes,)), []
          belief = torch.zeros(args.test_episodes, args.belief_size, device=args.device)
          episode_length = args.max_episode_length // args.action_repeat
          for t in range(episode_length):
              if args.policy_number == 0:
                  action = policy(encoder(observation.to(device=args.device)))
              if args.policy_number == 1:
                  action, belief = policy(encoder(observation.to(device=args.device)).unsqueeze(dim=0), belief)
                  action = action.squeeze(dim=0)
              if args.policy_number == 2:
                  action, generated_obs, belief = policy(encoder(observation.to(device=args.device)).unsqueeze(dim=0), belief)
                  action, generated_obs = action.squeeze(dim=0), generated_obs.squeeze(dim=0)
              next_observation, reward, done = test_envs.step(action.cpu())
              if args.observation_noise > 0:
                  next_observation = next_observation + args.observation_noise * torch.randn_like(next_observation)
              total_rewards += reward.numpy()
              if args.policy_number == 2:
                  video_frames.append(make_grid(torch.cat([observation, generated_obs.cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
              else:
                  video_frames.append(make_grid(observation + 0.5, nrow=5).numpy())
              observation = next_observation
              if done.sum().item() == args.test_episodes:
                  break

      # Update and plot reward metrics (and write video if applicable) and save metrics
      metrics['test_episodes'].append(episode)
      metrics['test_rewards'].append(total_rewards.tolist())
      lineplot(metrics['test_episodes'], metrics['test_rewards'], 'test_rewards', results_dir)
      episode_str = str(episode).zfill(len(str(args.episodes)))
      write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
      save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
      torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

      # Set models to train mode
      policy.train()
      # Close test environments
      test_envs.close()

  # Checkpoint models
  if episode % args.checkpoint_interval == 0:
    torch.save({'policy': policy.state_dict(), 'encoder': encoder.state_dict(),
                'optimiser': optimiser.state_dict()}, os.path.join(results_dir, 'models_%d.pth' % episode))

# Close training environment
env.close()
