import torch
import torch.nn.functional as F
from minihack.agent.common.models.embed import GlyphEmbedding
from minihack.agent.common.models.transformer import TransformerEncoder
from minihack.agent.polybeast.models.base import NUM_FEATURES, Crop, NetHackNet
from torch import nn

from . import message as message_models


class Student(NetHackNet):
    """
    This is the BaseNet minihack model, with modifications
    """

    def __init__(
        self, observation_shape, num_actions, flags, device, use_intrinsic_rewards=True
    ):
        super(Student, self).__init__()

        self.flags = flags

        self.observation_shape = observation_shape
        self.use_intrinsic_rewards = use_intrinsic_rewards

        self.H = observation_shape[0]
        self.W = observation_shape[1]

        self.num_actions = num_actions
        self.use_lstm = flags.minihack.use_lstm

        self.k_dim = flags.minihack.embedding_dim
        self.h_dim = flags.minihack.hidden_dim

        self.crop_model = flags.minihack.crop_model
        self.crop_dim = flags.minihack.crop_dim

        self.num_features = NUM_FEATURES

        self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim, device)

        self.glyph_type = flags.minihack.glyph_type
        self.glyph_embedding = GlyphEmbedding(
            flags.minihack.glyph_type,
            flags.minihack.embedding_dim,
            device,
            flags.minihack.use_index_select,
        )

        if flags.language_goals is None:
            K = (
                flags.minihack.embedding_dim + 1
            )  # number of input filters + goal channel
        else:
            K = flags.minihack.embedding_dim  # number of input filters
        FDIM = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        self.Y = 8  # number of output filters
        L = flags.minihack.layers  # number of convnet layers

        in_channels = [K] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [self.Y]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(FDIM, FDIM),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        self.extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )

        if self.crop_model == "transformer":
            self.extract_crop_representation = TransformerEncoder(
                K,
                N=L,
                heads=8,
                height=self.crop_dim,
                width=self.crop_dim,
                device=device,
            )
        elif self.crop_model == "cnn":
            conv_extract_crop = [
                nn.Conv2d(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    kernel_size=(FDIM, FDIM),
                    stride=S,
                    padding=P,
                )
                for i in range(L)
            ]

            self.extract_crop_representation = nn.Sequential(
                *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract))
            )

        # MESSAGING MODEL
        self.msg_model = flags.minihack.msg.model
        self.msg_hdim = flags.minihack.msg.hidden_dim
        self.msg_edim = flags.minihack.msg.embedding_dim
        if self.msg_model == "word_gru":
            self.msg_net = message_models.WordGRU(
                self.msg_hdim, self.msg_edim, vocab_size=30522
            )
        elif self.msg_model != "none":
            raise NotImplementedError

        self.embed_features = nn.Sequential(
            nn.Linear(self.num_features, self.k_dim),
            nn.ReLU(),
            nn.Linear(self.k_dim, self.k_dim),
            nn.ReLU(),
        )

        # just added up the output dimensions of the input featurizers
        # feature / status dim
        out_dim = self.k_dim
        # CNN over full glyph map
        out_dim += self.H * self.W * self.Y
        if self.crop_model == "transformer":
            out_dim += self.crop_dim ** 2 * K
        elif self.crop_model == "cnn":
            out_dim += self.crop_dim ** 2 * self.Y
        # messaging model
        if self.msg_model != "none":
            out_dim += self.msg_hdim
        # ==== INTRINSIC GOAL ====
        if flags.language_goals is not None:
            # intrinsic goal (+ diff repr)
            out_dim += 2 * self.msg_hdim

        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

        if self.use_lstm:
            self.core = nn.LSTM(self.h_dim, self.h_dim, num_layers=1)

        self.policy = nn.Linear(self.h_dim, self.num_actions)
        self.baseline = nn.Linear(self.h_dim, 1)
        if self.flags.int.twoheaded:
            self.int_baseline = nn.Linear(self.h_dim, 1)

    def initial_state(self, batch_size=1):
        if not self.use_lstm:
            return tuple()
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def prepare_input(self, inputs):
        # -- [T x B x H x W]
        T, B, H, W = inputs["glyphs"].shape

        # take our chosen glyphs and merge the time and batch

        glyphs = self.glyph_embedding.prepare_input(inputs)

        # -- [T x B x F]
        features = inputs["blstats"]
        # -- [B' x F]
        features = features.view(T * B, -1).float()

        return glyphs, features

    def create_goal_channel(self, goal):
        goal = goal.view(-1)
        if self.use_intrinsic_rewards:
            goal_channel = F.one_hot(goal, num_classes=self.H * self.W)
            goal_channel = goal_channel.view(goal_channel.shape[0], self.H, self.W)
        else:
            # Zeros
            goal_channel = torch.zeros(
                (goal.shape[0], self.H, self.W), device=goal.device, dtype=torch.int64
            )
        return goal_channel

    def forward(self, inputs, core_state, goal, learning=False):
        T, B, *_ = inputs["glyphs"].shape

        glyphs, features = self.prepare_input(inputs)

        # -- [B x 2] x,y coordinates
        coordinates = features[:, :2]

        features = features.view(T * B, -1).float()
        # -- [B x K]
        features_emb = self.embed_features(features)

        assert features_emb.shape[0] == T * B

        reps = [features_emb]

        # -- [B x H' x W']
        crop = self.glyph_embedding.GlyphTuple(
            *[self.crop(g, coordinates) for g in glyphs]
        )
        # -- [B x H' x W' x K]
        crop_emb = self.glyph_embedding(crop)

        if self.flags.language_goals is None:
            goal_channel = self.create_goal_channel(goal)
            goal_channel_crop = self.crop(goal_channel, coordinates).float()

            goal_channel = goal_channel.unsqueeze(-1).float()
            goal_channel_crop = goal_channel_crop.unsqueeze(-1).float()
            crop_emb = torch.cat([crop_emb, goal_channel_crop], -1)

        if self.crop_model == "transformer":
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb, mask=None)
        elif self.crop_model == "cnn":
            # -- [B x K x W' x H']
            crop_emb = crop_emb.transpose(1, 3)
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb)
        # -- [B x K']

        crop_rep = crop_rep.view(T * B, -1)
        assert crop_rep.shape[0] == T * B

        reps.append(crop_rep)

        # -- [B x H x W x K]
        glyphs_emb = self.glyph_embedding(glyphs)
        if self.flags.language_goals is None:
            glyphs_emb = torch.cat([glyphs_emb, goal_channel], -1)
        # glyphs_emb = self.embed(glyphs)
        # -- [B x K x W x H]
        glyphs_emb = glyphs_emb.transpose(1, 3)
        # -- [B x W x H x K]
        glyphs_rep = self.extract_representation(glyphs_emb)

        # -- [B x K']
        glyphs_rep = glyphs_rep.view(T * B, -1)

        assert glyphs_rep.shape[0] == T * B

        # -- [B x K'']
        reps.append(glyphs_rep)

        # MESSAGING MODEL
        if self.msg_model != "none":
            # Embed split messages, then average
            char_rep = self.embed_split_messages(
                inputs["split_messages"], inputs["split_messages_len"]
            )
            reps.append(char_rep)

        # INTRINSIC GOAL
        if self.flags.language_goals is not None:
            # Add intrinsic goal rep
            if self.use_intrinsic_rewards:
                goal_rep = self.embed_message(goal)
                goal_diff_rep = goal_rep - char_rep
            else:
                # FIXME - use batch size here, not char_rep.
                goal_rep = torch.zeros_like(char_rep)
                goal_diff_rep = torch.zeros_like(char_rep)
            reps.append(goal_rep)
            reps.append(goal_diff_rep)

        st = torch.cat(reps, dim=1)

        # -- [B x K]
        st = self.fc(st)

        if self.use_lstm:
            core_input = st.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * t for t in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = st

        # -- [B x A]
        policy_logits = self.policy(core_output)
        # -- [B x A]
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            # Don't sample when testing.
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        output = dict(policy_logits=policy_logits, baseline=baseline, action=action)
        if self.flags.int.twoheaded:
            int_baseline = self.int_baseline(core_output)
            output.update(int_baseline=int_baseline.view(T, B))

        return (output, core_state)

    def embed_split_messages(self, split_messages, split_messages_len=None):
        T, B, N, *_ = split_messages.shape
        split_messages = split_messages.view(T * B * N, -1)
        if split_messages_len is not None:
            split_messages_len = split_messages_len.view(T * B * N)
        rep = self.embed_message(split_messages, split_messages_len)
        rep = rep.view(T * B, N, -1).mean(-2)  # Average across the N messages
        return rep

    def embed_message(self, messages, message_len=None):
        if self.msg_model == "none":
            raise RuntimeError("No message model")

        if message_len is None:
            # Create it yourself
            message_len = (messages != 0).long().sum(-1, keepdim=True)

        if messages.ndim > 2:
            T, B, *_ = messages.shape
            # [T x B x 256] -> [T * B x 256]
            messages = messages.long().view(T * B, -1)
            message_len = message_len.long().view(T * B)

        return self.msg_net(messages, message_len)
