#!/usr/bin/env python3
import sys
import logging
import json
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import time
import copy
from train.behavioral_cloning.datasets.experience import Experience
from train.envs.minerl_env import make_minerl
from train.common.config import Config
from train.reinforcment_learning.utils.utils import safemean, get_device, import_module
from train.reinforcment_learning.runner import Runner
from train.reinforcment_learning.ppo import PPO
from train.reinforcment_learning.save_video import save_video
from train.pytorch_wrapper.utils import import_from_package_by_name
from train.reinforcment_learning.eval_hook import EvalEnv
import traceback

import coloredlogs
coloredlogs.install(logging.INFO)


class ReinforcementLearning:
    def __init__(self):
        self.model = None

    def train(self, args, env, model, dataset):
        """ rl-train model """
        print("Starting with PPO training!")

        # put model to appropriate device
        device = get_device(args.gpu)
        self.model = model

        # some optimization hyper parameters
        betas = (0.9, 0.999)  # adam momentum parameters
        nbatch = args.num_env * args.nsteps  # size of the entire batch
        nbatch_train = nbatch // args.nminibatches  # size of the mini-batch
        nupdates = args.total_timesteps // nbatch  # number of updates
        nbiupdates = args.train_value_timesteps // nbatch  # number of updates

        # initialize PPO
        print("Initializing PPO!")
        ppo = PPO(model=self.model, dataset=dataset, gamma=args.gamma, gae_lambda=args.lam, eps_clip=args.eps_clip,
                  pg_coef=args.pg_coef, vf_coef=args.vf_coef, ent_coef_bin=args.ent_coef_bin, ent_coef_cam=args.ent_coef_cam,
                  lrs=args.lrs, betas=betas, clamping=args.clamping, clip_grad_norm=args.clip_grad_norm,
                  bound_value_loss=args.bound_value_loss, value_clip=args.value_clip, lrs_critic_head=args.lrs_critic_head,
                  scale_critic_norm=args.scale_critic_norm, device=device, envname=args.env)

        # initialize runner
        print("Initializing Runner!")
        runner = Runner(model=self.model, dataset=dataset, nenv=args.num_env, env=env, nsteps=args.nsteps,
                        seq_len=dataset.SEQ_LENGTH, gamma=args.gamma, lam=args.lam, episode_length=args.episode_length,
                        reward_scale=args.reward_scale, episodic_interactions=args.episodic_interactions,
                        safe_exploration=args.safe_exploration, task_based_reward=args.task_based_reward,
                        task_to_check=args.task_to_check, num_tasks_to_check=args.num_tasks_to_check,
                        incremental=args.incremental, incremental_threshold=args.incremental_threshold,
                        resource_penalization_scale=args.resource_penalization_scale,
                        timestep_based_reward=args.timestep_based_reward,
                        step_counter_worker_envs_reset=args.step_counter_worker_envs_reset,
                        envname=args.env,
                        device=device)

        # average reward of last _ episodes
        epinfobuf = deque(maxlen=args.log_num_epi_avg)
        timestr = time.strftime("%Y%m%d-%H%M%S")
        writer = SummaryWriter(log_dir=args.log_dir +"/" + args.experiment_name + "_"+ timestr)

        # log run settings
        text = ""
        arguments = np.sort([arg for arg in vars(args)])
        for arg in arguments:
            text += "**{}:** {}<br>".format(arg, getattr(args, arg))
            # log_writer.add_text(arg, "{}".format(getattr(args, arg)))
        writer.add_text("run_config", text)
        writer.add_text("cmd", " ".join(sys.argv))

        # store the config to the run folder
        args.record_dir = writer.get_logdir() + '/'

        with open(args.config_file) as json_file:
            data = json.load(json_file)

        with open(writer.get_logdir() + '/config.json', 'w') as outfile:
            json.dump(data, outfile)

        print("Initializing Video Runner!")
        video_runner = EvalEnv(env_name=args.env, input_space=dataset.INPUT_SPACE,
                        action_space=dataset.ACTION_SPACE, transforms=dataset.DATA_TRANSFORM, seq_len=dataset.SEQ_LENGTH,
                        trials=args.num_env, max_steps=6000, record_dir=writer.get_logdir(),
                        n_workers=args.num_env, watch_item=args.task_to_check, task_based_reward=args.task_based_reward,
                        task_to_check=args.task_to_check)

        best_model = None
        best_model_mean_return = -np.inf

        # burn in training of value function
        if args.train_value_first:
            print("Pre-training Value Function!")

            # iterate value function burn-in epochs
            for update in range(nbiupdates):

                try:
                    # perform rollout
                    b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values, epinfo = \
                        runner.run()
                    print("epinfo", epinfo)

                    # shuffle data, create minibatches and run for a few epochs to complete a single update
                    value_loss, camera_entropy, action_entropy = [], [], []
                    inds = np.arange(nbatch)

                    # update un current rollout for "val_epochs" epochs
                    for _ in range(args.val_epochs):
                        # randomize the indexes
                        np.random.shuffle(inds)
                        # 0 to batch_size with batch_train_size step
                        for start in range(0, nbatch, nbatch_train):

                            # slice out training batch
                            end = start + nbatch_train
                            mbinds = inds[start:end]
                            slices = (arr[mbinds] for arr in (b_returns, pov, bin_actions, camera_actions, b_actions))

                            # perform update of value head
                            _value_loss, _camera_entropy, _action_entropy = ppo.update_value_head(*slices)

                            # book keeping
                            value_loss.append(_value_loss)
                            camera_entropy.append(_camera_entropy)
                            action_entropy.append(_action_entropy)

                    if len(runner.episodic_returns) >= 3 * args.num_env:

                        mean_return = np.mean(runner.episodic_returns[-3 * args.num_env:])
                        if best_model is None:
                            best_model = copy.deepcopy(self.model.to('cpu'))
                            best_model_mean_return = mean_return
                            self.model.to(get_device(args.gpu))
                        elif best_model_mean_return < mean_return:
                            best_model = copy.deepcopy(self.model.to('cpu'))
                            best_model_mean_return = mean_return
                            self.model.to(get_device(args.gpu))

                        max_return = np.max(runner.episodic_returns[-3 * args.num_env:])
                        min_return = np.min(runner.episodic_returns[-3 * args.num_env:])

                        writer.add_scalar('pretrain/reward/eprewmean', mean_return, update)
                        writer.add_scalar('pretrain/reward/eprewmax', max_return, update)
                        writer.add_scalar('pretrain/reward/eprewmin', min_return, update)
                        print("Return mean|max|min", mean_return, max_return, min_return)

                    writer.add_scalar('pretrain/ent/camera_entropy', np.mean(camera_entropy), update)
                    writer.add_scalar('pretrain/ent/action_entropy', np.mean(action_entropy), update)
                    writer.add_scalar('pretrain/loss/value_loss', np.mean(value_loss), update)
                    print("Pre-training Update:", update, " Updates remaining: ", nbiupdates - update)

                    if update % 10 == 0:
                        if best_model is not None:
                            torch.save(best_model.state_dict(), writer.get_logdir() + f'/checkpoint_value_pretraining_best')
                        torch.save(self.model.state_dict(), writer.get_logdir() + f'/checkpoint_value_pretraining_cp{update}')

                except Exception as e:
                    traceback.print_exc()
                    print(f"RL-ValuePretraining: Caught exception during value function pretraining: {e}")

            torch.save(self.model.state_dict(), writer.get_logdir() + '/checkpoint_value_pretraining')
            print("Completed training the Value function!! Now training both, value and policy!!")

        if args.use_best_value_function and best_model is not None:
            self.model = best_model.to(get_device(args.gpu))

        runner.state = env.reset()
        video_runner.run(self.model, env=env)
        # reset environment here and put the state into the runner
        runner.state = env.reset()
        runner.episodic_returns = []
        runner.returns = np.zeros(runner.nenv)
        # optimize policy and the value function together
        print("Starting with Policy Optimization!")
        best_mean_return = -np.inf
        for update in range(nupdates):
            try:
                # run a couple of steps in the environment
                b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions, b_actions, b_values, epinfo = runner.run()
                print("epinfo", epinfo)

                loss, policy_loss, policy_loss_mean, entropy_loss, \
                value_loss, surrogate_ratio_loss, surrogate_clamp, camera_entropy, \
                action_entropy, gradnorm, gradnorm_actor, gradnorm_critic = [], [], [], [], [], [], [], [], [], [], [], []

                # shuffle data, create mini-batches and run few epochs to complete a single update
                inds = np.arange(nbatch)
                for _ in range(args.nepochs):
                    # randomize the indexes
                    np.random.shuffle(inds)
                    # 0 to batch_size with batch_train_size step
                    for start in range(0, nbatch, nbatch_train):

                        # get training batch
                        end = start + nbatch_train
                        mbinds = inds[start:end]
                        slices = (arr[mbinds] for arr in (b_returns, b_advs, b_logprobs, pov, bin_actions, camera_actions,
                                                          b_actions, b_values))

                        # PPO update
                        _loss, _policy_loss, _policy_loss_mean, _entropy_loss, _value_loss, \
                            _surrogate_ratio_loss, _surrogate_clamp, _camera_entropy, _action_entropy, _gradnorm, \
                            _gradnorm_actor, _gradnorm_critic = ppo.update(*slices)

                        # book keeping
                        loss.append(_loss)
                        policy_loss.append(_policy_loss)
                        policy_loss_mean.append(_policy_loss_mean)
                        entropy_loss.append(_entropy_loss)
                        value_loss.append(_value_loss)
                        surrogate_ratio_loss.append(_surrogate_ratio_loss)
                        surrogate_clamp.append(_surrogate_clamp)
                        camera_entropy.append(_camera_entropy)
                        action_entropy.append(_action_entropy)
                        gradnorm.append(_gradnorm)
                        gradnorm_actor.append(_gradnorm_actor)
                        gradnorm_critic.append(_gradnorm_critic)

                # tensor-board stuff
                epinfobuf.extend(epinfo)
                if len(runner.episodic_returns) >= 3 * args.num_env:

                    mean_return = np.mean(runner.episodic_returns[-3 * args.num_env:])
                    max_return = np.max(runner.episodic_returns[-3 * args.num_env:])
                    min_return = np.min(runner.episodic_returns[-3 * args.num_env:])

                    writer.add_scalar('train/reward/eprewmean', mean_return, update)
                    writer.add_scalar('train/reward/eprewmax', max_return, update)
                    writer.add_scalar('train/reward/eprewmin', min_return, update)
                    print("Return mean|max|min", mean_return, max_return, min_return)

                    if mean_return > best_mean_return:
                        print("Found new best model ...")
                        best_mean_return = mean_return
                        torch.save(self.model.state_dict(), writer.get_logdir() + '/checkpoint_best')

                writer.add_scalar('train/eplenmean', safemean([epinfo['l'] for epinfo in epinfobuf]), update)
                writer.add_scalar('train/loss/loss', np.mean(loss), update)
                writer.add_scalar('train/loss/policy_loss_mean', np.mean(policy_loss_mean), update)
                writer.add_scalar('train/loss/gradient_norm', np.mean(gradnorm), update)
                writer.add_scalar('train/loss/gradient_norm_actor', np.mean(gradnorm_actor), update)
                writer.add_scalar('train/loss/gradient_norm_critic', np.mean(gradnorm_critic), update)

                writer.add_scalar('train/loss/ratio_actions_loss', np.mean(surrogate_ratio_loss), update)
                writer.add_scalar('train/loss/clamp_actions_loss', np.mean(surrogate_clamp), update)
                writer.add_scalar('train/lr/learning_rate', ppo.lrs, update)

                # entropy_stuff
                writer.add_scalar('train/loss/entropy_loss', np.mean(entropy_loss), update)
                writer.add_scalar('train/loss/camera_entropy', np.mean(camera_entropy), update)
                writer.add_scalar('train/ent/action_entropy', np.mean(action_entropy), update)
                writer.add_scalar('train/lr/entropy_coef_cam', ppo.ent_coef_cam, update)
                writer.add_scalar('train/lr/entropy_coef_bin', ppo.ent_coef_bin, update)

                # value stuff
                writer.add_scalar('train/loss/value_loss', np.mean(value_loss), update)

                # log every k epochs
                if update % args.log_interval == 0 or update == 1:
                    print(safemean([epinfo['r'] for epinfo in epinfobuf]), ppo.eps_clip)

                # anneal epsilon
                if args.anneal_eps:
                    frac = 1.0 - (update - 1.0) / nupdates
                    ppo.eps_clip = ppo.eps_clip_initial * frac

                # anneal actor learning rate
                if args.anneal_lrs:
                    frac = 1.0 - (update - 1.0) / nupdates
                    ppo.lrs = ppo.lrs_initial * frac

                # anneal entropy coef bin
                if args.anneal_ent_coef_bin:
                    frac = 1.0 - (update - 1.0) / nupdates
                    ppo.ent_coef_bin = ppo.ent_coef_bin_initial * frac

                # anneal entropy coef cam
                if args.anneal_ent_coef_bin:
                    frac = 1.0 - (update - 1.0) / nupdates
                    ppo.ent_coef_cam = ppo.ent_coef_cam_initial * frac

                print("Training Update: ", update, " Updates remaining: ", nupdates - update)
                if args.save_model_weights and update % 5 == 0:
                    torch.save(self.model.state_dict(), writer.get_logdir() + '/checkpoint' + str(update))
                    if update % 20 ==0 and update != 0:
                        video_runner.run(self.model, env=env)
            except Exception as e:
                traceback.print_exc()
                print(f"RL-Training: Caught exception during training: {e}")

        # load previously best model parameters
        best_params = torch.load(writer.get_logdir() + '/checkpoint_best')
        self.model.load_state_dict(best_params)


def train_model_with_config(config_file: str, env, model, dataset):
    args = Config(config_file)
    rl = ReinforcementLearning()
    rl.train(args, env, model, dataset)
    env.close()


if __name__ == "__main__":
    """ main """

    # parse config
    args = Config()

    # import model and dataset
    Network = import_module(args.model).Network
    dataset = import_module(args.dataset)

    if args.experience_restart_folder:

        env = make_minerl(
            env_id=args.env,
            n_cpu=args.num_env,
            seq_len=dataset.SEQ_LENGTH,
            transforms=dataset.DATA_TRANSFORM,
            input_space=dataset.INPUT_SPACE,
            env_server=True,
            seed=None,  # seed needs to be None when playing back a recording
            experience_recording=False,  # here, you give the above initialized experience
            replay_until=args.replay_until,  # use 'old_version' if you want the old ExperienceSamplingWrapper
            checkpoint=None,  # if you want to use the last checkpoint, use this
            frame_skip=args.frame_skip,
            make_new_recording=False,  # this needs to be none when playing back an experience
            experience_folder=args.exp_folder, craft_equip=True, port=args.port)

    elif args.experience_restart:

        exp = Experience.load("train/deterministic_envs/recordings/diamond4.p")
        env = make_minerl(args.env, n_cpu=args.num_env, seq_len=dataset.SEQ_LENGTH,
                          transforms=dataset.DATA_TRANSFORM, input_space=dataset.INPUT_SPACE,
                          env_server=True,
                          experience_recording=exp,
                          checkpoint=args.experience_startpoint, frame_skip=args.frame_skip, craft_equip=True)
    else:
        # instantiate environment
        env = make_minerl(args.env, port=args.port, n_cpu=args.num_env, seq_len=dataset.SEQ_LENGTH,
                          transforms=dataset.DATA_TRANSFORM, input_space=dataset.INPUT_SPACE,
                          env_server=True, frame_skip=args.frame_skip, craft_equip=True)

    # initialize actor-critic model
    print("Initializing network!")
    model = Network()
    device = get_device(args.gpu)
    model = model.to(device)

    # load pre-trained model parameters
    if args.load_model:
        print("Loading behaviour cloned model!")
        state_dict = model.state_dict()
        state_dict.update(torch.load(args.pre_trained_params, map_location=device))
        model.load_state_dict(state_dict)

        # not useful to use action statistics and resource penalization at the same time
        assert not (args.use_action_statistics and args.resource_penalization_scale > 0)
        if args.use_action_statistics:
            path = args.action_statistics_path
            # # compile dataset
            train_data = dataset.compile_dataset(root=path, index_file="train", subset=1)
            model.set_action_statistics(train_data.action_statistics)

    # start training
    rl = ReinforcementLearning()
    rl.train(args, env, model, dataset)
    env.close()
