# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math

from torch._C import device
from mbrl.third_party.unrolled_actor_soft_critic.agent.unrolling import debug_Qvalues
from mbrl.util.replay_buffer import ReplayBuffer
from mbrl.algorithms.tensorboard_logger import TensorboardAdapter
from mbrl.third_party.unrolled_actor_soft_critic.agent.uasc_agent import UASCAgent
import os
import time
from pathlib import Path
from typing import Optional, Tuple, cast

import gym
import hydra.utils
import numpy as np
import omegaconf
import torch

import mbrl.constants
import mbrl.models
import mbrl.planning
import mbrl.third_party.unrolled_actor_soft_critic as uasc
import mbrl.types
import mbrl.util
import mbrl.util.common
import mbrl.util.math
from mbrl.third_party.unrolled_actor_soft_critic.buffer import CircularReplayBuffer

MBPO_LOG_FORMAT = mbrl.constants.EVAL_LOG_FORMAT + [("epoch", "E", "int"), ("rollout_length", "RL", "int")]

class ShiftingGaussian(object):
    """Gaussian with continuous parameter updates, used to figure out how long the unrolling should be.
    """
    def __init__(self, loc=0., ssq=1., num=10000):
        self.loc = loc
        self.ssq = ssq
        self.num = num

    def std(self):
        return (self.ssq/(self.num - 1))**.5

    def zvalue(self, sample):
        return (sample - self.loc)/self.std()

    def to_steps(self, sample, stddev_span=1, max_steps=4):
        zv = self.zvalue(sample)
        # < 1 sd = 0 steps
        # > 1 sd = `max_steps` steps
        # linearly interpolate in between.
        std = float(self.std())
        max_steps = float(max_steps)

        if torch.is_tensor(zv):
            zv = zv.detach().cpu().numpy()
        zv = np.array(zv)

        rv = max_steps - max_steps*(zv + stddev_span)/(2*stddev_span)
        rv = np.rint(rv).astype(np.int32)
        rv = np.minimum(np.maximum(rv, 0), max_steps)
        return rv

    def update(self, sample):
        """Update the distribution parameters with the samples. Uses Welford's method.
        """
        for x in sample:
            old_loc = self.loc
            self.num += 1
            self.loc += (x-self.loc)/self.num
            self.ssq += (x-self.loc)*(x-old_loc)

def rollout_model_and_populate_imagined_buffer(
        model_env: mbrl.models.ModelEnv, replay_buffer: mbrl.util.ReplayBuffer, agent: UASCAgent,
        imagined_buffer: uasc.CircularReplayBuffer, rollout_samples_action: bool, rollout_horizon: int, batch_size: int):

    batch = replay_buffer.sample(batch_size)
    initial_obs, *_ = cast(mbrl.types.TransitionBatch, batch).astuple()
    obs = model_env.reset(initial_obs_batch=cast(np.ndarray, initial_obs), return_as_np=True)
    accum_dones = np.zeros(obs.shape[0], dtype=bool)
    for i in range(rollout_horizon):
        action = agent.act(obs, sample=rollout_samples_action, batched=True)
        pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action, sample=True)
        imagined_buffer.add_batch(
            obs[~accum_dones],
            action[~accum_dones],
            pred_rewards[~accum_dones],
            pred_next_obs[~accum_dones],
            pred_dones[~accum_dones],
            pred_dones[~accum_dones],
        )
        obs = pred_next_obs
        accum_dones |= pred_dones.squeeze()


def evaluate(env: gym.Env, agent: uasc.Agent, num_episodes: int, video_recorder: uasc.VideoRecorder) -> float:
    episode_rewards = []
    for episode in range(num_episodes):
        sum_reward = 0.
        obs = env.reset()
        video_recorder.init(enabled=(episode == 0))
        done = False
        while not done:
            action = agent.act(obs)
            obs, reward, done, _ = env.step(action)
            video_recorder.record(env)
            sum_reward += reward
        episode_rewards.append(sum_reward)
    return episode_rewards


def train(env: gym.Env, test_env: gym.Env, termination_fn: mbrl.types.TermFnType, cfg: omegaconf.DictConfig,
    silent: bool = False, work_dir: Optional[os.PathLike] = None) -> np.float32:
    # ------------------- Initialization -------------------
    debug_mode = cfg.get("debug_mode", False)

    obs_shape = env.observation_space.shape
    act_shape = env.action_space.shape

    mbrl.planning.complete_agent_cfg(env, cfg.algorithm.agent)
    agent : UASCAgent = hydra.utils.instantiate(cfg.algorithm.agent)

    work_dir = Path(work_dir or os.getcwd())
    omegaconf.OmegaConf.save(config=cfg, f=(work_dir / "config.yml"))

    # We inherit enable_back_compatible=True from sac_agent:
    logger = TensorboardAdapter(work_dir)
    logger.register_group(mbrl.constants.RESULTS_LOG_NAME, MBPO_LOG_FORMAT, color="green", dump_frequency=1)
    video_recorder = uasc.VideoRecorder(work_dir if cfg.save_video else None)

    rng = np.random.default_rng(seed=cfg.seed)
    torch_generator = torch.Generator(device=cfg.device)
    if cfg.seed is not None:
        torch_generator.manual_seed(cfg.seed)

    # -------------- Create initial overrides. dataset --------------
    dynamics_model = mbrl.util.common.create_one_dim_tr_model(cfg, obs_shape, act_shape)

    # Stores real transitions from the environment, used to train the model:
    replay_buffer = mbrl.util.common.create_replay_buffer(cfg, obs_shape, act_shape, rng=rng)
    mbrl.util.common.rollout_agent_trajectories(env,
        cfg.algorithm.initial_exploration_steps, mbrl.planning.RandomAgent(env) if cfg.algorithm.random_initial_explore else agent,
        {} if cfg.algorithm.random_initial_explore else {"sample": True, "batched": False}, replay_buffer=replay_buffer)

    # --------------------- Training Loop ---------------------
    rollout_batch_size = (cfg.overrides.effective_model_rollouts_per_step * cfg.algorithm.freq_train_model)
    trains_per_epoch = int(np.ceil(cfg.overrides.epoch_length / cfg.overrides.freq_train_model))
    updates_made = 0
    env_steps = 0
    # model_env is the Gym environment based on the learned dynamics:
    model_env = mbrl.models.ModelEnv(env, dynamics_model, termination_fn, None, generator=torch_generator)
    # This trains the model from environment data:
    model_trainer = mbrl.models.ModelTrainer(dynamics_model, optim_lr=cfg.overrides.model_lr, weight_decay=cfg.overrides.model_wd, logger=None if silent else logger)
    best_eval_reward = -np.inf
    epoch = 0

    imagined_buffer = CircularReplayBuffer(obs_shape, act_shape, device=torch.device(cfg.device))

    # Number of steps to unroll, modeled as a shifting gaussian:
    # This keeps track of the distribution of previously-seen log-PDF
    # We use the first encountered log-PDF to figure out how many steps to unroll.
    steps_unrolling = ShiftingGaussian()

    # Prepare the states for which we want to log Q-values:
    if cfg.algorithm.log_q_values.from_states:
        # Load the states from the file
        log_q_rp_load = ReplayBuffer(replay_buffer.capacity, obs_shape, act_shape, rng=None, max_trajectory_length=0)
        log_q_rp_load.load(cfg.algorithm.log_q_values.from_states)
        log_q_states = log_q_rp_load.obs
        del log_q_rp_load

        np.random.shuffle(log_q_states)
        log_q_states = log_q_states[:min(len(log_q_states), cfg.algorithm.log_q_values.num_elements),...]
        print(f"Logging Q values for {len(log_q_states)} states every {cfg.algorithm.log_q_values.every} steps.")

    # The training loop is structured a little strangely, so here's a description of how it works.
    # 
    # Training happens for a number of environment steps (each step is a single transition in the true environment),
    # and is grouped into epochs (each epoch is `cfg.overrides.epoch_length` ~ 200 env steps, without regard to 
    # the length of actual trajectories.)
    # 
    # Within each epoch, there are four steps:
    #  1. Advance the true environment using the current agent.
    #    a. Store the new transition in `replay_buffer`, evicting the oldest data if full.
    #  2. Update the imagined environment using data from the true environment
    #    a. This is done only once every `cfg.overrides.freq_train_model` ~ 200 env steps.
    #    b. First the dynamics model is trained from `replay_buffer`, then it is used to generate new data for `imagined_buffer`.
    #  3. Update the agent using the imagined environment.
    #    a. This is done every `cfg.overrides.agent_updates_every_steps` ~ 1 env steps
    #    b. Each time it is done, it is repeated `cfg.overrides.num_agent_updates_per_step` ~ 20 times
    #    c. Here's where the unrolling magic happens.
    #  4. Log the performance every `cfg.overrides.epoch_length` ~ 200 steps, and save the best video and weights.

    while env_steps < cfg.overrides.num_steps:
        # Resize the imagined transition buffer if necessary. Copying is handled by the resize method.
        env_rollout_length = int(mbrl.util.math.truncated_linear(*(cfg.overrides.rollout_schedule + [epoch + 1])))
        imagined_buffer.resize(env_rollout_length * rollout_batch_size * trains_per_epoch * cfg.overrides.num_epochs_to_retain_imagined_buffer)

        if debug_mode:
            print(f"Epoch: {epoch}.\tSAC buffer size: {len(imagined_buffer)}.\tRollout length: {env_rollout_length}. Steps: {env_steps}")

        obs, done = None, True
        for steps_epoch in range(cfg.overrides.epoch_length):
            if done:
                obs, done = env.reset(), False
            # --- Doing env step and adding to model dataset ---
            next_obs, _, done, _ = mbrl.util.common.step_env_and_add_to_buffer(env, obs, agent, {}, replay_buffer)

            # --------------- Model Training -----------------
            if (env_steps + 1) % cfg.overrides.freq_train_model == 0:
                # Update the dynamics model, then use that to imagine new trajectories:
                mbrl.util.common.train_model_and_save_model_and_data(dynamics_model, model_trainer, cfg.overrides, replay_buffer, work_dir=work_dir)
                rollout_model_and_populate_imagined_buffer(model_env, replay_buffer, agent, imagined_buffer,
                    cfg.algorithm.rollout_samples_action, env_rollout_length, rollout_batch_size)

            # --------------- Agent Training -----------------
            for _ in range(cfg.overrides.num_agent_updates_per_step):
                if (env_steps + 1) % cfg.overrides.agent_updates_every_steps != 0 or len(imagined_buffer) < rollout_batch_size:
                    break  # only update every once in a while
                agent.update(cfg.algorithm.agent, imagined_buffer, dynamics_model, logger, updates_made, rng=rng, device=torch.device(cfg.device), steps_unrolling=steps_unrolling)
                updates_made += 1
                if not silent and updates_made % cfg.log_frequency_agent == 0:
                    logger.dump(updates_made, save=True)

            # ------ Epoch ended (evaluate and save model) ------
            if (env_steps + 1) % cfg.overrides.epoch_length == 0:
                episode_rewards = evaluate(test_env, agent, cfg.algorithm.num_eval_episodes, video_recorder)

                # Also log raw episode rewards
                logger.log_aux("episode_rewards", [env_steps] + episode_rewards, flush=True)

                reward_mean, reward_std = np.mean(episode_rewards), np.std(episode_rewards)
                results = {"epoch": epoch, "env_step": env_steps, "episode_reward": reward_mean, "episode_reward_std": reward_std, "env_rollout_length": env_rollout_length}
                logger.log_data(mbrl.constants.RESULTS_LOG_NAME, results, step=env_steps)
                logger.log_distrib("episode_reward_dist", episode_rewards, step=env_steps)

                if reward_mean - reward_std > best_eval_reward:
                    video_recorder.save(f"{epoch}.mp4")
                    best_eval_reward = reward_mean - reward_std
                    (work_dir / "saved_version.txt").write_text(str(results))
                    torch.save(agent.critic.state_dict(), work_dir / "critic.pth")
                    torch.save(agent.actor.state_dict(), work_dir / "actor.pth")

                # If we have reached the end, then save the agent.
                if (env_steps + 1) >= cfg.overrides.num_steps:
                    video_recorder.save(f"ending_{epoch}.mp4")
                    (work_dir / "ending_results.txt").write_text(str(results))
                    torch.save(agent.critic.state_dict(), work_dir / "ending_critic.pth")
                    torch.save(agent.actor.state_dict(), work_dir / "ending_actor.pth")

                epoch += 1

            # Trace q values as training progresses:
            if cfg.algorithm.log_q_values.from_states:
                if env_steps % cfg.algorithm.log_q_values.every == 0:
                    # Log Q-values
                    debug_V0, debug_VH = debug_Qvalues(cfg.algorithm.agent, dynamics_model, agent, log_q_states, max_rollout_length=cfg.algorithm.log_q_values.unroll, device=cfg.device)
                    for k in range(len(debug_VH)):
                        logger.log_aux("debug_Q", [env_steps, k] + debug_VH[k,:].tolist(), flush=False)
                    logger.log_aux("debug_Q", [env_steps, "simple"] + debug_V0.ravel().tolist(), flush=True)

            env_steps += 1
            obs = next_obs
    return np.float32(best_eval_reward)
