# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
from random import sample
from sklearn import cluster
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange


from nle import nethack

from ..env import get_env
from .util import id_pairs_table
import numpy as np

NUM_GLYPHS = nethack.MAX_GLYPH
NUM_FEATURES = 25
PAD_CHAR = 0
NUM_CHARS = 256


def conv_outdim(i_dim, k, padding=0, stride=1, dilation=1):
    """Return the dimension after applying a convolution along one axis"""
    return int(1 + (i_dim + 2 * padding - dilation * (k - 1) - 1) / stride)


def select(embedding_layer, x, use_index_select):
    """Use index select instead of default forward to possible speed up embedding."""
    if use_index_select:
        out = embedding_layer.weight.index_select(0, x.view(-1))
        # handle reshaping x to 1-d and output back to N-d
        return out.view(x.shape + (-1,))
    else:
        return embedding_layer(x)


def get_torso(flags):
    env = get_env(flags)
    if env == 'nethack':
        return NethackTorso
    else:
        return AtariTorso


class NetHackNet(nn.Module):
    """This base class simply provides a skeleton for running with torchbeast."""

    AgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline")
    ExtendedAgentOutput = collections.namedtuple("AgentOutput", "action policy_logits baseline objective")

    def __init__(self, flags):
        super(NetHackNet, self).__init__()
        self.num_rewards = flags.num_objectives + 1 if flags.multi_objective else 1
        self.register_buffer("reward_sum", torch.zeros((self.num_rewards,)))
        self.register_buffer("reward_m2", torch.zeros((self.num_rewards,)))
        self.register_buffer("reward_count", torch.zeros((self.num_rewards,)).fill_(1e-8))

    def forward(self, inputs, core_state):
        raise NotImplementedError

    def initial_state(self, batch_size=1):
        return ()

    @torch.no_grad()
    def update_running_moments(self, reward_batch):
        """Maintains a running mean of reward."""
        batch_shape = reward_batch.shape[:-1]
        n_batch_dim = len(batch_shape)
        batch_dim = tuple(range(n_batch_dim))
        new_count = len(reward_batch)
        new_sum = torch.sum(reward_batch, batch_dim)
        new_mean = new_sum / new_count

        curr_mean = self.reward_sum / self.reward_count
        new_m2 = torch.sum((reward_batch - new_mean) ** 2, batch_dim) + (
            (self.reward_count * new_count)
            / (self.reward_count + new_count)
            * (new_mean - curr_mean) ** 2
        )

        # print(new_count, new_m2, new_sum)

        self.reward_count += new_count
        self.reward_sum += new_sum
        self.reward_m2 += new_m2

    @torch.no_grad()
    def get_running_std(self):
        """Returns standard deviation of the running mean of the reward."""
        # print(self.reward_m2, self.reward_count, torch.sqrt(self.reward_m2 / self.reward_count))
        return torch.sqrt(self.reward_m2 / self.reward_count) + 1e-8


# class MultiBaselineNet(NetHackNet):
#     def __init__(self, observation_space, action_space, flags, device, logits_mask=None):
#         super(MultiBaselineNet, self).__init__(flags)
        
#         nets = []
#         for i in range(flags.num_objectives + 1):
#             nets.append(BaselineNet(observation_space, action_space, flags, device, logits_mask))
#         self.nets = nn.ModuleList(nets)

#     def initial_state(self):
#         return self.nets[0].initial_state()

#     def forward(self, inputs, core_state, learning=False):


class RNDNet(NetHackNet):
    """This model combines the encodings of the glyphs, top line message and
    blstats into a single fixed-size representation, which is then passed to
    an LSTM core before generating a policy and value head for use in an IMPALA
    like architecture.

    This model was based on 'neurips2020release' tag on the NLE repo, itself
    based on Kuttler et al, 2020
    The NetHack Learning Environment
    https://arxiv.org/abs/2006.13760
    """

    def __init__(self, observation_space, action_space, flags, device, logits_mask=None):
        super(RNDNet, self).__init__(flags)

        self.flags = flags

        self.torso = get_torso(flags)(observation_space, action_space, flags, device)

        self.output_dim = flags.rnd_output_dim
        self.net = nn.Linear(flags.hidden_dim, self.output_dim)

        self.logits_mask = logits_mask is not None
        if self.logits_mask:
            self.policy_logits_mask = nn.parameter.Parameter(
                logits_mask, requires_grad=False
            )

    def forward(self, inputs):
        T, B = inputs["done"].shape

        st = self.torso(inputs).view(T, B, -1)

        # return self.net(st) / np.sqrt(self.output_dim)
        return self.net(st)


class BaselineNet(NetHackNet):
    """This model combines the encodings of the glyphs, top line message and
    blstats into a single fixed-size representation, which is then passed to
    an LSTM core before generating a policy and value head for use in an IMPALA
    like architecture.

    This model was based on 'neurips2020release' tag on the NLE repo, itself
    based on Kuttler et al, 2020
    The NetHack Learning Environment
    https://arxiv.org/abs/2006.13760
    """

    def __init__(self, observation_space, action_space, flags, device, logits_mask=None):
        super(BaselineNet, self).__init__(flags)

        self.flags = flags

        self.num_actions = action_space.n
        self.use_lstm = flags.use_lstm
        self.h_dim = flags.hidden_dim

        self.torso = get_torso(flags)(observation_space, action_space, flags, device)

        old_version = "objective_as_input" not in flags

        self.multi_objective = flags.get("multi_objective", flags.num_objectives > 1)

        # self.encode_completed_objectives = False
        self.encode_completed_objectives = old_version and self.multi_objective and "completed" in observation_space.spaces
        # print(observation_space, flags.num_objectives, "completed" in observation_space)
        if self.encode_completed_objectives:
            print("Encoding completed objectives!")
            self.completed_objectives_encoder = nn.Sequential(
                nn.Linear(flags.num_objectives + self.h_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, self.h_dim),
                nn.ReLU(),
            )
        
        self.objective_as_input = flags.get("objective_as_input", False) and self.multi_objective
        self.num_objectives = flags.num_objectives if old_version else flags.num_objectives + 1  # only used when self.multi_objective

        if self.objective_as_input:
            print("Objective as input!!", flush=True)
            self.objective_encoder = nn.Sequential(
                nn.Linear(self.num_objectives + self.h_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, self.h_dim),
                nn.ReLU(),
            )

        if self.use_lstm:
            self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)

        self.num_policies = 1 if self.objective_as_input or not self.multi_objective else self.num_objectives
        self.policy = nn.Linear(self.h_dim, self.num_actions * self.num_policies)
        self.baseline = nn.Linear(self.h_dim, self.num_policies)

        self.logits_mask = logits_mask is not None
        if self.logits_mask:
            self.policy_logits_mask = nn.parameter.Parameter(
                logits_mask, requires_grad=False
            )
        

    def initial_state(self, batch_size=1):
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def forward(self, inputs, core_state, learning=False):
        T, B = inputs["done"].shape

        st = self.torso(inputs)

        if self.encode_completed_objectives:
            st = self.completed_objectives_encoder(torch.cat((st, inputs["completed"].view(T * B, -1)), 1))
        
        if self.objective_as_input:
            st = self.objective_encoder(torch.cat((st, F.one_hot(inputs["objective"].view(T * B).to(torch.int64), self.num_objectives)), 1))

        if self.use_lstm:
            core_input = st.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * t for t in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = st

        # -- [B' x A]
        policy_logits = self.policy(core_output).view(T * B * self.num_policies, -1)

        # -- [B' x 1]
        if self.flags.get("baseline_only", False):
            baseline = self.baseline(core_output.detach())
        else:
            baseline = self.baseline(core_output)

        if self.logits_mask:
            policy_logits = policy_logits * self.policy_logits_mask + (
                (1 - self.policy_logits_mask) * -1e10
            )

        # print(policy_logits.shape, policy_logits[0])

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        if self.objective_as_input or not self.multi_objective:
            policy_logits = policy_logits.view(T, B, -1)
            baseline = baseline.view(T, B)
            action = action.view(T, B)
        else:
            policy_logits = policy_logits.view(T, B, self.num_policies, -1)
            baseline = baseline.view(T, B, self.num_policies)
            action = action.view(T, B, self.num_policies)

        output = dict(policy_logits=policy_logits, baseline=baseline, action=action, core_output=core_output.view(T, B, -1), obs_st=st.view(T, B, -1))
        return (output, core_state)


class DynamicsModel(nn.Module):
    def __init__(self, observation_space, action_space, flags, device):
        super(DynamicsModel, self).__init__()

        self.flags = flags
        self.num_actions = action_space.n
        self.h_dim = flags.hidden_dim
        self.k = flags.dynamics_k
        self.z = flags.get("dynamics_z", 1)
        self.discrete = flags.get("dynamics_discrete", True)

        #   s_{t} + s_{t+1}
        #    |    | torso1
        #    |    V
        #    |  dyn_obs_st + action + done
        #    | torso2        | fc, softmax
        #    v               v
        #  obs_st    +     dyn_k
        #            | ffc + convt
        #            V
        #         s'_{t+1}

        self.torso1 = get_torso(flags)(observation_space, action_space, flags, device)

        self.fc = nn.Sequential(
            nn.Linear(self.h_dim + self.num_actions + 1, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.k * self.z)
        )

        self.torso2 = get_torso(flags)(observation_space, action_space, flags, device)
        self.ffc = nn.Sequential(
            nn.Linear(self.h_dim + self.k * self.z, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU()
        )
        # self.ffc = nn.Linear(self.h_dim, self.h_dim * self.k)

        self.use_causal_predictor = flags.get("frame_pred_causal_predictor", False)
        if self.use_causal_predictor:
            self.causal_predictor = nn.Sequential(
                nn.Linear(self.h_dim * 2, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, 1)
            )

        self.rnd = flags.get("frame_pred_rnd", False)
        if self.rnd:
            self.rnd_pred_fc = nn.Sequential(
                nn.Linear(self.h_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, self.h_dim)
            )
            self.rnd_truth_fc = nn.Sequential(
                nn.Linear(self.h_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, self.h_dim)
            )

        self.causal_clustering = flags.get("frame_pred_causal_clustering", False)
        if self.causal_clustering:
            self.causal_clustering_baseline = nn.Sequential(
                nn.Linear(self.h_dim + self.num_actions + 1, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, 1)
            )

        self.discrete_exploration = flags.get("frame_pred_discrete_exploration", False)
        if self.discrete_exploration:
            self.prior_param = nn.Parameter(torch.randn((self.z, self.k)))

        self.goal_generation = flags.get("goal_generation", False)
        if self.goal_generation:
            self.goal_generator = nn.Sequential(
                nn.Linear(self.z * self.k, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, 1)
            )

        self.error_prediction = flags.get("frame_pred_error_prediction", False)
        if self.error_prediction:
            self.error_predictor = nn.Sequential(
                nn.Linear(self.z * self.k, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, 1)
            )

        self.convt = TransposeAtariTorso(observation_space, self.h_dim, flags, device)

    def forward(self, inputs, action, done, last_obs=None, training=True):
        obs_keys = ("frame",)
        T, B, *_ = action.shape
        if last_obs is None:
            last_obs = dict()
            for key in obs_keys:
                t = torch.zeros_like(inputs[key])
                t[1:].copy_(inputs[key][:-1])
                last_obs[key] = t

        new_obs = { key: torch.cat((inputs[key], last_obs[key]), dim=1) for key in obs_keys }
        if self.flags.get("frame_pred_mask_inventory", False):
            new_obs["frame"][:, :, :, 63:, :] = 0

        dyn_obs_st = self.torso1(new_obs)
        dyn_obs_st = dyn_obs_st.view(T, B * 2, -1)
        hh_dim = self.h_dim // 2
        dyn_obs_st = torch.cat(((1 - done.float().unsqueeze(-1)) * dyn_obs_st[:, :B, :hh_dim], dyn_obs_st[:, B:, hh_dim:]), dim=2)
        dyn_obs_st = dyn_obs_st.view(T * B, -1)

        # -- [B' x num_actions]
        one_hot_action = F.one_hot(
            action.view(T * B), self.num_actions
        ).float()

        st = torch.cat((dyn_obs_st, one_hot_action, done.float().view(T * B, 1)), 1)

        logits = self.fc(st).view(T * B, self.z, self.k)

        if self.flags.get("frame_pred_embed_scale", False):
            logits = logits / torch.std(logits)

        if training:
            if self.flags.get("frame_pred_embed_std_detach", True):
                std = torch.std(logits).detach()
            else:
                std = torch.std(logits)
            logits = logits + torch.randn_like(logits) * self.flags.get("frame_pred_embed_eps", 0.0) * std

        if self.discrete:
            dyn_k = F.gumbel_softmax(logits, tau=1, hard=True).view(T * B, self.z * self.k)  # [T * B, z * k], one_hot
        else:
            dyn_k = logits.view(T * B, self.z * self.k)

        obs_st = self.torso2(last_obs)

        fst = torch.cat((obs_st, dyn_k), dim=1)
        fst = self.ffc(fst)

        # ast = self.ffc(obs_st).view(T * B, self.k, self.h_dim)

        # fst = (ast * dyn_k.unsqueeze(2)).sum(1)
        final_pred = self.convt(fst).view(T, B, 3, 84, 84)
        if self.flags.frame_pred_delta:
            final_pred = final_pred * 2 - 255.0 + last_obs["frame"]


        predictions = {
            "frame": final_pred, 
            "dyn_k": dyn_k.view(T, B, self.z, self.k), 
            "dyn_logits": logits.view(T, B, self.z, self.k)
        }

        model_loss = {}

        if self.flags.get("dynamics_contrast", True) and self.discrete:
            contrast_indices = np.random.choice(T * B, min(T * B, 512), replace=False)
            n_contrast = len(contrast_indices)
            sampled_ks = np.random.choice(self.k, min(self.k, 8), replace=False)
            n_sampled_k = len(sampled_ks)
            contrast_preds = []
            for i in sampled_ks:
                dyn_k_i = torch.zeros((n_contrast, self.k), dtype=torch.float, device=obs_st.device)
                dyn_k_i[:, i] = 1
                contrast_preds.append(torch.cat((obs_st[contrast_indices], dyn_k_i), dim=1))
            contrast_preds = torch.stack(contrast_preds, 1).view(n_contrast * n_sampled_k, -1)

            cst = self.ffc(contrast_preds)

            # cst = ast[contrast_indices].view(contrast_N * self.k, self.h_dim)
            to_contrast = self.convt(cst).view(n_contrast, n_sampled_k, 3, 84, 84)
            # print(pred_frames.shape, last_obs["frame"].view(T * ).unsqueeze(1).shape)
            if self.flags.frame_pred_delta:
                to_contrast = to_contrast * 2 - 255.0 + last_obs["frame"].view(T * B, 1, 3, 84, 84)[contrast_indices]

            if self.flags.get("dynamics_downsample", False):
                to_contrast = F.interpolate(to_contrast.view(n_contrast * n_sampled_k, 3, 84, 84), (8, 8)).view(n_contrast, n_sampled_k, 3, 8, 8)

            in_dis = [[0 for _ in range(n_sampled_k)] for _ in range(n_sampled_k)]
            for i in range(n_sampled_k):
                for j in range(n_sampled_k):
                    dis = ((to_contrast[:, i] - to_contrast[:, j]) / 255.0).square().sum((-1, -2, -3))
                    in_dis[i][j] = in_dis[j][i] = dis
            in_dis = torch.stack([torch.stack(row, 1) for row in in_dis], dim=1)
            # print(in_dis.shape)
            # in_dis = (to_contrast.unsqueeze(2) - to_contrast.unsqueeze(3)).square().sum((-1, -2, -3))
            d = torch.exp(-in_dis / in_dis.max() * 2)
            c = torch.det(d)
            model_loss['contrast_loss'] = -c.sum() * (T * B / n_contrast) * self.flags.get("dynamics_contrast_coef", 0.1)

        if self.use_causal_predictor and T > 1:
            embeds = dyn_k.view(T, B, -1)
            CAUSAL_BATCH_SIZE = 1024
            causal_inputs = []
            causal_targets = []
            for _ in range(CAUSAL_BATCH_SIZE):
                while True:
                    b = np.random.randint(B)
                    x, y = np.random.choice(T, 2, replace=False)
                    if 1 <= action[x, b].item() <= 4 or 1 <= action[y, b].item() <= 4:
                        continue
                    if not torch.any(done[min(x, y): max(x, y), b]).item():
                        break
                # print(T, B, t, x ,y)
                causal_inputs.append(torch.cat((embeds[x, b], embeds[y, b])))
                causal_targets.append([int(x < y)])
            causal_inputs = torch.stack(causal_inputs, 0)
            causal_targets = torch.tensor(causal_targets, dtype=torch.float32, device=causal_inputs.device)
            causal_preds = self.causal_predictor(causal_inputs)
            loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
            causal_loss = loss_fn(causal_preds, causal_targets)
            model_loss['causal_loss'] = causal_loss
            predictions = {}

        if self.flags.get("frame_pred_chain_contrast", False):
            embeds = dyn_k.view(T, B, -1)
            chain_contrast_loss = 0
            for j in range(B):
                last = None
                for i in range(T):
                    if done[i, j].item():
                        last = None
                    if not (1 <= action[i, j].item() <= 4):
                        if last is not None:
                            chain_contrast_loss -= (embeds[i, j] - last).square().sum() / max(embeds[i, j].square().sum(), last.square().sum())
                        last = embeds[i, j]
            model_loss["chain_contrast_loss"] = chain_contrast_loss
            # predictions.pop("frame")

        if self.causal_clustering:
            # predictions = {}
            predictions.pop("frame")
            assert self.z == 1
            _dyn_k = dyn_k.detach().view(T, B, self.k).argmax(-1)
            criterion = self.flags.get("frame_pred_causal_clustering_criterion", "causal")
            if criterion == "causal":
                happen = np.zeros(self.k, dtype=int)
                before = np.zeros((self.k, self.k), dtype=int)
                for j in range(B):
                    appeared = np.zeros(self.k, dtype=bool)
                    for i in range(T):
                        if done[i, j].item():
                            appeared[:] = False
                        k = int(_dyn_k[i, j].item())
                        if not appeared[k]:
                            for l in range(self.k):
                                if appeared[l]:
                                    before[l][k] += 1
                            appeared[k] = True
                            happen[k] += 1
                score = np.zeros(self.k, dtype=float)
                for k in range(self.k):
                    for l in range(self.k):
                        if k != l and happen[k] > 0 and happen[l] > 0:
                            ratio = before[k][l] / happen[l]
                            if ratio > 0.9:
                                s = 1
                            elif ratio > 0.8:
                                s = 1 / 0.1 * (ratio - 0.8)
                            else:
                                s = 0
                            score[k] += happen[l] * s
                            score[l] += happen[l] * s
            elif criterion == "scarcity":
                appearance = []
                for j in range(B):
                    length = 0
                    _appearance = np.zeros(self.k, dtype=int)
                    for i in range(T):
                        if done[i, j].item():
                            if length > 0:
                                appearance.append(_appearance.copy() / length)
                            length = 0
                            _appearance[:] = 0
                        length += 1
                        k = int(_dyn_k[i, j].item())
                        _appearance[k] += 1
                    if length > 0:
                        appearance.append(_appearance.copy() / length)
                appearance = np.stack(appearance, 0)
                nonzero = (appearance > 0).astype(int).mean(0)
                # freq = (appearance > 0.1).astype(int).mean(0)
                freq = np.maximum(appearance - 0.1, 0).mean(0)
                score = np.minimum(nonzero, 0.3) * 1.0 / 0.3 - freq
                score[0] = 0
            # causal_clustering_loss = 0
            print(nonzero, freq)
            print(score)
            score = torch.tensor(score, dtype=torch.float32, device=action.device)
            # score = (score - score.mean()) / score.std()
            # score -= score.mean()
            # ll_logits = torch.log(F.softmax(logits, 2))
            # print(ll_logits.shape)
            # print(ll_logits[0])
            # causal_clustering_loss = -(ll_logits * dyn_k.detach() * score[None, None, :]).sum()
            # print(logits[0, 0])
            causal_clustering_baseline = self.causal_clustering_baseline(st)
            causal_clustering_return = (dyn_k * score[None, :]).sum(-1)
            causal_clustering_loss = -(causal_clustering_return - (dyn_k * causal_clustering_baseline).sum(-1)).sum()
            model_loss["causal_clustering_loss"] = causal_clustering_loss.sum()
            model_loss["causal_clustering_baseline_loss"] = (causal_clustering_baseline - causal_clustering_return).square().sum()
                    

        if self.rnd:
            rnd_truth = self.rnd_truth_fc(dyn_k).detach()
            rnd_pred = self.rnd_pred_fc(dyn_k)
            rnd_diff = (rnd_pred - rnd_truth).square()
            predictions["dyn_rnd_diff"] = rnd_diff.sum(-1).detach().view(T, B)
            model_loss["rnd_loss"] = rnd_diff.sum()


        if self.discrete_exploration:
            _dyn_k = dyn_k.detach().view(T * B, self.z, self.k).argmax(-1)
            p = F.softmax(self.prior_param)
            probs = torch.stack([p[i][_dyn_k[:, i]] for i in range(self.z)], 1)  # [T * B, z]
            prior_loss = -torch.log(probs)
            model_loss['prior_loss'] = prior_loss.sum()
            predictions['dyn_exploration_reward'] = prior_loss.sum(1).view(T, B)


        if self.flags.get("frame_pred_variational_loss", False):
            if self.discrete:
                prob = F.softmax(logits, dim=-1)
                prob = (prob * dyn_k.detach().view(T * B, self.z, self.k)).sum(-1)
                variational_loss = torch.log(prob).sum()
                model_loss['variational_loss'] = variational_loss
            else:
                variational_loss = logits.square().sum()
                model_loss['variational_loss'] = variational_loss

        if self.goal_generation:
            predictions.pop("frame")
            goal = self.goal_generator(dyn_k)
            predictions["dyn_k_goal"] = goal.view(T, B)

        if self.error_prediction:
            error_pred = self.error_predictor(dyn_k)
            predictions["error"] = error_pred.view(T, B)

        return (
            predictions,
            model_loss,
            None,
            None
        )


class PredictionModel(nn.Module):
    def __init__(self, observation_space, action_space, flags, device, predict_items={"reward": ()}, version=1):
        super(PredictionModel, self).__init__()

        self.flags = flags
        self.predict_items = predict_items
        self.version = version

        self.num_actions = action_space.n
        self.h_dim = flags.hidden_dim
        
        env = get_env(flags)
        if env == 'nethack':
            self.obs_keys = ("glyphs", "message", "blstats", "chars", "colors", "specials")
        else:
            self.obs_keys = ("frame",)

        self.torso = get_torso(flags)(observation_space, action_space, flags, device)

        self.fc = nn.Sequential(
            nn.Linear(self.h_dim + self.num_actions + 1, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

        if version == 1:
            self.reward_prediction = nn.Linear(self.h_dim, 1)
        elif version == 2:
            predict_heads = {}
            if "reward" in predict_items:
                predict_heads["reward"] = nn.Linear(self.h_dim, 1)
            if "frame" in predict_items:
                dim = self.h_dim
                predict_heads["frame"] = TransposeAtariTorso(observation_space, dim, flags, device)
            if "event" in predict_items:
                predict_heads["event"] = nn.Linear(self.h_dim, flags.num_events)
            
            assert len(predict_heads) > 0
            self.predict_heads = nn.ModuleDict(predict_heads)

        self.frame_pred_delta = flags.get("frame_pred_delta", False)
        self.frame_pred_alpha = flags.get("frame_pred_alpha", False)

    def forward(self, inputs, action, done, last_obs=None, training=True):
        obs_keys = self.obs_keys
        T, B, *_ = action.shape
        if last_obs is None:
            last_obs = dict()
            for key in obs_keys:
                t = torch.zeros_like(inputs[key])
                t[1:].copy_(inputs[key][:-1])
                last_obs[key] = t
        # print(inputs["glyphs"].shape, last_obs["glyphs"].shape, inputs["blstats"].shape, last_obs["blstats"].shape)

        if ("frame" in self.predict_items and not self.version == 3) or self.flags.get("pred_no_next_frame", False):
            next_obs_mask = 0
        else:
            next_obs_mask = 1
        new_obs = { key: torch.cat((inputs[key] * next_obs_mask, last_obs[key]), dim=1) for key in obs_keys }

        obs_st = self.torso(new_obs)

        obs_st = obs_st.view(T, B * 2, -1)
        hh_dim = self.h_dim // 2
        obs_st = torch.cat(((1 - done.float().unsqueeze(-1)) * obs_st[:, :B, :hh_dim], obs_st[:, B:, hh_dim:]), dim=2)
        obs_st = obs_st.view(T * B, -1)

        # -- [B' x num_actions]
        one_hot_action = F.one_hot(
            action.view(T * B), self.num_actions
        ).float()

        st = torch.cat((obs_st, one_hot_action, done.float().view(T * B, 1)), 1)

        st = self.fc(st) + obs_st

        if training:
            st += torch.randn_like(st) * self.flags.get("frame_pred_embed_eps", 0.0)

        predictions = dict()
        if self.version == 1:
            predictions["reward"] = self.reward_prediction(st).view(T, B)
        else:
            for key, head in self.predict_heads.items():
                if key == "frame":
                    frame_depth, frame_shape = self.predict_items[key][0], self.predict_items[key][1:]
                    if self.frame_pred_alpha:
                        frame_depth += 1

                    output = head(st).view(T, B, frame_depth, *frame_shape)
                    if self.frame_pred_delta:
                        predictions[key] = output * 2 - 255.0 + last_obs[key]
                    elif self.frame_pred_alpha:
                        mask = output[:, :, 3:] / 255.0
                        predictions["mask"] = mask * ~done.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                        # print(mask.shape, output.shape)
                        predictions[key] = mask.detach() * output[:, :, :3] + (1 - mask.detach()) * last_obs[key]
                    else:
                        predictions[key] = output
                    predictions[key] *= ~done.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                    # if self.flags.get("frame_pred_no_moving", False):
                    #     moving_mask = torch.logical_and(torch.ge(action, 1), torch.le(action, 4)).to(torch.float)
                    #     predictions[key] *= (1 - moving_mask).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                else:
                    predictions[key] = head(st).view(T, B, *self.predict_items[key])

        return (
            predictions,
            {},
            obs_st.clone().view(T, B, -1),
            st.clone().view(T, B, -1)
        )


class RewardClassifier(nn.Module):

    def __init__(self, pred_net: PredictionModel, centroids: np.ndarray, threshold: float, device, cluster_names=None):
        super(RewardClassifier, self).__init__()

        self.pred_net = pred_net
        self.num_clusters = centroids.shape[0]
        self.centroids = torch.from_numpy(centroids).to(device=device, dtype=torch.float)[None, None, :, :]  # [1, 1, K, D]
        self.threshold = threshold
        self.cluster_names = cluster_names + [f"{self.num_clusters:2}:new_tasks"]

    def forward(self, inputs, action, done, last_obs=None, one_hot=False, return_embeds=False):
        # print(self.centroids.device)
        # print(action.device, done.device)
        # print(next(self.pred_net.parameters()).device)
        embeds = self.pred_net(inputs, action, done, last_obs=last_obs, training=False)[3].detach()
        dis = (embeds[:, :, None, :] - self.centroids).square().sum(-1)  # [T, B, K]
        min_dis = torch.min(dis, dim=-1)
        classes = torch.where(min_dis.values < self.threshold, min_dis.indices, self.num_clusters)
        if one_hot:
            result = F.one_hot(classes, self.num_clusters + 1)
        else:
            result = classes
        if return_embeds:
            return result, embeds
        else:
            return result


class AtariTorso(nn.Module):
    def __init__(self, observation_space, action_space, flags, device):
        super(AtariTorso, self).__init__()
        self.observation_shape = observation_space['frame'].shape
        self.num_actions = action_space.n

        self.h_dim = flags.hidden_dim

        # Feature extraction.
        # [3, 84, 84]
        self.conv1 = nn.Conv2d(
            in_channels=self.observation_shape[0],
            out_channels=32,
            kernel_size=8,
            stride=4,
        )  
        # [32, 20, 20]
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        # [64, 9, 9]
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        # [64, 7, 7]

        out_dim = 3136

        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

    def forward(self, inputs):
        x = inputs["frame"]  # [T, B, C, H, W].
        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.float() / 255.0
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(T * B, -1)
        x = self.fc(x)

        return x


class TransposeAtariTorso(nn.Module):
    def __init__(self, observation_space, input_dim, flags, device):
        super(TransposeAtariTorso, self).__init__()

        self.h_dim = flags.hidden_dim

        self.alpha = flags.get("frame_pred_alpha", False)

        self.version = flags.get("transpose_cnn_version", 1)
        if self.version == 1:

            self.fc = nn.Sequential(
                nn.Linear(input_dim, self.h_dim),
                nn.ReLU(),
                nn.Linear(self.h_dim, 3136),
                nn.ReLU(),
            )

            self.tconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1)
            self.tconv2 = nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2)
            self.tconv3 = nn.ConvTranspose2d(32, observation_space['frame'].shape[0] + int(self.alpha), kernel_size=8, stride=4)

        elif self.version == 2:

            self.fc = nn.Sequential(
                nn.Linear(self.h_dim, 64 * 7 * 7),
                nn.ReLU(),
            )

            self.tconv1 = nn.ConvTranspose2d(64, 64, kernel_size=6, stride=2)
            self.tconv2 = nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2)
            self.tconv3 = nn.ConvTranspose2d(32, observation_space['frame'].shape[0] + int(self.alpha), kernel_size=6, stride=2)

    def forward(self, inputs):
        x = self.fc(inputs)

        if self.version == 1:
            x = x.view(-1, 64, 7, 7)
            x = F.relu(self.tconv1(x))
            x = F.relu(self.tconv2(x))
            x = self.tconv3(x)
        elif self.version == 2:
            x = x.view(-1, 64, 7, 7)
            x = F.relu(self.tconv1(x))
            x = F.relu(self.tconv2(x))
            x = self.tconv3(x)

        x = torch.sigmoid(x) * 255.0

        return x


class NethackTorso(nn.Module):
    def __init__(self, observation_space, action_space, flags, device):
        super(NethackTorso, self).__init__()

        self.flags = flags

        self.observation_shape = observation_space["glyphs"].shape
        self.num_actions = action_space.n

        self.H = self.observation_shape[0]
        self.W = self.observation_shape[1]

        self.h_dim = flags.hidden_dim

        # GLYPH + CROP MODEL
        self.glyph_model = GlyphEncoder(flags, self.H, self.W, flags.crop_dim, device)

        # MESSAGING MODEL
        self.msg_model = MessageEncoder(
            flags.msg.hidden_dim, flags.msg.embedding_dim, device
        )

        # BLSTATS MODEL
        self.blstats_model = BLStatsEncoder(NUM_FEATURES, flags.embedding_dim)

        out_dim = (
            self.blstats_model.hidden_dim
            + self.glyph_model.hidden_dim
            + self.msg_model.hidden_dim
        )

        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

    def forward(self, inputs):

        reps = []

        # -- [B' x K] ; B' == (T x B)
        glyphs_rep = self.glyph_model(inputs)
        reps.append(glyphs_rep)

        # -- [B' x K]
        char_rep = self.msg_model(inputs)
        reps.append(char_rep)

        # -- [B' x K]
        features_emb = self.blstats_model(inputs)
        reps.append(features_emb)

        # -- [B' x K]
        st = torch.cat(reps, dim=1)

        # -- [B' x K]
        st = self.fc(st)

        return st


class GlyphEncoder(nn.Module):
    """This glyph encoder first breaks the glyphs (integers up to 6000) to a
    more structured representation based on the qualities of the glyph: chars,
    colors, specials, groups and subgroup ids..
       Eg: invisible hell-hound: char (d), color (red), specials (invisible),
                                 group (monster) subgroup id (type of monster)
       Eg: lit dungeon floor: char (.), color (white), specials (none),
                              group (dungeon) subgroup id (type of dungeon)

    An embedding is provided for each of these, and the embeddings are
    concatenated, before encoding with a number of CNN layers.  This operation
    is repeated with a crop of the structured reprentations taken around the
    characters position, and the two representations are concatenated
    before returning.
    """

    def __init__(self, flags, rows, cols, crop_dim, device=None):
        super(GlyphEncoder, self).__init__()

        self.crop = Crop(rows, cols, crop_dim, crop_dim, device)
        K = flags.embedding_dim  # number of input filters
        L = flags.layers  # number of convnet layers

        assert (
            K % 8 == 0
        ), "This glyph embedding format needs embedding dim to be multiple of 8"
        unit = K // 8
        self.chars_embedding = nn.Embedding(256, 2 * unit)
        self.colors_embedding = nn.Embedding(16, unit)
        self.specials_embedding = nn.Embedding(256, unit)

        self.id_pairs_table = nn.parameter.Parameter(
            torch.from_numpy(id_pairs_table()), requires_grad=False
        )
        num_groups = self.id_pairs_table.select(1, 1).max().item() + 1
        num_ids = self.id_pairs_table.select(1, 0).max().item() + 1

        self.groups_embedding = nn.Embedding(num_groups, unit)
        self.ids_embedding = nn.Embedding(num_ids, 3 * unit)

        F = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        self.output_filters = 8

        in_channels = [K] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [self.output_filters]

        h, w, c = rows, cols, crop_dim
        conv_extract, conv_extract_crop = [], []
        for i in range(L):
            conv_extract.append(
                nn.Conv2d(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    kernel_size=(F, F),
                    stride=S,
                    padding=P,
                )
            )
            conv_extract.append(nn.ELU())

            conv_extract_crop.append(
                nn.Conv2d(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    kernel_size=(F, F),
                    stride=S,
                    padding=P,
                )
            )
            conv_extract_crop.append(nn.ELU())

            # Keep track of output shapes
            h = conv_outdim(h, F, P, S)
            w = conv_outdim(w, F, P, S)
            c = conv_outdim(c, F, P, S)

        self.hidden_dim = (h * w + c * c) * self.output_filters
        self.extract_representation = nn.Sequential(*conv_extract)
        self.extract_crop_representation = nn.Sequential(*conv_extract_crop)
        self.select = lambda emb, x: select(emb, x, flags.use_index_select)

    def glyphs_to_ids_groups(self, glyphs):
        T, B, H, W = glyphs.shape
        ids_groups = self.id_pairs_table.index_select(0, glyphs.view(-1).long())
        ids = ids_groups.select(1, 0).view(T, B, H, W).long()
        groups = ids_groups.select(1, 1).view(T, B, H, W).long()
        return [ids, groups]

    def forward(self, inputs):
        T, B, H, W = inputs["glyphs"].shape
        ids, groups = self.glyphs_to_ids_groups(inputs["glyphs"])

        glyph_tensors = [
            self.select(self.chars_embedding, inputs["chars"].long()),
            self.select(self.colors_embedding, inputs["colors"].long()),
            self.select(self.specials_embedding, inputs["specials"].long()),
            self.select(self.groups_embedding, groups),
            self.select(self.ids_embedding, ids),
        ]

        glyphs_emb = torch.cat(glyph_tensors, dim=-1)
        glyphs_emb = rearrange(glyphs_emb, "T B H W K -> (T B) K H W")

        coordinates = inputs["blstats"].view(T * B, -1).float()[:, :2]
        crop_emb = self.crop(glyphs_emb, coordinates)

        glyphs_rep = self.extract_representation(glyphs_emb)
        glyphs_rep = rearrange(glyphs_rep, "B C H W -> B (C H W)")
        assert glyphs_rep.shape[0] == T * B

        crop_rep = self.extract_crop_representation(crop_emb)
        crop_rep = rearrange(crop_rep, "B C H W -> B (C H W)")
        assert crop_rep.shape[0] == T * B

        st = torch.cat([glyphs_rep, crop_rep], dim=1)
        return st


class MessageEncoder(nn.Module):
    """This model encodes the the topline message into a fixed size representation.

    It works by using a learnt embedding for each character before passing the
    embeddings through 6 CNN layers.

    Inspired by Zhang et al, 2016
    Character-level Convolutional Networks for Text Classification
    https://arxiv.org/abs/1509.01626
    """

    def __init__(self, hidden_dim, embedding_dim, device=None):
        super(MessageEncoder, self).__init__()

        self.hidden_dim = hidden_dim
        self.msg_edim = embedding_dim

        self.char_lt = nn.Embedding(NUM_CHARS, self.msg_edim, padding_idx=PAD_CHAR)
        self.conv1 = nn.Conv1d(self.msg_edim, self.hidden_dim, kernel_size=7)
        self.conv2_6_fc = nn.Sequential(
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # conv2
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # conv3
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
            nn.ReLU(),
            # conv4
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
            nn.ReLU(),
            # conv5
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
            nn.ReLU(),
            # conv6
            nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3, stride=3),
            # fc receives -- [ B x h_dim x 5 ]
            Flatten(),
            nn.Linear(5 * self.hidden_dim, 2 * self.hidden_dim),
            nn.ReLU(),
            nn.Linear(2 * self.hidden_dim, self.hidden_dim),
        )  # final output -- [ B x h_dim x 5 ]

    def forward(self, inputs):
        T, B, *_ = inputs["message"].shape
        messages = inputs["message"].long().view(T * B, -1)
        # [ T * B x E x 256 ]
        char_emb = self.char_lt(messages).transpose(1, 2)
        char_rep = self.conv2_6_fc(self.conv1(char_emb))
        return char_rep


class BLStatsEncoder(nn.Module):
    """This model encodes the bottom line stats into a fixed size representation.

    It works by simply using two fully-connected layers with ReLU activations.
    """

    def __init__(self, num_features, hidden_dim):
        super(BLStatsEncoder, self).__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.embed_features = nn.Sequential(
            nn.Linear(self.num_features, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
        )

    def forward(self, inputs):
        T, B, *_ = inputs["blstats"].shape

        features = inputs["blstats"][:,:, :NUM_FEATURES]
        # -- [B' x F]
        features = features.view(T * B, -1).float()
        # -- [B x K]
        features_emb = self.embed_features(features)

        assert features_emb.shape[0] == T * B
        return features_emb


class Crop(nn.Module):
    def __init__(self, height, width, height_target, width_target, device=None):
        super(Crop, self).__init__()
        self.width = width
        self.height = height
        self.width_target = width_target
        self.height_target = height_target

        width_grid = self._step_to_range(2 / (self.width - 1), self.width_target)
        self.width_grid = width_grid[None, :].expand(self.height_target, -1)

        height_grid = self._step_to_range(2 / (self.height - 1), height_target)
        self.height_grid = height_grid[:, None].expand(-1, self.width_target)

        if device is not None:
            self.width_grid = self.width_grid.to(device)
            self.height_grid = self.height_grid.to(device)

    def _step_to_range(self, step, num_steps):
        return torch.tensor([step * (i - num_steps // 2) for i in range(num_steps)])

    def forward(self, inputs, coordinates):
        """Calculates centered crop around given x,y coordinates.

        Args:
           inputs [B x H x W] or [B x C x H x W]
           coordinates [B x 2] x,y coordinates

        Returns:
           [B x C x H' x W'] inputs cropped and centered around x,y coordinates.
        """
        if inputs.dim() == 3:
            inputs = inputs.unsqueeze(1).float()

        assert inputs.shape[2] == self.height, "expected %d but found %d" % (
            self.height,
            inputs.shape[2],
        )
        assert inputs.shape[3] == self.width, "expected %d but found %d" % (
            self.width,
            inputs.shape[3],
        )

        x = coordinates[:, 0]
        y = coordinates[:, 1]

        x_shift = 2 / (self.width - 1) * (x.float() - self.width // 2)
        y_shift = 2 / (self.height - 1) * (y.float() - self.height // 2)

        grid = torch.stack(
            [
                self.width_grid[None, :, :] + x_shift[:, None, None],
                self.height_grid[None, :, :] + y_shift[:, None, None],
            ],
            dim=3,
        )

        crop = torch.round(F.grid_sample(inputs, grid, align_corners=True)).squeeze(1)
        return crop


class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
