"""
Reinforcement Learning with Prior Data (RLPD) agent training script.

Does not support image observations right now.
"""

import os
import pickle
import numpy as np
import torch
import logging
import wandb
import hydra
from collections import deque

log = logging.getLogger(__name__)
from util.timer import Timer
from agent.finetune.train_agent import TrainAgent
from util.scheduler import CosineAnnealingWarmupRestarts


class TrainCalQLAgent(TrainAgent):
    def __init__(self, cfg):
        super().__init__(cfg)
        assert self.n_envs == 1, "Cal-QL only supports single env for now"

        # Train mode (offline or online)
        self.train_online = cfg.train.train_online

        # Build dataset
        self.dataset_offline = hydra.utils.instantiate(cfg.offline_dataset)

        # note the discount factor gamma here is applied to reward every act_steps, instead of every env step
        self.gamma = cfg.train.gamma

        # Optimizer
        self.actor_optimizer = torch.optim.AdamW(
            self.model.network.parameters(),
            lr=cfg.train.actor_lr,
            weight_decay=cfg.train.actor_weight_decay,
        )
        self.actor_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.actor_optimizer,
            first_cycle_steps=cfg.train.actor_lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=cfg.train.actor_lr,
            min_lr=cfg.train.actor_lr_scheduler.min_lr,
            warmup_steps=cfg.train.actor_lr_scheduler.warmup_steps,
            gamma=1.0,
        )
        self.critic_optimizer = torch.optim.AdamW(
            self.model.critic.parameters(),
            lr=cfg.train.critic_lr,
            weight_decay=cfg.train.critic_weight_decay,
        )
        self.critic_lr_scheduler = CosineAnnealingWarmupRestarts(
            self.critic_optimizer,
            first_cycle_steps=cfg.train.critic_lr_scheduler.first_cycle_steps,
            cycle_mult=1.0,
            max_lr=cfg.train.critic_lr,
            min_lr=cfg.train.critic_lr_scheduler.min_lr,
            warmup_steps=cfg.train.critic_lr_scheduler.warmup_steps,
            gamma=1.0,
        )

        # Perturbation scale
        self.target_ema_rate = cfg.train.target_ema_rate

        # Number of random actions to sample for Cal-QL
        self.n_random_actions = cfg.train.n_random_actions

        # Reward scale
        self.scale_reward_factor = cfg.train.scale_reward_factor

        # Number of critic updates
        self.num_update = cfg.train.num_update

        # Buffer size
        self.buffer_size = cfg.train.buffer_size

        # Online only configs
        if self.train_online:
            # number of episode to colect per epoch for training
            self.n_episode_per_epoch = cfg.train.n_episode_per_epoch

            # UTD ratio
            self.online_utd_ratio = cfg.train.online_utd_ratio

        # Eval episodes
        self.n_eval_episode = cfg.train.n_eval_episode

        # Exploration steps at the beginning - using randomly sampled action
        self.n_explore_steps = cfg.train.n_explore_steps

        # Initialize temperature parameter for entropy
        init_temperature = cfg.train.init_temperature
        self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device)
        self.log_alpha.requires_grad = True
        self.automatic_entropy_tuning = cfg.train.automatic_entropy_tuning
        self.target_entropy = cfg.train.target_entropy
        self.log_alpha_optimizer = torch.optim.Adam(
            [self.log_alpha],
            lr=cfg.train.critic_lr,
        )

    def run(self):
        # make a FIFO replay buffer for obs, action, and reward
        obs_buffer = deque(maxlen=self.buffer_size)
        next_obs_buffer = deque(maxlen=self.buffer_size)
        action_buffer = deque(maxlen=self.buffer_size)
        reward_buffer = deque(maxlen=self.buffer_size)
        reward_to_go_buffer = deque(maxlen=self.buffer_size)
        terminated_buffer = deque(maxlen=self.buffer_size)
        if not self.train_online:
            obs_array = np.array(obs_buffer)
            next_obs_array = np.array(next_obs_buffer)
            actions_array = np.array(action_buffer)
            rewards_array = np.array(reward_buffer)
            reward_to_go_array = np.array(reward_to_go_buffer)
            terminated_array = np.array(terminated_buffer)

        # load offline dataset into replay buffer
        dataloader_offline = torch.utils.data.DataLoader(
            self.dataset_offline,
            batch_size=len(self.dataset_offline),
            drop_last=False,
        )
        for batch in dataloader_offline:
            actions, states_and_next, rewards, terminated, reward_to_go = batch
            states = states_and_next["state"]
            next_states = states_and_next["next_state"]
            obs_buffer_off = states.cpu().numpy()
            next_obs_buffer_off = next_states.cpu().numpy()
            action_buffer_off = actions.cpu().numpy()
            reward_buffer_off = rewards.cpu().numpy().flatten()
            reward_to_go_buffer_off = reward_to_go.cpu().numpy().flatten()
            terminated_buffer_off = terminated.cpu().numpy().flatten()

        # Start training loop
        timer = Timer()
        run_results = []
        done_venv = np.zeros((1, self.n_envs))
        while self.itr < self.n_train_itr:
            if self.itr % 1000 == 0:
                print(f"Finished training iteration {self.itr} of {self.n_train_itr}")

            # Prepare video paths for each envs --- only applies for the first set of episodes if allowing reset within iteration and each iteration has multiple episodes from one env
            options_venv = [{} for _ in range(self.n_envs)]
            if self.itr % self.render_freq == 0 and self.render_video:
                for env_ind in range(self.n_render):
                    options_venv[env_ind]["video_path"] = os.path.join(
                        self.render_dir, f"itr-{self.itr}_trial-{env_ind}.mp4"
                    )

            # Define train or eval - all envs restart
            eval_mode = (
                self.itr % self.val_freq == 0
                and self.itr >= self.n_explore_steps
                and not self.force_train
            )
            # during eval, we collect a fixed number of episodes, so we set n_steps to a large value
            if eval_mode:
                n_steps = int(1e5)
            elif not self.train_online:
                n_steps = 0
            else:
                n_steps = int(1e5)  # use episodes
            self.model.eval() if eval_mode else self.model.train()

            # Reset env before iteration starts (1) if specified, (2) at eval mode, or (3) at the beginning
            firsts_trajs = np.empty((0, self.n_envs))
            if self.reset_at_iteration or eval_mode or self.itr == 0:
                prev_obs_venv = self.reset_env_all(options_venv=options_venv)
                firsts_trajs = np.vstack((firsts_trajs, np.ones((1, self.n_envs))))
            else:
                # if done at the end of last iteration, then the envs are just reset
                firsts_trajs = np.vstack((firsts_trajs, done_venv))
            reward_trajs = np.empty((0, self.n_envs))

            # Collect a set of trajectories from env
            cnt_episode = 0
            for env_step in range(n_steps):
                if env_step % 100 == 0:
                    print(f"Completed environment step {env_step}")

                # Select action
                if self.itr < self.n_explore_steps:
                    action_venv = self.venv.action_space.sample()
                else:
                    with torch.no_grad():
                        cond = {
                            "state": torch.from_numpy(prev_obs_venv["state"])
                            .float()
                            .to(self.device)
                        }
                        samples = (
                            self.model(
                                cond=cond,
                                deterministic=eval_mode,
                            )
                            .cpu()
                            .numpy()
                        )  # n_env x horizon x act
                    action_venv = samples[:, : self.act_steps]

                # Apply multi-step action
                obs_venv, reward_venv, done_venv, info_venv = self.venv.step(
                    action_venv
                )
                reward_trajs = np.vstack((reward_trajs, reward_venv))
                firsts_trajs = np.vstack((firsts_trajs, done_venv))
                terminated_venv = done_venv.copy()

                # add to buffer in train mode
                if not eval_mode:
                    for i in range(self.n_envs):
                        obs_buffer.append(prev_obs_venv["state"][i])
                        if "final_obs" in info_venv[i]:  # truncated
                            next_obs_buffer.append(info_venv[i]["final_obs"]["state"])
                            terminated_venv[i] = False
                        else:  # first obs in new episode
                            next_obs_buffer.append(obs_venv["state"][i])
                        action_buffer.append(action_venv[i])
                        reward_buffer.append(reward_venv[i] * self.scale_reward_factor)
                        terminated_buffer.append(terminated_venv[i])
                prev_obs_venv = obs_venv

                # check if enough eval episodes are done
                cnt_episode += np.sum(done_venv)
                if eval_mode and cnt_episode >= self.n_eval_episode:
                    break
                if not eval_mode and cnt_episode >= self.n_episode_per_epoch:
                    break

            # Summarize episode reward --- this needs to be handled differently depending on whether the environment is reset after each iteration. Only count episodes that finish within the iteration.
            episodes_start_end = []
            for env_ind in range(self.n_envs):
                env_steps = np.where(firsts_trajs[:, env_ind] == 1)[0]
                for i in range(len(env_steps) - 1):
                    start = env_steps[i]
                    end = env_steps[i + 1]
                    if end - start > 1:
                        episodes_start_end.append((env_ind, start, end - 1))
            if len(episodes_start_end) > 0:
                reward_trajs_split = [
                    reward_trajs[start : end + 1, env_ind]
                    for env_ind, start, end in episodes_start_end
                ]

                # compute episode returns
                returns_trajs_split = [
                    np.zeros_like(reward_trajs) for reward_trajs in reward_trajs_split
                ]
                for traj_rewards, traj_returns in zip(
                    reward_trajs_split, returns_trajs_split
                ):
                    prev_return = 0
                    for t in range(len(traj_rewards)):
                        traj_returns[-t - 1] = (
                            traj_rewards[-t - 1] + self.gamma * prev_return
                        )
                        prev_return = traj_returns[-t - 1]

                # flatten (note: only works for single env!)
                returns_trajs_split = np.concatenate(returns_trajs_split)

                # extend buffer
                reward_to_go_buffer.extend(returns_trajs_split)

                num_episode_finished = len(reward_trajs_split)
                episode_reward = np.array(
                    [np.sum(reward_traj) for reward_traj in reward_trajs_split]
                )
                episode_best_reward = np.array(
                    [
                        np.max(reward_traj) / self.act_steps
                        for reward_traj in reward_trajs_split
                    ]
                )
                avg_episode_reward = np.mean(episode_reward)
                avg_best_reward = np.mean(episode_best_reward)
                success_rate = np.mean(
                    episode_best_reward >= self.best_reward_threshold_for_success
                )
            else:
                episode_reward = np.array([])
                num_episode_finished = 0
                avg_episode_reward = 0
                avg_best_reward = 0
                success_rate = 0

            # Update models
            if not eval_mode and self.itr >= self.n_explore_steps:
                # TODO: is this slow in online?
                if self.train_online:
                    obs_array = np.array(obs_buffer)
                    next_obs_array = np.array(next_obs_buffer)
                    actions_array = np.array(action_buffer)
                    rewards_array = np.array(reward_buffer)
                    reward_to_go_array = np.array(reward_to_go_buffer)
                    terminated_array = np.array(terminated_buffer)

                # override num_update
                if self.train_online:
                    num_update = len(reward_trajs)  # assume one env!
                else:
                    num_update = self.num_update
                for _ in range(num_update):
                    # Sample from OFFLINE buffer
                    inds = np.random.choice(
                        len(obs_buffer_off),
                        self.batch_size // 2 if self.train_online else self.batch_size,
                    )
                    obs_b = (
                        torch.from_numpy(obs_buffer_off[inds]).float().to(self.device)
                    )
                    next_obs_b = (
                        torch.from_numpy(next_obs_buffer_off[inds])
                        .float()
                        .to(self.device)
                    )
                    actions_b = (
                        torch.from_numpy(action_buffer_off[inds])
                        .float()
                        .to(self.device)
                    )
                    rewards_b = (
                        torch.from_numpy(reward_buffer_off[inds])
                        .float()
                        .to(self.device)
                    )
                    terminated_b = (
                        torch.from_numpy(terminated_buffer_off[inds])
                        .float()
                        .to(self.device)
                    )
                    reward_to_go_b = (
                        torch.from_numpy(reward_to_go_buffer_off[inds])
                        .float()
                        .to(self.device)
                    )

                    # Sample from ONLINE buffer
                    if self.train_online:
                        inds = np.random.choice(len(obs_buffer), self.batch_size // 2)
                        obs_b_on = (
                            torch.from_numpy(obs_array[inds]).float().to(self.device)
                        )
                        next_obs_b_on = (
                            torch.from_numpy(next_obs_array[inds])
                            .float()
                            .to(self.device)
                        )
                        actions_b_on = (
                            torch.from_numpy(actions_array[inds])
                            .float()
                            .to(self.device)
                        )
                        rewards_b_on = (
                            torch.from_numpy(rewards_array[inds])
                            .float()
                            .to(self.device)
                        )
                        terminated_b_on = (
                            torch.from_numpy(terminated_array[inds])
                            .float()
                            .to(self.device)
                        )
                        reward_to_go_b_on = (
                            torch.from_numpy(reward_to_go_array[inds])
                            .float()
                            .to(self.device)
                        )

                        # merge offline and online data
                        obs_b = torch.cat([obs_b, obs_b_on], dim=0)
                        next_obs_b = torch.cat([next_obs_b, next_obs_b_on], dim=0)
                        actions_b = torch.cat([actions_b, actions_b_on], dim=0)
                        rewards_b = torch.cat([rewards_b, rewards_b_on], dim=0)
                        terminated_b = torch.cat([terminated_b, terminated_b_on], dim=0)
                        reward_to_go_b = torch.cat(
                            [reward_to_go_b, reward_to_go_b_on], dim=0
                        )

                    # Get a random action for Cal-QL
                    random_actions = (
                        torch.rand(
                            (
                                self.batch_size,
                                self.n_random_actions,
                                self.horizon_steps,
                                self.action_dim,
                            )
                        ).to(self.device)
                        * 2
                        - 1
                    )  # scale to [-1, 1]

                    # Update critic
                    alpha = self.log_alpha.exp().item()
                    loss_critic = self.model.loss_critic(
                        {"state": obs_b},
                        {"state": next_obs_b},
                        actions_b,
                        random_actions,
                        rewards_b,
                        reward_to_go_b,
                        terminated_b,
                        self.gamma,
                        alpha,
                    )
                    self.critic_optimizer.zero_grad()
                    loss_critic.backward()
                    self.critic_optimizer.step()

                    # Update target critic
                    self.model.update_target_critic(self.target_ema_rate)

                    # Update actor
                    loss_actor = self.model.loss_actor(
                        {"state": obs_b},
                        alpha,
                    )
                    self.actor_optimizer.zero_grad()
                    loss_actor.backward()
                    self.actor_optimizer.step()

                    # Update temperature parameter
                    if self.automatic_entropy_tuning:
                        self.log_alpha_optimizer.zero_grad()
                        loss_alpha = self.model.loss_temperature(
                            {"state": obs_b},
                            self.log_alpha.exp(),  # with grad
                            self.target_entropy,
                        )
                        loss_alpha.backward()
                        self.log_alpha_optimizer.step()

            # Update lr
            self.actor_lr_scheduler.step()
            self.critic_lr_scheduler.step()

            # Save model
            if self.itr % self.save_model_freq == 0 or self.itr == self.n_train_itr - 1:
                self.save_model()

            # Log loss and save metrics
            run_results.append({"itr": self.itr})
            if self.itr % self.log_freq == 0 and self.itr >= self.n_explore_steps:
                time = timer()
                if eval_mode:
                    log.info(
                        f"eval: success rate {success_rate:8.4f} | avg episode reward {avg_episode_reward:8.4f} | avg best reward {avg_best_reward:8.4f}"
                    )
                    if self.use_wandb:
                        wandb.log(
                            {
                                "success rate - eval": success_rate,
                                "avg episode reward - eval": avg_episode_reward,
                                "avg best reward - eval": avg_best_reward,
                                "num episode - eval": num_episode_finished,
                            },
                            step=self.itr,
                            commit=False,
                        )
                    run_results[-1]["eval_success_rate"] = success_rate
                    run_results[-1]["eval_episode_reward"] = avg_episode_reward
                    run_results[-1]["eval_best_reward"] = avg_best_reward
                else:
                    log.info(
                        f"{self.itr}: loss actor {loss_actor:8.4f} | loss critic {loss_critic:8.4f} | reward {avg_episode_reward:8.4f} | alpha {alpha:8.4f} | t:{time:8.4f}"
                    )
                    if self.use_wandb:
                        wandb.log(
                            {
                                "loss - actor": loss_actor,
                                "loss - critic": loss_critic,
                                "entropy coeff": alpha,
                                "avg episode reward - train": avg_episode_reward,
                                "num episode - train": num_episode_finished,
                            },
                            step=self.itr,
                            commit=True,
                        )
                    run_results[-1]["loss_actor"] = loss_actor
                    run_results[-1]["loss_critic"] = loss_critic
                    run_results[-1]["train_episode_reward"] = avg_episode_reward
                run_results[-1]["time"] = time
                with open(self.result_path, "wb") as f:
                    pickle.dump(run_results, f)
            self.itr += 1
