#!/usr/bin/env python3
import logging
import json
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from collections import deque

from train.envs.minerl_env import make_minerl
from train.common.config import Config
from train.reinforcment_learning.utils.utils import safemean, get_device, import_module
from train.reinforcment_learning.rudder.rudder_runner import Runner
from train.reinforcment_learning.ppo import PPO
from train.reinforcment_learning.save_video import save_video
from train.reinforcment_learning.rudder.rudder import RR_LSTM_ARCH, RR_LSTM
from train.reinforcment_learning.rudder.buffer import LessonBuffer
from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset, MineRLBaseDataset
from train.reinforcment_learning.rudder.rudder_test import break_sequence, get_episode_from_traj

import coloredlogs
coloredlogs.install(logging.INFO)

def copy_weights(model1, model2):
    "Copy weights for structures which are similar in both models, from 1 -> 2"

    params1 = model1.named_parameters()
    params2 = model2.named_parameters()

    dict_params2 = dict(params2)

    for name1, param1 in params1:
        if name1 in dict_params2:
            print("Copying:", name1)
            dict_params2[name1].data.copy_(param1.data)

    model2.load_state_dict(dict_params2)

    return model2

class ReinforcementLearning:
    def __init__(self):
        self.model = None

    def train(self, args, env, model, dataset):
        """ rl-train model """
        print("Starting with PPO training!")

        # put model to appropriate device
        device = get_device(args.gpu)
        device_rudder = get_device(args.rudder_gpu)
        buffer = LessonBuffer(size=1000, max_time=args.rudder_episode_len, state_shape=(3, 48, 48), n_actions=10)
        rr_lstm = RR_LSTM_ARCH(input_lstm=288, lstm_size=128, num_actions=10, duplication=10, device=device_rudder)
        rudder = RR_LSTM(model=rr_lstm, buffer=buffer, loss_weight=args.rudder_aux_loss_weight, lr=args.rudder_lrs,
                            num_actions=10, episode_len=args.rudder_episode_len)

        print("Copying conv layers from behavior cloning agent!")
        rr_lstm = copy_weights(model, rr_lstm)
        self.model = model.to(device)
        rr_lstm = rr_lstm.to(device)

        # some optimization hyper parameters
        betas = (0.9, 0.999)  # adam momentum parameters
        nbatch = args.num_env * args.nsteps  # size of the entire batch
        nbatch_train = nbatch // args.nminibatches  # size of the mini-batch
        nupdates = args.total_timesteps // nbatch  # number of updates
        nbiupdates = args.train_value_timesteps // nbatch  # number of updates

        # initialize PPO
        print("Initializing PPO!")
        ppo = PPO(model=self.model, dataset=dataset, gamma=args.gamma, gae_lambda=args.lam, eps_clip=args.eps_clip, pg_coef=args.pg_coef,
                  vf_coef=args.vf_coef, ent_coef_bin=args.ent_coef_bin, ent_coef_cam=args.ent_coef_cam,
                  lrs=args.lrs, betas=betas, clamping=args.clamping,
                  clip_grad_norm=args.clip_grad_norm, bound_value_loss=args.bound_value_loss, value_clip=args.value_clip,
                  lrs_critic_head=args.lrs_critic_head, device=device)

        # initialize runner
        print("Initializing Runner!")
        runner = Runner(model=self.model, rudder=rudder, dataset=dataset, nenv=args.num_env, env=env, nsteps=args.nsteps,
                        seq_len=dataset.SEQ_LENGTH, gamma=args.gamma, lam=args.lam, episode_length=args.episode_length,
                        reward_scale=args.reward_scale, episodic_interactions=args.episodic_interactions, device=device)

        runner_video = Runner(model=self.model, rudder=rudder, dataset=dataset, nenv=args.num_env, env=env, nsteps=args.nsteps,
                              seq_len=dataset.SEQ_LENGTH, gamma=args.gamma, lam=args.lam,
                              episode_length=args.episode_length, reward_scale=args.reward_scale,
                              episodic_interactions=args.episodic_interactions, device=device)

        # average reward of last _ episodes
        epinfobuf = deque(maxlen=args.log_num_epi_avg)
        writer = SummaryWriter(comment=args.experiment_name)

        # pretrain the rudder LSTM
        # Add human trajectories to the buffer
        print("Loading tree chop sequences")
        neg_rew = args.negative_reward_const
        root_chop = "/publicdata/minerl/official_v1"
        data_chop = MineRLDataset(root=root_chop, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                  experiment="MineRLTreechop-v0",
                                  input_space=dataset.INPUT_SPACE,
                                  action_space=dataset.ACTION_SPACE,
                                  data_transform=dataset.DATA_TRANSFORM)

        num_seq_chop = data_chop.num_sequences()
        print("number of tree chop sequences", num_seq_chop)

        pov, binary_actions, camera_actions, rewards, length = [], [], [], [], []
        for i in range(num_seq_chop):
            # get a sequence
            _pov, _binary_actions, _camera_actions, _rewards = data_chop._prepare_sequence(i, step=None)
            pov.append(break_sequence(args.rudder_episode_len, _pov))
            binary_actions.append(break_sequence(args.rudder_episode_len, _binary_actions))
            camera_actions.append(break_sequence(args.rudder_episode_len, _camera_actions))
            l, b_r = break_sequence(args.rudder_episode_len, _rewards, True)
            rewards.append(b_r)
            length.append(l)

        # TODO: test if the transformation aligns all the arrays in correct manner
        flat_pov, flat_binary_actions, flat_camera_actions, flat_rewards, chop_lengths = [], [], [], [], []
        for p_, b_, c_, r_, l_ in zip(pov, binary_actions, camera_actions, rewards, length):
            for p_1, b_1, c_1, r_1, l_1 in zip(p_, b_, c_, r_, l_):
                flat_pov.append(p_1)
                flat_binary_actions.append(b_1)
                flat_camera_actions.append(c_1)
                if args.change_reward_structure:
                    first_reward = np.where(r_1 == 1)[0]
                    if first_reward.size == 0:
                        r_1[:] = neg_rew
                    else:
                        r_1[:first_reward[0]] = neg_rew
                flat_rewards.append(r_1)
                chop_lengths.append(l_1)

        chop_pov = np.stack(flat_pov)
        chop_binary_actions = np.stack(flat_binary_actions)
        chop_camera_actions = np.stack(flat_camera_actions)
        chop_rewards = np.stack(flat_rewards)
        chop_lengths = np.stack(chop_lengths)
        chop_actions = np.concatenate((chop_binary_actions, chop_camera_actions), axis=-1)

        # append all the sequences to the lessons buffer
        for i in range(chop_pov.shape[0]):
            buffer.add(states=chop_pov[i], actions=chop_actions[i], rewards=chop_rewards[i], length=chop_lengths[i])
        print(buffer.curr_size)

        # fill samples from the environment
        samples_to_fill = chop_pov.shape[0] # equal to human samples
        del data_chop, chop_pov, chop_camera_actions, chop_binary_actions, chop_rewards, chop_lengths, chop_actions
        rollouts = samples_to_fill * args.rudder_episode_len // (args.nsteps * args.num_env)

        for i in range(rollouts):
            with torch.no_grad():
                b_states, b_rewards, b_dones, _, _, _, _, _ = runner.rollout()
                pov, bin_actions, camera_actions = runner.transform_state_dict(b_states)

            agent_pov_, agent_bin_actions_, agent_camera_actions_, agent_b_rewards_, agent_lengths_ = get_episode_from_traj(
                b_rewards,
                b_dones, pov,
                bin_actions,
                camera_actions,
                args)

            agent_actions = np.concatenate((agent_bin_actions_, agent_camera_actions_), axis=-1)

            for i in range(agent_pov_.shape[0]):
                buffer.add(states=agent_pov_[i], actions=agent_actions[i], rewards=agent_b_rewards_[i], length=agent_lengths_[i])
            print(buffer.curr_size)

        # train from lessons buffer
        print("Training Rudder")
        runner.learn_from_buffer(writer, rudder_pretraining=True, stop_loss=0.05, batch_size=args.rudder_batch_size)

        # store the config to the run folder
        args.record_dir = writer.get_logdir() + '/'

        with open(args.config_file) as json_file:
            data = json.load(json_file)

        with open(writer.get_logdir() + '/config.json', 'w') as outfile:
            json.dump(data, outfile)

        # burn in training of value function
        if args.train_value_first:
            print("Pre-training Value Function!")

            # iterate value function burn-in epochs
            for update in range(nbiupdates):

                # perform rollout
                b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values, epinfo = \
                    runner.run()
                print("epinfo", epinfo)

                # shuffle data, create minibatches and run for a few epochs to complete a single update
                value_loss, camera_entropy, action_entropy = [], [], []
                inds = np.arange(nbatch)

                # update un current rollout for "val_epochs" epochs
                for _ in range(args.val_epochs):
                    # randomize the indexes
                    np.random.shuffle(inds)
                    # 0 to batch_size with batch_train_size step
                    for start in range(0, nbatch, nbatch_train):

                        # slice out training batch
                        end = start + nbatch_train
                        mbinds = inds[start:end]
                        slices = (arr[mbinds] for arr in (b_returns, pov, bin_actions, camera_actions, b_actions))

                        # perform update of value head
                        _value_loss, _camera_entropy, _action_entropy = ppo.update_value_head(*slices)

                        # book keeping
                        value_loss.append(_value_loss)
                        camera_entropy.append(_camera_entropy)
                        action_entropy.append(_action_entropy)

                if len(runner.episodic_returns) >= 3:

                    mean_return = np.mean(runner.episodic_returns)
                    max_return = np.max(runner.episodic_returns)
                    min_return = np.min(runner.episodic_returns)
                    runner.episodic_returns.pop(0)

                    writer.add_scalar('pretrain/reward/eprewmean', mean_return, update)
                    writer.add_scalar('pretrain/reward/eprewmax', max_return, update)
                    writer.add_scalar('pretrain/reward/eprewmin', min_return, update)
                    print("Return mean|max|min", mean_return, max_return, min_return)

                writer.add_scalar('pretrain/ent/camera_entropy', np.mean(camera_entropy), update)
                writer.add_scalar('pretrain/ent/action_entropy', np.mean(action_entropy), update)
                writer.add_scalar('pretrain/loss/value_loss', np.mean(value_loss), update)
                print("Pre-training Update:", update, " Updates remaining: ", nbiupdates - update)

            torch.save(self.model.state_dict(), writer.get_logdir() + '/checkpoint_value_pretraining')
            print("Completed training the Value function!! Now training both, value and policy!!")

        runner.state = env.reset()
        save_video(args, runner)
        # reset environment here and put the state into the runner
        runner.state = env.reset()
        runner.episodic_returns = []
        runner.returns = np.zeros(runner.nenv)
        # optimize policy and the value function together
        print("Starting with Policy Optimization!")
        for update in range(nupdates):

            # run a couple of steps in the environment
            b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values, epinfo = runner.run()
            print("epinfo", epinfo)

            loss, policy_loss, policy_loss_mean, entropy_loss, \
            value_loss, surrogate_ratio_loss, surrogate_clamp, camera_entropy, \
            action_entropy, gradnorm = [], [], [], [], [], [], [], [], [], []

            # shuffle data, create mini-batches and run few epochs to complete a single update
            inds = np.arange(nbatch)
            for _ in range(args.nepochs):
                # randomize the indexes
                np.random.shuffle(inds)
                # 0 to batch_size with batch_train_size step
                for start in range(0, nbatch, nbatch_train):

                    # get training batch
                    end = start + nbatch_train
                    mbinds = inds[start:end]
                    slices = (arr[mbinds] for arr in (b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions,
                                                      b_actions, b_values))

                    # PPO update
                    _loss, _policy_loss, _policy_loss_mean, _entropy_loss, _value_loss, \
                    _surrogate_ratio_loss, _surrogate_clamp, _camera_entropy, _action_entropy, \
                    _gradnorm, _gradnorm_actor, _gradnorm_critic = ppo.update(*slices)

                    # book keeping
                    loss.append(_loss)
                    policy_loss.append(_policy_loss)
                    policy_loss_mean.append(_policy_loss_mean)
                    entropy_loss.append(_entropy_loss)
                    value_loss.append(_value_loss)
                    surrogate_ratio_loss.append(_surrogate_ratio_loss)
                    surrogate_clamp.append(_surrogate_clamp)
                    camera_entropy.append(_camera_entropy)
                    action_entropy.append(_action_entropy)
                    gradnorm.append(_gradnorm)

            # tensor-board stuff
            epinfobuf.extend(epinfo)
            if len(runner.episodic_returns) >= 3:

                mean_return = np.mean(runner.episodic_returns)
                max_return = np.max(runner.episodic_returns)
                min_return = np.min(runner.episodic_returns)
                runner.episodic_returns.pop(0)

                writer.add_scalar('train/reward/eprewmean', mean_return, update)
                writer.add_scalar('train/reward/eprewmax', max_return, update)
                writer.add_scalar('train/reward/eprewmin', min_return, update)
                print("Return mean|max|min", mean_return, max_return, min_return)

            writer.add_scalar('train/eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]), update)
            writer.add_scalar('train/loss/loss', np.mean(loss), update)
            writer.add_scalar('train/loss/policy_loss_mean', np.mean(policy_loss_mean), update)
            writer.add_scalar('train/loss/gradient_norm', np.mean(gradnorm), update)

            writer.add_scalar('train/loss/ratio_actions_loss', np.mean(surrogate_ratio_loss), update)
            writer.add_scalar('train/loss/clamp_actions_loss', np.mean(surrogate_clamp), update)
            writer.add_scalar('train/lr/learning_rate', ppo.lrs, update)

            # entropy_stuff
            writer.add_scalar('train/loss/entropy_loss', np.mean(entropy_loss), update)
            writer.add_scalar('train/loss/camera_entropy', np.mean(camera_entropy), update)
            writer.add_scalar('train/ent/action_entropy', np.mean(action_entropy), update)
            writer.add_scalar('train/lr/entropy_coef_cam', ppo.ent_coef_cam, update)
            writer.add_scalar('train/lr/entropy_coef_bin', ppo.ent_coef_bin, update)

            # value stuff
            writer.add_scalar('train/loss/value_loss', np.mean(value_loss), update)

            # log every k epochs
            if update % args.log_interval == 0 or update == 1:
                print(safemean([epinfo['r'] for epinfo in epinfobuf]), ppo.eps_clip)

            # anneal epsilon
            if args.anneal_eps:
                frac = 1.0 - (update - 1.0) / nupdates
                ppo.eps_clip = ppo.eps_clip_initial * frac

            # anneal actor learning rate
            if args.anneal_lrs:
                frac = 1.0 - (update - 1.0) / nupdates
                ppo.lrs = ppo.lrs_initial * frac

            # anneal entropy coef bin
            if args.anneal_ent_coef_bin:
                frac = 1.0 - (update - 1.0) / nupdates
                ppo.ent_coef_bin = ppo.ent_coef_bin_initial * frac

            # anneal entropy coef cam
            if args.anneal_ent_coef_bin:
                frac = 1.0 - (update - 1.0) / nupdates
                ppo.ent_coef_cam = ppo.ent_coef_cam_initial * frac

            print("Training Update: ", update, " Updates remaining: ", nupdates - update)
            if args.save_model_weights and update % 5 == 0:
                torch.save(self.model.state_dict(), writer.get_logdir() + '/checkpoint' + str(update))
                save_video(args, runner_video)


def train_model_with_config(config_file: str, env, model, dataset):
    args = Config(config_file)
    rl = ReinforcementLearning()
    rl.train(args, env, model, dataset)
    env.close()


if __name__ == "__main__":
    """ main """

    # parse config
    args = Config()

    # import model and dataset
    Network = import_module(args.model).Network
    dataset = import_module(args.dataset)

    # instantiate environment
    env = make_minerl(args.env, n_cpu=args.num_env, seq_len=dataset.SEQ_LENGTH,
                      transforms=dataset.DATA_TRANSFORM, input_space=dataset.INPUT_SPACE,
                      env_server=True, frame_skip=args.frame_skip)

    # initialize actor-critic model
    print("Initializing network!")
    model = Network()

    # load pre-trained model parameters
    if args.load_model:
        print("Loading behaviour cloned model!")
        model.load_state_dict(torch.load(args.pre_trained_params))

    # start training
    rl = ReinforcementLearning()
    rl.train(args, env, model, dataset)
    env.close()
