from collections import deque, defaultdict
from typing import Tuple, Optional, List

import os
import csv
import numpy as np
import torch
from baselines.common.running_mean_std import RunningMeanStd

from level_replay import LevelSamplerES as LevelSampler, LevelStore
from util import (
    # array_to_csv,
    # array_to_csv_append,
    is_discrete_actions,
    get_obs_at_index,
    set_obs_at_index,
)

from teachDeepRL.teachers.teacher_controller import TeacherController

import matplotlib as mpl
import matplotlib.pyplot as plt


class AdversarialRunner(object):
    """
    Performs rollouts of an adversarial environment, given
    protagonist (agent), antagonist (adversary_agent), and
    environment adversary (adversary_env)
    """

    def __init__(
        self,
        args,
        venv,
        agent,
        ued_venv=None,
        adversary_agent=None,
        adversary_env=None,
        flexible_protagonist=False,
        train=False,
        plr_args=None,
        device="cpu",
    ):
        """
        venv: Vectorized, adversarial gym env with agent-specific wrappers.
        agent: Protagonist trainer.
        ued_venv: Vectorized, adversarial gym env with adversary-env-specific wrappers.
        adversary_agent: Antagonist trainer.
        adversary_env: Environment adversary trainer.

        flexible_protagonist: Which agent plays the role of protagonist in
            calculating the regret depends on which has the lowest score.
        """
        self.args = args

        self.venv = venv
        if ued_venv is None:
            self.ued_venv = venv
        else:
            self.ued_venv = ued_venv

        self.is_discrete_actions = is_discrete_actions(self.venv)
        self.is_discrete_adversary_env_actions = is_discrete_actions(
            self.venv, adversary=True
        )

        self.agents = {
            "agent": agent,
            "adversary_agent": adversary_agent,
            "adversary_env": adversary_env,
        }

        self.agent_rollout_steps = args.num_steps
        self.adversary_env_rollout_steps = self.venv.adversary_observation_space[
            "time_step"
        ].high[0]

        self.is_dr = args.ued_algo == "domain_randomization"
        self.is_training_env = args.ued_algo in ["paired", "flexible_paired", "minimax"]
        self.is_paired = args.ued_algo in ["paired", "flexible_paired"]
        self.requires_batched_vloss = (
            args.use_editor
            and args.base_levels == "easy"
            and args.use_accel_paired == False
        )

        self.is_alp_gmm = args.ued_algo == "alp_gmm"

        # Track running mean and std of env returns for return normalization
        if args.adv_normalize_returns:
            self.env_return_rms = RunningMeanStd(shape=())

        self.device = device

        if train:
            self.train()
        else:
            self.eval()

        self.reset()
        self.use_accel_paired = args.use_accel_paired

        self.level_store = None
        self.level_samplers = {}
        self.current_level_seeds = None
        self.weighted_num_edits = 0.0
        self.latest_env_stats = defaultdict(float)

        if plr_args:
            if self.is_paired:
                if not args.protagonist_plr and not args.antagonist_plr:
                    self.level_samplers.update(
                        {
                            "agent": LevelSampler(**plr_args),
                            "adversary_agent": LevelSampler(**plr_args),
                        }
                    )
                elif args.protagonist_plr:
                    self.level_samplers["agent"] = LevelSampler(**plr_args)
                elif args.antagonist_plr:
                    self.level_samplers["adversary_agent"] = LevelSampler(**plr_args)
            else:
                self.level_samplers["agent"] = LevelSampler(**plr_args)

            if self.use_byte_encoding:
                example = self.ued_venv.get_encodings()[0]
                data_info = {
                    "numpy": True,
                    "dtype": example.dtype,
                    "shape": example.shape,
                }
                self.level_store = LevelStore(data_info=data_info)
            else:
                self.level_store = LevelStore()

            self.current_level_seeds = [-1 for _ in range(args.num_processes)]
            self._default_level_sampler = self.all_level_samplers[0]

            self.use_editor = args.use_editor
            self.edit_prob = args.level_editor_prob
            self.base_levels = args.base_levels
        else:
            self.use_editor = False
            self.edit_prob = 0.0
            self.base_levels = None
            self._default_level_sampler = None

        # ------------------- ALP-GMM Initialization -------------------
        if self.is_alp_gmm:
            self._init_alp_gmm()

        # ------------------- Early-stop & drop mechanism -------------------
        self.threshold = 150
        env_name = self.args.env_name
        if env_name.startswith("MultiGrid"):
            self.threshold = 0.7

        self.early_stop_patience = 20
        self.level_sample_counts = defaultdict(int)  # seed -> #times replay-sampled
        self.dropped_seeds = set()  # seeds that should no longer be replay-sampled

        self._last_buffer_seeds_for_stats = None

        #   {"seed", "level", "edit_count", "times_sampled", "reason", "partition"}
        self.added_levels_records = []  # Levels newly added to replay buffer / LevelStore
        self.dropped_levels_records = []  # Levels dropped from replay buffer

        self.added_levels_csv = getattr(args, "added_levels_csv_path", None)
        self.dropped_levels_csv = getattr(args, "dropped_levels_csv_path", None)

    @property
    def use_byte_encoding(self):
        env_name = self.args.env_name
        if (
            self.args.use_editor
            or env_name.startswith("BipedalWalker")
            or (env_name.startswith("MultiGrid") and self.args.use_reset_random_dr)
        ):
            return True
        else:
            return False

    def _init_alp_gmm(self):
        args = self.args
        param_env_bounds = []
        if args.env_name.startswith("MultiGrid"):
            param_env_bounds = {"actions": [0, 168, 26]}
            reward_bounds = None
        elif args.env_name.startswith("Bipedal"):
            if "POET" in args.env_name:
                param_env_bounds = {"actions": [0, 2, 5]}
            else:
                param_env_bounds = {"actions": [0, 2, 8]}
            reward_bounds = (-200, 350)
        else:
            raise ValueError(f"Environment {args.env_name} not supported for ALP-GMM")

        self.alp_gmm_teacher = TeacherController(
            teacher="ALP-GMM",
            nb_test_episodes=0,
            param_env_bounds=param_env_bounds,
            reward_bounds=reward_bounds,
            seed=args.seed,
            teacher_params={},
        )  # Use defaults

    def reset(self):
        self.num_updates = 0
        self.total_num_edits = 0
        self.total_episodes_collected = 0
        self.total_seeds_collected = 0
        self.student_grad_updates = 0
        self.sampled_level_info = None

        max_return_queue_size = 10
        self.agent_returns = deque(maxlen=max_return_queue_size)
        self.adversary_agent_returns = deque(maxlen=max_return_queue_size)

    def train(self):
        self.is_training = True
        [agent.train() if agent else agent for _, agent in self.agents.items()]

    def eval(self):
        self.is_training = False
        [agent.eval() if agent else agent for _, agent in self.agents.items()]

    # -------------------------------------------------------------------
    # Checkpoint
    # -------------------------------------------------------------------
    def state_dict(self):
        agent_state_dict = {}
        optimizer_state_dict = {}
        for k, agent in self.agents.items():
            if agent:
                agent_state_dict[k] = agent.algo.actor_critic.state_dict()
                optimizer_state_dict[k] = agent.algo.optimizer.state_dict()

        return {
            "agent_state_dict": agent_state_dict,
            "optimizer_state_dict": optimizer_state_dict,
            "agent_returns": self.agent_returns,
            "adversary_agent_returns": self.adversary_agent_returns,
            "num_updates": self.num_updates,
            "total_episodes_collected": self.total_episodes_collected,
            "total_seeds_collected": self.total_seeds_collected,
            "total_num_edits": self.total_num_edits,
            "student_grad_updates": self.student_grad_updates,
            "latest_env_stats": self.latest_env_stats,
            "level_store": self.level_store,
            "level_samplers": self.level_samplers,
        }

    def load_state_dict(self, state_dict):
        agent_state_dict = state_dict.get("agent_state_dict")

        for k, state in agent_state_dict.items():
            self.agents[k].algo.actor_critic.load_state_dict(state)

        optimizer_state_dict = state_dict.get("optimizer_state_dict")

        for k, state in optimizer_state_dict.items():
            self.agents[k].algo.optimizer.load_state_dict(state)

        self.agent_returns = state_dict.get("agent_returns")
        self.adversary_agent_returns = state_dict.get("adversary_agent_returns")
        self.num_updates = state_dict.get("num_updates")
        self.total_episodes_collected = state_dict.get("total_episodes_collected")
        self.total_seeds_collected = state_dict.get("total_seeds_collected")
        self.total_num_edits = state_dict.get("total_num_edits")
        self.student_grad_updates = state_dict.get("student_grad_updates")
        self.latest_env_stats = state_dict.get("latest_env_stats")

        self.level_store = state_dict.get("level_store")
        self.level_samplers = state_dict.get("level_samplers")

        if self.args.use_plr and self.level_samplers is not None:
            self._default_level_sampler = self.all_level_samplers[0]

    # -------------------------------------------------------------------
    def _ensure_csv_header(self, path: str, fieldnames: List[str]):
        if path is None:
            return
        dirname = os.path.dirname(path)
        if dirname:
            os.makedirs(dirname, exist_ok=True)

        need_header = (not os.path.exists(path)) or (os.path.getsize(path) == 0)
        if not need_header:
            return

        with open(path, "w", newline="") as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()

    def _flush_add_drop_records_to_csv(self, current_steps: int):

        if (not getattr(self.args, "use_plr", False)) or (self.level_store is None):
            self.added_levels_records.clear()
            self.dropped_levels_records.clear()
            return

        if self.added_levels_csv is not None and self.added_levels_records:
            fieldnames_add = [
                "steps",
                "seed",
                "level",
                "edit_count",
                "times_sampled",
                "reason",
                # "partition",
            ]
            self._ensure_csv_header(self.added_levels_csv, fieldnames_add)

            with open(self.added_levels_csv, "a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=fieldnames_add)
                for rec in self.added_levels_records:
                    row = {
                        "steps": int(current_steps),
                        "seed": int(rec.get("seed", -1)),
                        "level": str(rec.get("level", None)),
                        "edit_count": int(rec.get("edit_count", 0)),
                        "times_sampled": int(rec.get("times_sampled", 0)),
                        "reason": rec.get("reason", ""),
                        # "partition": rec.get("partition", None) or "",
                    }
                    writer.writerow(row)

        # ---------------- Write dropped_levels ----------------
        if self.dropped_levels_csv is not None and self.dropped_levels_records:
            fieldnames_drop = [
                "steps",
                "seed",
                "level",
                "edit_count",
                "times_sampled",
                "reason",
                # "partition",
            ]
            self._ensure_csv_header(self.dropped_levels_csv, fieldnames_drop)

            with open(self.dropped_levels_csv, "a", newline="") as f:
                writer = csv.DictWriter(f, fieldnames=fieldnames_drop)
                for rec in self.dropped_levels_records:
                    row = {
                        "steps": int(current_steps),
                        "seed": int(rec.get("seed", -1)),
                        "level": str(rec.get("level", None)),
                        "edit_count": int(rec.get("edit_count", 0)),
                        "times_sampled": int(rec.get("times_sampled", 0)),
                        "reason": rec.get("reason", ""),
                        # "partition": rec.get("partition", None) or "",
                    }
                    writer.writerow(row)

        self.added_levels_records.clear()
        self.dropped_levels_records.clear()

    def _get_batched_value_loss(self, agent, clipped=True, batched=True):
        batched_value_loss = agent.storage.get_batched_value_loss(
            signed=False, positive_only=False, clipped=clipped, batched=batched
        )
        return batched_value_loss

    def _get_rollout_return_stats(self, rollout_returns):
        mean_return = torch.zeros(self.args.num_processes, 1)
        max_return = torch.zeros(self.args.num_processes, 1)
        for b, returns in enumerate(rollout_returns):
            if len(returns) > 0:
                mean_return[b] = float(np.mean(returns))
                max_return[b] = float(np.max(returns))

        stats = {
            "mean_return": mean_return,
            "max_return": max_return,
            "returns": rollout_returns,
        }
        return stats

    def _calculate_paired_regret_scores(
        self, agent_rollout_info, adversary_agent_rollout_info, type="paired"
    ):
        if type == "paired":
            external_scores = torch.max(
                adversary_agent_rollout_info["max_return"]
                - agent_rollout_info["mean_return"],
                torch.zeros_like(agent_rollout_info["mean_return"]),
            )
        elif type == "flex_paired":
            env_return = torch.zeros_like(
                agent_rollout_info["max_return"], dtype=torch.float
            )
            adversary_agent_max_idx = (
                adversary_agent_rollout_info["max_return"]
                > agent_rollout_info["max_return"]
            )
            agent_max_idx = ~adversary_agent_max_idx

            env_return[adversary_agent_max_idx] = adversary_agent_rollout_info[
                "max_return"
            ][adversary_agent_max_idx]
            env_return[agent_max_idx] = agent_rollout_info["max_return"][agent_max_idx]

            env_mean_return = torch.zeros_like(env_return, dtype=torch.float)
            env_mean_return[adversary_agent_max_idx] = agent_rollout_info["mean_return"][
                adversary_agent_max_idx
            ]
            env_mean_return[agent_max_idx] = adversary_agent_rollout_info["mean_return"][
                agent_max_idx
            ]

            env_return = torch.max(
                env_return - env_mean_return, torch.zeros_like(env_return)
            )
        else:
            raise NotImplementedError

        return external_scores

    def _get_env_stats_multigrid(self, agent_info, adversary_agent_info):
        num_blocks = np.mean(agent_info.get("num_blocks", self.venv.get_num_blocks()))
        passable_ratio = np.mean(agent_info.get("passable_ratio", self.venv.get_passable()))

        shortest_path_lengths = agent_info.get(
            "shortest_path_lengths", self.venv.get_shortest_path_length()
        )
        shortest_path_length = np.mean(shortest_path_lengths)

        solved_idx = agent_info.get("solved_idx", None)
        if solved_idx is None:
            if "max_returns" in adversary_agent_info:
                solved_idx = (
                    (
                        torch.max(
                            agent_info["max_return"],
                            adversary_agent_info["max_return"],
                        )
                        > 0
                    )
                    .numpy()
                    .squeeze()
                )
            else:
                solved_idx = (agent_info["max_return"] > 0).numpy().squeeze()

        solved_path_lengths = np.array(shortest_path_lengths)[solved_idx]
        solved_path_length = (
            np.mean(solved_path_lengths) if len(solved_path_lengths) > 0 else 0
        )

        stats = {
            "num_blocks": num_blocks,
            "passable_ratio": passable_ratio,
            "shortest_path_length": shortest_path_length,
            "solved_path_length": solved_path_length,
        }
        return stats

    def _get_env_stats_car_racing(self, agent_info, adversary_agent_info):
        infos = self.venv.get_complexity_info()
        num_envs = len(infos)

        sums = defaultdict(float)
        for info in infos:
            for k, v in info.items():
                sums[k] += v

        stats = {}
        for k, v in sums.items():
            stats["track_" + k] = sums[k] / num_envs
        return stats

    def _get_env_stats_bipedalwalker(self, agent_info, adversary_agent_info):
        infos = self.venv.get_complexity_info()
        num_envs = len(infos)

        sums = defaultdict(float)
        for info in infos:
            for k, v in info.items():
                sums[k] += v

        stats = {}
        for k, v in sums.items():
            stats["track_" + k] = sums[k] / num_envs
        return stats

    def _get_env_stats(self, agent_info, adversary_agent_info, log_replay_complexity=False):
        env_name = self.args.env_name
        if env_name.startswith("MultiGrid"):
            stats = self._get_env_stats_multigrid(agent_info, adversary_agent_info)
        elif env_name.startswith("CarRacing"):
            stats = self._get_env_stats_car_racing(agent_info, adversary_agent_info)
        elif env_name.startswith("BipedalWalker"):
            stats = self._get_env_stats_bipedalwalker(agent_info, adversary_agent_info)
        else:
            raise ValueError(f"Unsupported environment, {self.args.env_name}")

        stats_ = {}
        for k, v in stats.items():
            stats_["plr_" + k] = v if log_replay_complexity else None
            stats_[k] = v if not log_replay_complexity else None

        return stats_

    def _get_plr_buffer_stats(self):
        stats = {}
        return stats

    def _get_active_levels(self):
        assert self.args.use_plr, "Only call _get_active_levels when using PLR."

        env_name = self.args.env_name

        is_multigrid = env_name.startswith("MultiGrid")
        is_car_racing = env_name.startswith("CarRacing")
        is_bipedal_walker = env_name.startswith("BipedalWalker")

        if self.use_byte_encoding:
            return [x.tobytes() for x in self.ued_venv.get_encodings()]
        elif is_multigrid:
            return self.agents["adversary_env"].storage.get_action_traj(as_string=True)
        else:
            return self.ued_venv.get_level()

    def _get_level_sampler(self, name):
        other = "adversary_agent"
        if name == "adversary_agent":
            other = "agent"

        level_sampler = self.level_samplers.get(name) or self.level_samplers.get(other)
        updateable = name in self.level_samplers
        return level_sampler, updateable

    @property
    def all_level_samplers(self):
        if len(self.level_samplers) == 0:
            return []
        return list(
            filter(lambda x: x is not None, [v for _, v in self.level_samplers.items()])
        )

    def _record_added_level(self, seed, reason: str):

        if (not self.args.use_plr) or self.level_store is None:
            return

        seed = int(seed)

        level = self.level_store.get_level(seed)

        seed2parent = getattr(self.level_store, "seed2parent", {})
        parent_list = seed2parent.get(seed, [])
        edit_count = len(parent_list)

        times_sampled = int(self.level_sample_counts.get(seed, 0))

        self.added_levels_records.append(
            {
                "seed": seed,
                "level": level,
                "edit_count": edit_count,
                "times_sampled": times_sampled,
                "reason": reason,
                # "partition": partition_str,
            }
        )

    def _record_dropped_level(self, seed, reason: str):

        if (not self.args.use_plr) or self.level_store is None:
            return

        seed = int(seed)
        level = self.level_store.get_level(seed)

        seed2parent = getattr(self.level_store, "seed2parent", {})
        parent_list = seed2parent.get(seed, [])
        edit_count = len(parent_list)

        times_sampled = int(self.level_sample_counts.get(seed, 0))

        self.dropped_levels_records.append(
            {
                "seed": seed,
                "level": level,
                "edit_count": edit_count,
                "times_sampled": times_sampled,
                "reason": reason,
            }
        )

    def _get_buffer_seeds(self):

        ls = getattr(self, "_default_level_sampler", None)
        if ls is None:
            return []

        if getattr(ls, "sample_full_distribution", False):
            seeds = [int(s) for s in ls.seeds[: ls.working_seed_buffer_size] if s != -1]
        else:
            seeds = []
            for i, s in enumerate(ls.seeds):
                if s == -1:
                    continue
                if ls.unseen_seed_weights[i] == 0.0:
                    seeds.append(int(s))

        return seeds

    def _get_weighted_num_edits(self):

        if not self.args.use_plr or self.level_store is None:
            return 0.0

        ls = getattr(self, "_default_level_sampler", None)
        if ls is None:
            return 0.0

        buffer_seeds = self._get_buffer_seeds()
        if len(buffer_seeds) == 0:
            return 0.0

        seed2parent = getattr(self.level_store, "seed2parent", {})
        seed_num_edits = {}
        for seed in buffer_seeds:
            parent_list = seed2parent.get(seed, [])
            seed_num_edits[seed] = len(parent_list)

        all_weights = ls.sample_weights()  # shape: [seed_buffer_size]

        weights = []
        edits = []
        for seed in buffer_seeds:
            idx = ls.seed2index.get(seed, None)
            if idx is None:
                continue
            w = float(all_weights[idx])
            weights.append(w)
            edits.append(float(seed_num_edits[seed]))

        if len(weights) == 0:
            return 0.0

        weights = np.asarray(weights, dtype=np.float32)
        edits = np.asarray(edits, dtype=np.float32)

        Z = weights.sum()
        if Z <= 0:
            return float(edits.mean())

        weighted = float((weights * edits).sum() / Z)
        return weighted

    def _reconcile_level_store_and_samplers(self):

        all_replay_seeds = set()
        for level_sampler in self.all_level_samplers:
            all_replay_seeds.update([x for x in level_sampler.seeds if x >= 0])
        self.level_store.reconcile_seeds(all_replay_seeds)

    def _sample_replay_level_filtered(self, level_sampler, update_staleness=True):

        if level_sampler is None:
            return None

        return level_sampler._sample_replay_level_batch()

    def _should_edit_level(self):
        # if self.use_editor:
        #     return np.random.rand() < self.edit_prob
        # else:
        #     return False
        return True

    def _update_plr_with_current_unseen_levels(
        self, parent_seeds=None, valid_indices=None, child_partitions=None
    ):

        if not self.args.use_plr or self.level_store is None:
            return

        args = self.args

        all_levels = self._get_active_levels()

        filtered_parent_seeds = parent_seeds

        if parent_seeds is None and valid_indices is None:
            levels = all_levels

        elif parent_seeds is None and valid_indices is not None:
            # print("complement init levels:", valid_indices)
            levels = [all_levels[i] for i in valid_indices]
            # child_partitions = [zero_idx] * len(levels)

        else:
            levels = [all_levels[i] for i in valid_indices]
            if parent_seeds is not None:
                filtered_parent_seeds = [parent_seeds[i] for i in valid_indices]

        current_level_seeds = self.level_store.insert(
            levels,
            parent_seeds=filtered_parent_seeds,
        )

        self.current_level_seeds = current_level_seeds

        if args.log_plr_buffer_stats or args.reject_unsolvable_seeds:
            passable = self.venv.get_passable()
        else:
            passable = None

        self._update_level_samplers_with_external_unseen_sample(
            self.current_level_seeds,
            solvable=passable,
        )

        new_seeds_int = set(int(s) for s in current_level_seeds)
        if len(new_seeds_int) > 0:
            for seed in self.current_level_seeds:
                seed = int(seed)
                if seed not in new_seeds_int:
                    continue

                if parent_seeds is not None:
                    reason = "mutation_child"

                elif valid_indices is not None:
                    reason = "complement_init"
                else:
                    reason = "random_init"

                self._record_added_level(seed, reason)

    def _update_level_samplers_with_external_unseen_sample(self, seeds, solvable=None):
        level_samplers = self.all_level_samplers

        if self.args.reject_unsolvable_seeds:
            solvable = np.array(solvable, dtype=np.bool_)
            seeds = np.array(seeds, dtype=np.int64)[solvable]
            solvable = solvable[solvable]

        for level_sampler in level_samplers:
            level_sampler.observe_external_unseen_sample(seeds, solvable)

    def _sample_replay_decision(self):

        buffer_seeds = self._get_buffer_seeds()
        if len(buffer_seeds) < 960:
            return False

        return True
        # return np.random.rand() <= 0.9

    def agent_rollout(
        self,
        agent,
        num_steps,
        update=False,
        is_env=False,
        level_replay=False,
        level_sampler=None,
        update_level_sampler=False,
        discard_grad=False,
        edit_level=False,
        num_edits=0,
        fixed_seeds=None,
        kl_dict=None,
        update_agent_separately=False,
        original_partitions=None,
        target_partitions=None,
        eligible_idx=None,
    ):
        args = self.args
        if is_env:
            if edit_level:

                levels = [self.level_store.get_level(seed) for seed in fixed_seeds]
                self.ued_venv.reset_to_level_batch(levels)

                valid_indices = eligible_idx

                self.ued_venv.mutate_level_forward_in_batch(
                    eligible_idx=valid_indices,
                    num_edits=num_edits,
                )

                self._update_plr_with_current_unseen_levels(
                    parent_seeds=fixed_seeds,
                    valid_indices=valid_indices,
                )
                return

            if level_replay:
                self.current_level_seeds = self._sample_replay_level_filtered(level_sampler)

                levels = [
                    self.level_store.get_level(seed) for seed in self.current_level_seeds
                ]
                self.ued_venv.reset_to_level_batch(levels)
                return self.current_level_seeds

            else:
                obs = self.ued_venv.reset_random()
                self._update_plr_with_current_unseen_levels(parent_seeds=None)
                self.total_seeds_collected += args.num_processes
                return
        else:
            obs = self.venv.reset_agent()

        if agent is not None:
            agent.storage.copy_obs_to_index(obs, 0)

        rollout_info = {}
        rollout_returns = [[] for _ in range(args.num_processes)]

        if self.use_accel_paired:
            actor_seeds = {i: [] for i in range(args.num_processes)}

        if level_sampler and level_replay and agent is not None and not is_env:
            rollout_info.update(
                {"solved_idx": np.zeros(args.num_processes, dtype=np.bool_)}
            )

        for step in range(num_steps if num_steps is not None else 0):
            if agent is None:
                break

            if args.render and not is_env:
                self.venv.render_to_screen()

            # Sample actions
            with torch.no_grad():
                obs_id = agent.storage.get_obs(step)
                value, action, action_log_dist, recurrent_hidden_states = agent.act(
                    obs_id,
                    agent.storage.get_recurrent_hidden_state(step),
                    agent.storage.masks[step],
                )
                if is_env:
                    is_disc = self.is_discrete_adversary_env_actions
                else:
                    is_disc = self.is_discrete_actions

                if is_disc:
                    action_log_prob = action_log_dist.gather(-1, action)
                else:
                    action_log_prob = action_log_dist

            reset_random = self.is_dr and not args.use_plr
            _action = agent.process_action(action.cpu())

            if is_env:
                obs, reward, done, infos = self.ued_venv.step_adversary(_action)
            else:
                obs, reward, done, infos = self.venv.step_env(
                    _action, reset_random=reset_random
                )
                if args.clip_reward:
                    reward = torch.clamp(reward, -args.clip_reward, args.clip_reward)

            if not is_env and step >= num_steps - 1:
                if agent.storage.use_proper_time_limits:
                    for i, done_ in enumerate(done):
                        if not done_:
                            infos[i]["cliffhanger"] = True
                            infos[i]["truncated"] = True
                            infos[i]["truncated_obs"] = get_obs_at_index(obs, i)

                done = np.ones_like(done, dtype=np.float32)

            if level_sampler and level_replay and not is_env:
                next_level_seeds = [s for s in self.current_level_seeds]

            for i, info in enumerate(infos):
                if "episode" in info.keys():
                    rollout_returns[i].append(info["episode"]["r"])

                    if self.use_accel_paired and not is_env:
                        actor_seeds[i].append(self.current_level_seeds[i])

                    if reset_random:
                        self.total_seeds_collected += 1

                    if not is_env:
                        self.total_episodes_collected += 1

                        if agent.storage.use_proper_time_limits:
                            if "truncated_obs" in info.keys():
                                truncated_obs = info["truncated_obs"]
                                agent.storage.insert_truncated_obs(truncated_obs, index=i)

                        if level_sampler and level_replay and not self.use_accel_paired:
                            level_seed = self._sample_replay_level_filtered(level_sampler)[
                                0
                            ]
                            level = self.level_store.get_level(level_seed)
                            obs_i = self.venv.reset_to_level(level, i)
                            set_obs_at_index(obs, obs_i, i)
                            next_level_seeds[i] = level_seed
                            rollout_info["solved_idx"][i] = True

                        if level_sampler and level_replay and self.use_accel_paired:
                            level_seed = self.current_level_seeds[i]
                            level = self.level_store.get_level(level_seed)
                            obs_i = self.venv.reset_to_level(level, i)
                            set_obs_at_index(obs, obs_i, i)
                            next_level_seeds[i] = level_seed
                            rollout_info["solved_idx"][i] = True
                            self.current_level_seeds[i] = level_seed

                        if self.is_alp_gmm and not is_env:
                            self.alp_gmm_teacher.record_train_episode(
                                rollout_returns[i][-1], index=i
                            )
                            self.alp_gmm_teacher.set_env_params(self.venv)

            if agent is None:
                break

            masks = torch.FloatTensor([[0.0] if d else [1.0] for d in done])
            bad_masks = torch.FloatTensor(
                [[0.0] if "truncated" in info.keys() else [1.0] for info in infos]
            )
            cliffhanger_masks = torch.FloatTensor(
                [[0.0] if "cliffhanger" in info.keys() else [1.0] for info in infos]
            )

            current_level_seeds = None
            if (not is_env) and level_sampler:
                current_level_seeds = torch.tensor(
                    self.current_level_seeds, dtype=torch.int64
                ).view(-1, 1)

            agent.insert(
                obs,
                recurrent_hidden_states,
                action,
                action_log_prob,
                action_log_dist,
                value,
                reward,
                masks,
                bad_masks,
                level_seeds=current_level_seeds,
                cliffhanger_masks=cliffhanger_masks,
            )

            if level_sampler and level_replay and not is_env:
                self.current_level_seeds = next_level_seeds

        if is_env and args.use_plr and not level_replay and not edit_level:
            self._update_plr_with_current_unseen_levels()

        if agent is None:
            return

        rollout_info.update(self._get_rollout_return_stats(rollout_returns))
        if self.use_accel_paired and not is_env:
            rollout_info["actor_seeds"] = actor_seeds

        if not is_env and update:
            with torch.no_grad():
                obs_id = agent.storage.get_obs(-1)
                next_value = agent.get_value(
                    obs_id,
                    agent.storage.get_recurrent_hidden_state(-1),
                    agent.storage.masks[-1],
                ).detach()

            agent.storage.compute_returns(
                next_value, args.use_gae, args.gamma, args.gae_lambda
            )

            if self.requires_batched_vloss:
                clipped = not args.adv_use_popart and not args.adv_normalize_returns
                batched_value_loss = self._get_batched_value_loss(
                    agent, clipped=clipped, batched=True
                )
                rollout_info.update({"batched_value_loss": batched_value_loss})

            if not update_agent_separately:
                if level_sampler and update_level_sampler:
                    level_sampler.update_with_rollouts(agent.storage)

                value_loss, action_loss, dist_entropy, info = agent.update(
                    discard_grad=discard_grad, kl_dict=kl_dict
                )

                if level_sampler and update_level_sampler:
                    level_sampler.after_update()

                if "grad_stats" in info:
                    gs = info["grad_stats"]
                    rollout_info.update(
                        {
                            "grad_var": gs.get("grad_var", 0.0),
                            "grad_dir_cos": gs.get("grad_dir_cos", None),
                            "clip_fraction": gs.get("clip_fraction", 0.0),
                            "pre_clip_norm_mean": gs.get("pre_clip_norm_mean", 0.0),
                            "post_clip_norm_mean": gs.get("post_clip_norm_mean", 0.0),
                        }
                    )

                if "kl_loss" in info.keys():
                    kl_loss = info.pop("kl_loss")
                    rollout_info.update({"kl_loss": kl_loss})

                rollout_info.update(
                    {
                        "value_loss": value_loss,
                        "action_loss": action_loss,
                        "dist_entropy": dist_entropy,
                        "update_info": {k: v for k, v in info.items() if k != "grad_stats"},
                    }
                )

                if args.log_action_complexity:
                    rollout_info.update(
                        {"action_complexity": agent.storage.get_action_complexity()}
                    )

        return rollout_info

    def _update_agent_separately(
        self,
        agent,
        level_sampler=None,
        update_level_sampler=False,
        discard_grad=False,
        kl_dict=None,
        external_scores=None,
    ):
        if level_sampler and update_level_sampler:
            level_sampler.update_with_rollouts(
                agent.storage, external_scores=external_scores
            )

        value_loss, action_loss, dist_entropy, info = agent.update(
            discard_grad=discard_grad, kl_dict=kl_dict
        )

        if level_sampler and update_level_sampler:
            level_sampler.after_update()

        rollout_info = {
            "value_loss": value_loss,
            "action_loss": action_loss,
            "dist_entropy": dist_entropy,
            "update_info": {k: v for k, v in info.items() if k != "grad_stats"},
        }

        if "grad_stats" in info:
            gs = info["grad_stats"]
            rollout_info.update(
                {
                    "grad_var": gs.get("grad_var", 0.0),
                    "grad_dir_cos": gs.get("grad_dir_cos", None),
                    "clip_fraction": gs.get("clip_fraction", 0.0),
                    "pre_clip_norm_mean": gs.get("pre_clip_norm_mean", 0.0),
                    "post_clip_norm_mean": gs.get("post_clip_norm_mean", 0.0),
                }
            )

        if "kl_loss" in info.keys():
            kl_loss = info.pop("kl_loss")
            rollout_info.update({"kl_loss": kl_loss})

        if self.args.log_action_complexity:
            rollout_info.update(
                {"action_complexity": agent.storage.get("action_complexity", None)}
            )

        return rollout_info

    def _compute_env_return(self, agent_info, adversary_agent_info):
        args = self.args
        if args.ued_algo == "paired":
            env_return = torch.max(
                adversary_agent_info["max_return"] - agent_info["mean_return"],
                torch.zeros_like(agent_info["mean_return"]),
            )

        elif args.ued_algo == "flexible_paired":
            env_return = torch.zeros_like(
                agent_info["max_return"], dtype=torch.float, device=self.device
            )
            adversary_agent_max_idx = (
                adversary_agent_info["max_return"] > agent_info["max_return"]
            )
            agent_max_idx = ~adversary_agent_max_idx

            env_return[adversary_agent_max_idx] = adversary_agent_info["max_return"][
                adversary_agent_max_idx
            ]
            env_return[agent_max_idx] = agent_info["max_return"][agent_max_idx]

            env_mean_return = torch.zeros_like(env_return, dtype=torch.float)
            env_mean_return[adversary_agent_max_idx] = agent_info["mean_return"][
                adversary_agent_max_idx
            ]
            env_mean_return[agent_max_idx] = adversary_agent_info["mean_return"][
                agent_max_idx
            ]

            env_return = torch.max(
                env_return - env_mean_return, torch.zeros_like(env_return)
            )

        elif args.ued_algo == "minimax":
            env_return = -agent_info["max_return"]

        else:
            env_return = torch.zeros_like(agent_info["mean_return"])

        if args.adv_normalize_returns:
            self.env_return_rms.update(env_return.flatten().cpu().numpy())
            env_return /= np.sqrt(self.env_return_rms.var + 1e-8)

        if args.adv_clip_reward is not None:
            clip_max_abs = args.adv_clip_reward
            env_return = env_return.clamp(-clip_max_abs, clip_max_abs)

        return env_return

    def run(self):
        args = self.args

        adversary_env = self.agents["adversary_env"]
        agent = self.agents["agent"]
        adversary_agent = self.agents["adversary_agent"]

        level_replay = False
        if args.use_plr and self.is_training:
            level_replay = self._sample_replay_decision()

        student_discard_grad = False
        no_exploratory_grad_updates = vars(args).get("no_exploratory_grad_updates", False)
        if args.use_plr and (not level_replay) and no_exploratory_grad_updates:
            student_discard_grad = True

        if self.is_training and not student_discard_grad:
            self.student_grad_updates += 1

        env_info = self.agent_rollout(
            agent=adversary_env,
            num_steps=self.adversary_env_rollout_steps,
            update=False,
            is_env=True,
            level_replay=level_replay,
            level_sampler=self._get_level_sampler("agent")[0],
            update_level_sampler=False,
        )

        level_sampler, is_updateable = self._get_level_sampler("agent")

        kl_dict_agent = None
        if self.use_accel_paired:
            kl_dict_agent = None
        elif self.is_training and self.args.use_behavioural_cloning:
            if (self.student_grad_updates) % self.args.kl_update_step == 0:
                kl_dict_agent = {}
                adversary_agent.eval()
                kl_dict_agent["antagonist_model"] = adversary_agent.algo.actor_critic

        agent_info = self.agent_rollout(
            agent=agent,
            num_steps=self.agent_rollout_steps,
            update=self.is_training,
            level_replay=level_replay,
            level_sampler=level_sampler,
            update_level_sampler=is_updateable,
            discard_grad=student_discard_grad,
            kl_dict=kl_dict_agent,
            update_agent_separately=self.use_accel_paired,
        )

        if kl_dict_agent is not None:
            adversary_agent.train()

        adversary_agent_info = defaultdict(float)
        if self.is_paired:
            level_sampler_adv, is_updateable_adv = self._get_level_sampler(
                "adversary_agent"
            )

            kl_dict_adv_agent = None
            if not self.args.use_kl_only_agent:
                if self.is_training and self.args.use_behavioural_cloning:
                    if (self.student_grad_updates) % self.args.kl_update_step == 0:
                        kl_dict_adv_agent = {}
                        agent.eval()
                        kl_dict_adv_agent["antagonist_model"] = agent.algo.actor_critic

            adversary_agent_info = self.agent_rollout(
                agent=adversary_agent,
                num_steps=self.agent_rollout_steps,
                update=self.is_training,
                level_replay=level_replay,
                level_sampler=level_sampler_adv,
                update_level_sampler=is_updateable_adv,
                discard_grad=student_discard_grad,
                kl_dict=kl_dict_adv_agent,
            )

            if kl_dict_adv_agent is not None:
                agent.train()

        elif self.use_accel_paired:
            adversary_agent_info = self.agent_rollout(
                agent=adversary_agent,
                num_steps=self.agent_rollout_steps,
                update=self.is_training,
                level_replay=False,
                level_sampler=None,
                update_level_sampler=False,
                discard_grad=student_discard_grad,
                kl_dict=None,
                update_agent_separately=self.use_accel_paired,
            )

            external_scores = self._calculate_paired_regret_scores(
                agent_info,
                adversary_agent_info,
                type=args.accel_paired_score_function,
            )

            level_sampler_agent, is_updateable_agent = self._get_level_sampler("agent")

            kl_dict_agent = None
            if self.is_training and self.args.use_behavioural_cloning:
                if (self.student_grad_updates) % self.args.kl_update_step == 0:
                    kl_dict_agent = {}
                    adversary_agent.eval()
                    kl_dict_agent["antagonist_model"] = adversary_agent.algo.actor_critic

            agent_update_rollout_info = self._update_agent_separately(
                agent,
                level_sampler=level_sampler_agent,
                update_level_sampler=is_updateable_agent,
                discard_grad=student_discard_grad,
                kl_dict=kl_dict_agent,
                external_scores=external_scores,
            )

            if kl_dict_agent is not None:
                adversary_agent.train()

            agent_info.update(agent_update_rollout_info)

            kl_dict_adv_agent = None
            if not self.args.use_kl_only_agent:
                if self.is_training and self.args.use_behavioural_cloning:
                    if (self.student_grad_updates) % self.args.kl_update_step == 0:
                        kl_dict_adv_agent = {}
                        agent.eval()
                        kl_dict_adv_agent["antagonist_model"] = agent.algo.actor_critic

            adversary_agent_update_rollout_info = self._update_agent_separately(
                adversary_agent,
                level_sampler=level_sampler_agent,
                update_level_sampler=is_updateable_agent,
                discard_grad=student_discard_grad,
                kl_dict=kl_dict_adv_agent,
                external_scores=external_scores,
            )

            if kl_dict_adv_agent is not None:
                agent.train()

            adversary_agent_info.update(adversary_agent_update_rollout_info)

        per_env_return = agent_info["mean_return"].detach().cpu().numpy().reshape(-1)

        num_early_stop = 0
        num_dropped = 0

        stuck_seed_index = [0] * args.num_processes

        if level_replay:
            early_stop_idx = np.where(per_env_return >= self.threshold)[0]
            num_early_stop = int(len(early_stop_idx))
            early_stop_idx_set = set(early_stop_idx.tolist())

            dropped_solved = set()
            dropped_stuck = set()

            if args.use_plr and env_info is not None:
                for i, seed in enumerate(env_info):
                    if seed is None:
                        continue
                    seed = int(seed)
                    if seed < 0:
                        continue

                    self.level_sample_counts[seed] += 1

                    count_i = self.level_sample_counts.get(seed)

                    if i in early_stop_idx_set:
                        if seed not in self.dropped_seeds:
                            self.dropped_seeds.add(seed)
                            dropped_solved.add(seed)

                    else:
                        if count_i >= self.early_stop_patience:
                            if seed not in self.dropped_seeds:

                                stuck_seed_index[i] = 1

                                self.dropped_seeds.add(seed)
                                dropped_stuck.add(seed)

                for s in dropped_solved:
                    self._record_dropped_level(s, reason="early_stop")
                for s in dropped_stuck:
                    self._record_dropped_level(s, reason="patience")

                if (dropped_solved or dropped_stuck) and args.use_plr:
                    for sampler in self.all_level_samplers:
                        for s in dropped_solved | dropped_stuck:
                            sampler.drop_seed(int(s))

                # num_dropped = len(dropped_solved) + len(dropped_stuck)
                num_dropped = len(dropped_stuck)

            eligible_idx = early_stop_idx
        else:
            early_stop_idx = np.array([], dtype=np.int64)
            eligible_idx = 0

        edit_level = self._should_edit_level() and level_replay

        if edit_level and isinstance(eligible_idx, np.ndarray) and eligible_idx.size == 0:
            edit_level = False

        if edit_level:
            fixed_seeds = env_info

            level_sampler_agent, is_updateable_agent = self._get_level_sampler("agent")

            self.agent_rollout(
                agent=None,
                num_steps=None,
                is_env=True,
                edit_level=True,
                num_edits=args.num_edits,
                fixed_seeds=fixed_seeds,
                eligible_idx=eligible_idx,
                # target_partitions=target_partitions,
                # original_partitions=original_partitions,
            )

        ##################

        # now for each dropped stucked seed, add a complement new seed

        if level_replay:
            if (np.array(stuck_seed_index) != 0).any():
                stuck_index = np.where(np.array(stuck_seed_index) != 0)[0]

                self.ued_venv.reset_random()
                self._update_plr_with_current_unseen_levels(valid_indices=stuck_index)

        if edit_level:
            self.total_num_edits += 1
        ##################

        # Align LevelStore and replay buffer, and update weighted_num_edits
        self._reconcile_level_store_and_samplers()

        # if self._partition_steps_since_flush is None:
        #     self._partition_steps_since_flush = 0
        # self._partition_steps_since_flush += 1
        # self._flush_partition_events(force=False)

        env_return = self._compute_env_return(agent_info, adversary_agent_info)

        adversary_env_info = defaultdict(float)
        if self.is_training and self.is_training_env:
            with torch.no_grad():
                obs_id = adversary_env.storage.get_obs(-1)
                next_value = adversary_env.get_value(
                    obs_id,
                    adversary_env.storage.get_recurrent_hidden_state(-1),
                    adversary_env.storage.masks[-1],
                ).detach()
            adversary_env.storage.replace_final_return(env_return)
            adversary_env.storage.compute_returns(
                next_value, args.use_gae, args.gamma, args.gae_lambda
            )
            env_value_loss, env_action_loss, env_dist_entropy, info = adversary_env.update()
            adversary_env_info.update(
                {
                    "action_loss": env_action_loss,
                    "value_loss": env_value_loss,
                    "dist_entropy": env_dist_entropy,
                    "update_info": info,
                }
            )

        if self.is_training:
            self.num_updates += 1

        if args.use_plr and self.level_store is not None:
            self._flush_add_drop_records_to_csv(current_steps=self.num_updates)

        # -------------------------------------------------------------------
        # LOGGING
        # -------------------------------------------------------------------
        log_replay_complexity = level_replay and args.log_replay_complexity
        if (not level_replay) or log_replay_complexity:
            stats = self._get_env_stats(
                agent_info,
                adversary_agent_info,
                log_replay_complexity=log_replay_complexity,
            )
            stats.update(
                {
                    "mean_env_return": env_return.mean().item(),
                    "adversary_env_pg_loss": adversary_env_info["action_loss"],
                    "adversary_env_value_loss": adversary_env_info["value_loss"],
                    "adversary_env_dist_entropy": adversary_env_info["dist_entropy"],
                }
            )
            if args.use_plr:
                self.latest_env_stats.update(stats)
        else:
            stats = self.latest_env_stats.copy()

        if args.use_plr and args.log_plr_buffer_stats:
            stats.update(self._get_plr_buffer_stats())

        [self.agent_returns.append(r) for b in agent_info["returns"] for r in reversed(b)]
        mean_agent_return = 0.0
        if len(self.agent_returns) > 0:
            mean_agent_return = float(np.mean(self.agent_returns))

        mean_adversary_agent_return = 0.0
        if self.is_paired or self.use_accel_paired:
            [
                self.adversary_agent_returns.append(r)
                for b in adversary_agent_info["returns"]
                for r in reversed(b)
            ]
            if len(self.adversary_agent_returns) > 0:
                mean_adversary_agent_return = float(np.mean(self.adversary_agent_returns))

        replay_buffer_size = 0
        level_store_size = 0
        if args.use_plr and self.level_store is not None:
            buffer_seeds = self._last_buffer_seeds_for_stats
            if buffer_seeds is None:
                buffer_seeds = self._get_buffer_seeds()
            replay_buffer_size = len(buffer_seeds)
            level_store_size = len(self.level_store)

        stats.update(
            {
                "steps": (self.num_updates + self.total_num_edits)
                * args.num_processes
                * args.num_steps,
                "total_episodes": self.total_episodes_collected,
                "total_seeds": self.total_seeds_collected,
                "total_student_grad_updates": self.student_grad_updates,
                "mean_agent_return": mean_agent_return,
                "agent_value_loss": agent_info["value_loss"],
                "agent_pg_loss": agent_info["action_loss"],
                "agent_dist_entropy": agent_info["dist_entropy"],
                "mean_adversary_agent_return": mean_adversary_agent_return,
                "adversary_value_loss": adversary_agent_info["value_loss"],
                "adversary_pg_loss": adversary_agent_info["action_loss"],
                "adversary_dist_entropy": adversary_agent_info["dist_entropy"],
                "kl_loss_advagent_agent": agent_info.get("kl_loss", None),
                "kl_loss_agent_advagent": adversary_agent_info.get("kl_loss", None),
                "grad_var": agent_info.get("grad_var", 0.0),
                "grad_dir_cos": agent_info.get("grad_dir_cos", None),
                "clip_fraction": agent_info.get("clip_fraction", 0.0),
                "pre_clip_norm_mean": agent_info.get("pre_clip_norm_mean", 0.0),
                "post_clip_norm_mean": agent_info.get("post_clip_norm_mean", 0.0),
                "num_early_stop": num_early_stop,
                "num_dropped": num_dropped,
                "replay_buffer_size": replay_buffer_size,
                "level_store_size": level_store_size,
            }
        )

        if args.log_grad_norm:
            agent_grad_norm = 0.0
            if "update_info" in agent_info and "grad_norms" in agent_info["update_info"]:
                if len(agent_info["update_info"]["grad_norms"]) > 0:
                    agent_grad_norm = float(
                        np.mean(agent_info["update_info"]["grad_norms"])
                    )

            adversary_grad_norm = 0.0
            adversary_env_grad_norm = 0.0
            if (
                self.is_paired
                and "update_info" in adversary_agent_info
                and "grad_norms" in adversary_agent_info["update_info"]
            ):
                if len(adversary_agent_info["update_info"]["grad_norms"]) > 0:
                    adversary_grad_norm = float(
                        np.mean(adversary_agent_info["update_info"]["grad_norms"])
                    )
            if (
                self.is_training_env
                and "update_info" in adversary_env_info
                and "grad_norms" in adversary_env_info["update_info"]
            ):
                if len(adversary_env_info["update_info"]["grad_norms"]) > 0:
                    adversary_env_grad_norm = float(
                        np.mean(adversary_env_info["update_info"]["grad_norms"])
                    )

            stats.update(
                {
                    "agent_grad_norm": agent_grad_norm,
                    "adversary_grad_norm": adversary_grad_norm,
                    "adversary_env_grad_norm": adversary_env_grad_norm,
                }
            )

        if args.log_action_complexity:
            stats.update(
                {
                    "agent_action_complexity": agent_info.get("action_complexity", None),
                    "adversary_action_complexity": adversary_agent_info.get(
                        "action_complexity", None
                    ),
                }
            )

        return stats
