# pylint: disable=too-many-locals, too-many-statements, too-many-instance-attributes, too-many-arguments
"""Experiments wrapper for Training and evaluating algorithms in MDP and Monitor MDP """
import os
import gymnasium as gym
import numpy as np
from src.actor import Actor
from src.critic import Critic
from src.utils import set_rng_seed, cantor_pairing
from src.replay_buffer import TorchReplayMemory
from src.critic import MonRoomCNN


class Experiment:
    """Run experiments for training and testing in MDP env"""

    def __init__(
        self,
        env: gym.Env,
        actor: Actor,
        critic: Critic,
        training_timesteps,
        testing_episodes,
        testing_frequency,
        rng_seed,
        log_dir: str,
        replay_buffer_size: int,
        save_log: bool = False,
        replay_buffer: bool = False,
        start_train_timestep: int = int(1e4),
        batch_size: int = 128,
        n_itr_episode: int = 5000,
        update_target_freq: int = 1000,
    ):
        self._env = env
        self._actor = actor
        self._critic = critic
        self._training_timesteps = training_timesteps
        self._testing_episodes = testing_episodes
        self._testing_frequency = testing_frequency
        self._rng_seed = rng_seed
        self._log_dir = log_dir
        self._save_train_log = save_log
        self._checkpoint_count = 0
        self._start_train_timestep = start_train_timestep  # start training after reaching number of timesteps
        self._batch_size = batch_size  # mini batch size number of sample per batch
        self._n_itr_episode = n_itr_episode  # number of iteration to update the Q-network per episode
        self._update_target_freq = update_target_freq  # update the target network every episode
        self.buffer = TorchReplayMemory(max_size=int(replay_buffer_size)) if replay_buffer else None

    def train(self):
        """Train an algorithm in MDP env, logs and save results"""
        set_rng_seed(self._rng_seed)
        self._actor.reset()
        self._critic.reset()
        joint_reward = {}
        total_timesteps = 0
        episode = 0
        while total_timesteps < self._training_timesteps:
            if total_timesteps > self._testing_frequency * self._checkpoint_count:
                self._actor.eval()
                episode_return = self.test()
                self._actor.train()

            ep_seed = cantor_pairing(self._rng_seed, episode)
            obs, _ = self._env.reset(seed=ep_seed)
            episode_return = 0.0
            episode_loss = 0.0
            next_action = None
            ep_joint_reward = []
            episode_timesteps = 0
            while True:
                episode_timesteps += 1
                action = self._actor(obs) if next_action is None else next_action
                next_obs, reward, term, trunc, _ = self._env.step(action)
                if self._critic._on_policy:
                    next_action = self._actor(next_obs)
                episode_loss += self._critic.update(obs, action, reward, term, next_obs, next_action)
                ep_joint_reward.append(reward)
                episode_return += reward
                if term or trunc:
                    break
                obs = next_obs
                self._actor.update()
            joint_reward[episode_timesteps] = ep_joint_reward
            total_timesteps += episode_timesteps
            episode += 1
        if self._save_train_log:
            np.save(self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed), joint_reward)
        # save Q-table as numpy array
        self._critic.save(seed=self._rng_seed)
        self._env.close()

    def test(self, render: bool = False) -> dict:
        """Evaluate an algorithm in MDP env, logs and save results"""
        test_return = {}
        total_timesteps = 0
        episode = 0
        while episode < self._testing_episodes:
            episode_return = []
            ep_seed = cantor_pairing(self._rng_seed, episode)
            obs, _ = self._env.reset(seed=ep_seed)
            episode_timesteps = 0
            while True:
                episode_timesteps += 1
                if render:
                    self._env.render()
                action = self._actor(obs)
                next_obs, reward, term, trunc, _ = self._env.step(action)
                episode_return.append(reward)
                if term or trunc:
                    break
                obs = next_obs
            test_return.update({episode: np.array(episode_return)})
            episode += 1
            total_timesteps += episode_timesteps
        return test_return


class MonExperiment(Experiment):
    """Run experiments for training and testing in Monitor MDP env"""
    def train(self, checkpoint: bool = True):
        """Train an algorithm in Monitor MDP env, logs and save results"""
        set_rng_seed(self._rng_seed)
        self._actor.reset()
        self._critic.reset()
        agent_locations = []
        joint_reward = {}
        eval_joint_reward = {}
        eval_count = 0
        total_timesteps = 0
        episode = 0
        # reset visit table
        while total_timesteps < self._training_timesteps:
            if total_timesteps > self._testing_frequency * self._checkpoint_count:
                # perform/save checkpoint
                self.checkpoint(joint_reward, eval_joint_reward, agent_locations)
                joint_reward = {}
                agent_locations = []
                self._actor.eval()
                (
                    ep_return_true,
                    ep_return_proxy,
                    ep_return_cost,
                    ep_monitor_action,
                    ep_length,
                    ep_discount_reward,
                    _,
                ) = self.test(save_results=True)
                eval_joint_reward.update({episode: ep_discount_reward})
                episode_return_true = ep_return_true.mean()
                episode_return_proxy = np.nanmean(ep_return_proxy)
                episode_return_cost = ep_return_cost.mean()
                self._actor.train()
                logs = {
                    "environment_reward": episode_return_true,
                    "received_reward": episode_return_proxy,
                    "monitor_reward": episode_return_cost,
                    "joint_reward": episode_return_true + episode_return_cost,
                }
                self.log_save_logs(train=False, logs=logs, episode=episode, save_logs=True)
                eval_count += 1

            ep_seed = cantor_pairing(self._rng_seed, episode)
            obs, _ = self._env.reset(seed=ep_seed)
            episode_return_true = 0.0
            episode_return_proxy = 0.0
            episode_return_cost = 0.0
            episode_reward_model_loss = 0.0
            reward_seen = False
            episode_monitor_action_count, episode_timesteps = 0, 0
            next_action = None
            agent_location_ep = []
            ep_joint_reward = []
            current_device = self._critic.get_device()
            while True:
                episode_timesteps += 1
                agent_location_ep.append(self._env.get_agent_pos())
                action = self._actor(obs) if next_action is None else next_action

                next_obs, reward, term, trunc, info = self._env.step(action)
                if self.buffer is not None:
                    self.buffer.push(obs, action, {"mdp": None, "monitor": None} if term else next_obs, reward)

                if action["monitor"] == 1:
                    episode_monitor_action_count += 1

                if self._critic._on_policy:
                    next_action = self._actor(next_obs)
                if self.buffer is None:
                    step_loss_mdp, step_loss_mon = self._critic.update(obs, action, reward, term, next_obs, next_action)
                else:
                    # optimize
                    step_loss_mdp, step_loss_mon = 0, 0

                episode_return_true += info["mdp_reward"]
                episode_return_cost += reward["monitor"]
                if not np.isnan(reward["mdp"]):
                    reward_seen = True
                    episode_return_proxy += reward["mdp"]

                ep_joint_reward.append(info["mdp_reward"] + reward["monitor"])
                self._actor.update()
                if term or trunc:
                    if not reward_seen:
                        episode_return_proxy = np.nan
                    break

                obs = next_obs
            # Update Q-network and reward network
            if self.buffer.buffer_size > self._start_train_timestep and episode % 20 == 0:
                for epoch in range(self._n_itr_episode):
                    batch = self.buffer.process_batch(self.buffer.sample(self._batch_size), device=current_device)
                    step_loss_mdp, r_model_loss = self._critic.optimize_policy_model(
                        batch,
                        epoch % self._update_target_freq == 0,
                        monitor_state=isinstance(self._critic, MonRoomCNN),
                    )
                    episode_reward_model_loss += r_model_loss
            agent_locations.append(agent_location_ep)
            joint_reward.update({episode: ep_joint_reward})
            total_timesteps += episode_timesteps
            logs = {
                "environment_reward": episode_return_true,
                "received_reward": episode_return_proxy,
                "monitor_reward": episode_return_cost,
                "episode_reward_model_loss": episode_reward_model_loss,
                "joint_reward": episode_return_true + episode_return_cost,
            }
            self.log_save_logs(train=True, logs=logs, episode=episode, save_logs=True)
            episode += 1

        temp_agent_location = np.load(self._log_dir + "/agent_locations_{}.npy".format(self._rng_seed))
        np.save(
            self._log_dir + "/agent_locations_{}.npy".format(self._rng_seed),
            np.concatenate(
                (temp_agent_location, agent_locations),
                0,
                dtype=np.int8,
            ),
        )
        agent_locations = []
        self._critic.save(seed=self._rng_seed)
        if self._save_train_log:
            tmp_joint_reward = np.load(
                self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed),
                allow_pickle=True,
            )[()]
            tmp_joint_reward.update(joint_reward)
            np.save(self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed), tmp_joint_reward)
            np.save(
                self._log_dir + "/evaluation_joint_reward_{}.npy".format(self._rng_seed),
                eval_joint_reward,
            )
        self._env.close()

    def test(self, render: bool = False, seed: int = 1, save_results: bool = False):
        """Evaluate an algorithm in Monitor MDP env, logs and save results"""
        episode_return_true = []
        episode_return_proxy = []
        episode_return_cost = []
        episode_agent_locations = []
        episode_discount_reward = []
        trajectories = {}
        total_timesteps = 0
        episode = 0
        while episode < self._testing_episodes:
            reward_seen = False
            ep_seed = cantor_pairing(self._rng_seed, episode)
            obs, _ = self._env.reset(seed=ep_seed)
            episode_monitor_action_count, episode_timesteps = 0, 0
            ep_states, ep_actions, ep_joint_reward = [], [], []
            return_true, return_proxy, return_cost, ep_discount_reward = 0, 0, 0, 0
            while True:
                if render:
                    self._env.render()
                episode_agent_locations.append(self._env.get_agent_pos())
                action = self._actor(obs)
                next_obs, reward, term, trunc, info = self._env.step(action)
                return_true += info["mdp_reward"]
                return_cost += reward["monitor"]
                ep_joint_reward.append(info["mdp_reward"] + reward["monitor"])
                ep_discount_reward += (self._critic._gamma**episode_timesteps) * (
                    info["mdp_reward"] + reward["monitor"]
                )
                if not np.isnan(reward["mdp"]):
                    reward_seen = True
                    return_proxy += reward["mdp"]
                if term or trunc:
                    if not reward_seen:
                        return_proxy = np.nan
                    break
                obs = next_obs
                episode_timesteps += 1

            episode_return_true.append(return_true)
            episode_return_cost.append(return_cost)
            episode_return_proxy.append(return_proxy)
            total_timesteps += episode_timesteps
            episode_discount_reward.append(ep_discount_reward)
            trajectories[episode] = {
                "environment_reward": np.array(episode_return_true),
                "received_reward": np.array(episode_return_proxy),
                "monitor_reward": np.array(episode_return_cost),
                "joint_reward": np.array(episode_return_true) + np.array(episode_return_cost),
                "undiscounted_joint_reward": np.array(ep_joint_reward),
                "episode_agent_locations": np.squeeze(episode_agent_locations),
            }
            episode += 1
        return (
            np.array(episode_return_true),
            np.array(episode_return_proxy),
            np.array(episode_return_cost),
            np.array([]),
            np.array([]),
            np.array(episode_discount_reward),
            trajectories,
        )

    def checkpoint(self, joint_reward, eval_joint_reward, agent_locations):
        """save the model and statistic during the training process"""
        checkpoint_dir = self._log_dir + "checkpoints_{}/".format(self._checkpoint_count)
        os.makedirs(checkpoint_dir, exist_ok=True)
        self._critic.save(file_name="checkpoints_{}/".format(self._checkpoint_count), seed=self._rng_seed)
        self._critic.save(seed=self._rng_seed)  # save critic
        if self._checkpoint_count == 0:
            np.save(self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed), joint_reward)
            np.save(
                self._log_dir + "/agent_locations_{}.npy".format(self._rng_seed),
                np.array(agent_locations, dtype=np.int8),
            )
        else:
            tmp_joint_reward = np.load(
                self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed),
                allow_pickle=True,
            )[()]
            tmp_joint_reward.update(joint_reward)
            np.save(self._log_dir + "/training_joint_reward_{}.npy".format(self._rng_seed), tmp_joint_reward)
            temp_agent_location = np.load(self._log_dir + "/agent_locations_{}.npy".format(self._rng_seed))
            np.save(
                self._log_dir + "/agent_locations_{}.npy".format(self._rng_seed),
                np.concatenate(
                    (temp_agent_location, agent_locations),
                    0,
                    dtype=np.int8,
                ),
            )
        np.save(
            self._log_dir + "/evaluation_joint_reward_{}.npy".format(self._rng_seed),
            eval_joint_reward,
        )

        self._checkpoint_count += 1

    @staticmethod
    def get_obs_from_agent_pos(self, grid_size: (int, int), agent_pos: (int, int)) -> int:
        return int(agent_pos[0] * grid_size[0] + agent_pos[1])

    def log_save_logs(self, train: bool, logs: dict, episode: int, save_logs: bool) -> None:
        """log and save logs to wand"""
        for key, value in logs.items():
            if save_logs:
                np.save(self._log_dir + "/{}_{}.npy".format(key, self._rng_seed), value)
