"""
Goal-generating teachers.
"""


import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical

from .. import utils
from . import language


def init_(m):
    return utils.init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0))


class Teacher(nn.Module):
    """Constructs the Teacher Policy which takes an initial observation and produces a goal."""

    def __init__(
        self,
        observation_shape,
        width,
        height,
        FLAGS,
        hidden_dim=256,
        vocab=language.VOCAB,
        lang=language.LANG,
        lang_len=language.LANG_LEN,
        lang_templates=language.INSTR_TEMPLATES,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.height = height
        self.width = width
        self.env_dim = self.width * self.height
        self.state_embedding_dim = 256
        self.FLAGS = FLAGS

        self.vocab = vocab
        self.lang = lang
        self.lang_len = lang_len
        self.lang_templates = lang_templates

        self.use_index_select = True
        self.obj_dim = 5
        self.col_dim = 3
        self.con_dim = 2
        self.num_channels = self.obj_dim + self.col_dim + self.con_dim

        self.embed_object = nn.Embedding(11, self.obj_dim)
        self.embed_color = nn.Embedding(6, self.col_dim)
        self.embed_contains = nn.Embedding(4, self.con_dim)

        K = self.num_channels  # number of input filters
        F = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        L = 4  # number of convnet layers
        E = 1  # output of last layer

        in_channels = [K] + [M] * 4
        out_channels = [M] * 3 + [E]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(F, F),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )
        # Results in representations of size 16
        self.extract_representation = extract_representation[:6]
        # Produces logits
        self.repr2logits = extract_representation[6:]

        self.out_dim = self.env_dim * 16 + self.obj_dim + self.col_dim

        self.baseline_teacher = init_(nn.Linear(16 * self.env_dim, 1))
        self.logits_size = self.env_dim
        self.raw_goal_size = 1

    def _select(self, embed, x):
        """Efficient function to get embedding from an index."""
        if self.use_index_select:
            out = embed.weight.index_select(0, x.reshape(-1))
            # handle reshaping x to 1-d and output back to N-d
            return out.reshape(x.shape + (-1,))
        else:
            return embed(x)

    def create_embeddings(self, x, id):
        """Generates compositional embeddings."""
        if id == 0:
            objects_emb = self._select(self.embed_object, x[:, :, :, id::3])
        elif id == 1:
            objects_emb = self._select(self.embed_color, x[:, :, :, id::3])
        elif id == 2:
            objects_emb = self._select(self.embed_contains, x[:, :, :, id::3])
        embeddings = torch.flatten(objects_emb, 3, 4)
        return embeddings

    def embed_environment(self, inputs):
        """
        Create environment representations for each (x, y) location with a
        dimensionality-preserving ConvNet.
        These are used to produce logits in the forward method of the standard
        (x, y)-generating teacher.
        """
        x = inputs["frame"]
        T, B, *_ = x.shape

        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.long()
        x = torch.cat(
            [
                self.create_embeddings(x, 0),
                self.create_embeddings(x, 1),
                self.create_embeddings(x, 2),
            ],
            dim=3,
        )

        x = x.transpose(1, 3)
        x = self.extract_representation(x)
        return x

    def forward(self, inputs, prev_goals=None):
        """Main Function, takes an observation and returns a goal."""
        T, B, *_ = inputs["frame"].shape
        x = self.embed_environment(inputs)
        full_env_repr = x.view(T * B, -1)
        x = self.repr2logits(x)

        generator_logits = x.view(T * B, -1)

        generator_baseline = self.baseline_teacher(full_env_repr)

        goal = torch.multinomial(F.softmax(generator_logits, dim=1), num_samples=1)

        generator_logits = generator_logits.view(T, B, -1)
        generator_baseline = generator_baseline.view(T, B)
        goal = goal.view(T, B)

        return dict(
            goal=goal,
            raw_goal=goal,
            generator_logits=generator_logits,
            generator_baseline=generator_baseline,
        )


class LanguageTeacher(Teacher):
    def __init__(
        self,
        *args,
        hidden_dim=256,
        **kwargs,
    ):
        super().__init__(*args, **kwargs, hidden_dim=hidden_dim)
        # Language instr encoder...
        # Self attention over image regions perhaps?
        self.language_encoder = language.LanguageEncoder(
            self.vocab, embedding_dim=64, hidden_dim=hidden_dim
        )
        self.instr_size = 16
        self.bilinear = nn.Linear(hidden_dim, self.instr_size, bias=False)
        # Input to baseline is (1) full final repr or environment and (2) logits over goals
        self.baseline_teacher = init_(
            nn.Linear(16 * self.env_dim + self.lang.shape[0], 1)
        )

        self.logits_size = self.lang.shape[0]

        if self.FLAGS.language_goals in {"online_naive", "online_grounding"}:
            # Track current set of possible goals to propose. We "reveal" the
            # goals as the student encounters them. For BabyAI, the max number
            # of goals is given,
            self.register_buffer("goals", self.lang)
            self.register_buffer(
                "goals_mask", torch.zeros(self.goals.shape[0], dtype=torch.bool)
            )
            # Goal 0 is the null instr
            self.goals_mask[0] = True
        else:
            self.register_buffer(
                "goals_mask", torch.ones(self.lang.shape[0], dtype=torch.bool)
            )

        if self.FLAGS.language_goals == "online_grounding":
            self.grounding_head = nn.Linear(hidden_dim, self.instr_size)

    def predict_achievable(self, inputs=None, env_emb=None, goal_emb=None):
        if inputs is None:
            assert env_emb is not None and goal_emb is not None
        else:
            # Produce the goal emb ourselves.
            assert env_emb is None and goal_emb is None
            T, B, *_ = inputs["frame"].shape

            # move to device (if it's not already on-device)
            self.lang = self.lang.to(inputs["frame"].device)
            self.lang_len = self.lang_len.to(inputs["frame"].device)
            full_env_emb = self.embed_environment(inputs)
            # Average across pixels. Could do attention later, etc.
            env_emb = full_env_emb.mean(-1).mean(-1)

            goal_emb = self.language_encoder(self.lang, self.lang_len)
        goal_emb_proj = self.grounding_head(goal_emb)
        goal_emb_proj = goal_emb_proj.unsqueeze(0).expand(env_emb.shape[0], -1, -1)
        goal_logits = torch.bmm(goal_emb_proj, env_emb.unsqueeze(-1)).squeeze(-1)
        return {
            "logits": goal_logits,
            "preds": (goal_logits > 0),
        }

    def update_goals(self, batch):
        goals_seen_this_batch = batch["subgoal_done"].any(0).any(0).bool()
        self.goals_mask |= goals_seen_this_batch

    def forward_grounder(self, inputs):
        T, B, *_ = inputs["frame"].shape

        # move to device (if it's not already on-device)
        self.lang = self.lang.to(inputs["frame"].device)
        self.lang_len = self.lang_len.to(inputs["frame"].device)
        full_env_emb = self.embed_environment(inputs)
        env_emb = full_env_emb.mean(-1).mean(-1)

        goal_emb = self.language_encoder(self.lang, self.lang_len)
        grounding_output = self.predict_achievable(env_emb=env_emb, goal_emb=goal_emb)
        return grounding_output

    def forward(self, inputs, prev_goals=None):
        T, B, *_ = inputs["frame"].shape

        # move to device (if it's not already on-device)
        self.lang = self.lang.to(inputs["frame"].device)
        self.lang_len = self.lang_len.to(inputs["frame"].device)
        full_env_emb = self.embed_environment(inputs)
        env_emb = full_env_emb.mean(-1).mean(-1)

        goal_emb = self.language_encoder(self.lang, self.lang_len)
        goal_emb_proj = self.bilinear(goal_emb)
        # Expand to this batch size (usually 1)
        goal_emb_proj = goal_emb_proj.unsqueeze(0).expand(env_emb.shape[0], -1, -1)
        goal_logits = torch.bmm(goal_emb_proj, env_emb.unsqueeze(-1)).squeeze(-1)

        if np.random.random() < self.FLAGS.generator_eps:
            # Sample uniformly.
            goal_logits_masked = torch.zeros_like(goal_logits)
        else:
            goal_logits_masked = goal_logits.detach().clone()

        subgoal_mask = None
        grounding_output = None
        if self.FLAGS.language_goals == "achievable":
            # Filter goals to those that are achievable
            subgoal_mask = inputs["subgoal_achievable"].bool()
        elif self.FLAGS.language_goals == "online_naive":
            # Use current achievable subgoal mask
            subgoal_mask = self.goals_mask
        elif self.FLAGS.language_goals == "online_grounding":
            with torch.no_grad():
                grounding_output = self.predict_achievable(
                    env_emb=env_emb, goal_emb=goal_emb
                )
            if self.FLAGS.hard_grounding:
                subgoal_mask = grounding_output["preds"] & self.goals_mask
                # To avoid nan issues, subgoal mask is always 1 for the null goal
                subgoal_mask[:, 0] = True
            else:
                subgoal_mask = (
                    torch.sigmoid(grounding_output["logits"]) * self.goals_mask
                )
        else:
            raise NotImplementedError

        if self.FLAGS.hard_grounding:
            if subgoal_mask is not None:
                if subgoal_mask.ndim == 3:
                    subgoal_mask = subgoal_mask[0]
                    goal_logits_masked[~subgoal_mask] = -np.inf
                elif subgoal_mask.ndim == 2:
                    goal_logits_masked[~subgoal_mask] = -np.inf
                else:
                    goal_logits_masked[:, ~subgoal_mask] = -np.inf

            goal_dist = Categorical(logits=goal_logits_masked)
        else:
            # Actual produce probabilities and do the multiplication
            goal_probs_masked = goal_logits_masked.softmax(-1)
            if subgoal_mask is not None:
                if subgoal_mask.ndim == 3:
                    subgoal_mask = subgoal_mask[0]
                    goal_probs_masked *= subgoal_mask.float()
                elif subgoal_mask.ndim == 2:
                    goal_probs_masked *= subgoal_mask.float()
                else:
                    goal_probs_masked *= subgoal_mask.unsqueeze(0).float()
            goal_probs_masked = goal_probs_masked / goal_probs_masked.sum(
                -1, keepdim=True
            )

            goal_dist = Categorical(probs=goal_probs_masked)

        goal = goal_dist.sample()

        # Baseline
        full_env_emb_flat = full_env_emb.view(T * B, -1)
        baseline_input = torch.cat((full_env_emb_flat, goal_logits), -1)
        goal_baseline = self.baseline_teacher(baseline_input)

        # Reshape for time x batch
        goal_logits = goal_logits.view(T, B, -1)
        goal_baseline = goal_baseline.view(T, B)
        goal = goal.view(T, B)

        if self.FLAGS.language_goals == "online_grounding":
            grounder_logits = subgoal_mask.view(T, B, -1)
        else:
            grounder_logits = torch.ones_like(goal_logits)

        return dict(
            goal=goal,
            raw_goal=goal,
            generator_logits=goal_logits,
            generator_baseline=goal_baseline,
            grounding_output=grounding_output,
            grounder_logits=grounder_logits,
        )


class LangPredictor(nn.Module):
    def __init__(
        self,
        FLAGS,
        vocab=language.VOCAB,
        lang=language.LANG,
        lang_len=language.LANG_LEN,
        lang_templates=language.INSTR_TEMPLATES,
        hidden_dim=256,
    ):
        super().__init__()
        self.lang = lang
        self.lang_len = lang_len
        self.vocab = vocab
        self.language_encoder = language.LanguageEncoder(
            self.vocab, embedding_dim=64, hidden_dim=hidden_dim
        )
        self.FLAGS = FLAGS

    def forward(self, inputs):
        x = inputs["partial_frame"]
        T, B, *_ = x.shape

        subgoal_done = inputs["subgoal_done"].float().view(T * B, -1)

        self.lang = self.lang.to(self.FLAGS.device)
        self.lang_len = self.lang_len.to(self.FLAGS.device)

        goal_emb = self.language_encoder(self.lang, self.lang_len)
        goal_emb_proj = goal_emb.unsqueeze(0).expand(T * B, -1, -1)
        # Now do a mean
        goal_emb_sum = (goal_emb_proj * subgoal_done.unsqueeze(-1)).sum(-2)
        mask = subgoal_done.sum(-1, keepdim=True)
        # Divide by number of embeddings
        goal_emb_mean = goal_emb_sum / torch.clamp(mask, min=1)
        # This gives you an average embedding
        goal_emb = goal_emb_mean.view(T, B, -1)
        return goal_emb


class Predictor(nn.Module):
    def __init__(self, observation_shape, flags):
        super(Predictor, self).__init__()
        self.observation_shape = observation_shape
        self.flags = flags

        init_ = lambda m: utils.init(
            m,
            nn.init.orthogonal_,
            lambda x: nn.init.constant_(x, 0),
            nn.init.calculate_gain("relu"),
        )

        self.feat_extract = nn.Sequential(
            init_(
                nn.Conv2d(
                    in_channels=self.observation_shape[2],
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=128,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
        )

        if self.flags.noveld_combined:
            self.lang = language.LANG
            self.lang_len = language.LANG_LEN
            self.lang_templates = language.INSTR_TEMPLATES
            self.vocab = language.VOCAB
            self.language_encoder = language.LanguageEncoder(
                self.vocab, embedding_dim=64, hidden_dim=256
            )

    def forward(self, inputs):
        # -- [unroll_length x batch_size x height x width x channels]
        x = inputs["partial_frame"]
        T, B, *_ = x.shape

        # -- [unroll_length*batch_size x height x width x channels]
        x = torch.flatten(x, 0, 1)  # Merge time and batch.

        x = x.float() / 255.0

        # -- [unroll_length*batch_size x channels x width x height]
        x = x.transpose(1, 3)
        x = self.feat_extract(x)

        state_embedding = x.view(T, B, -1)
        if self.flags.noveld_combined:
            subgoal_done = inputs["subgoal_done"].float().view(T * B, -1)

            self.lang = self.lang.to(self.flags.device)
            self.lang_len = self.lang_len.to(self.flags.device)

            goal_emb = self.language_encoder(self.lang, self.lang_len)
            goal_emb_proj = goal_emb.unsqueeze(0).expand(T * B, -1, -1)
            # Now do a mean
            goal_emb_sum = (goal_emb_proj * subgoal_done.unsqueeze(-1)).sum(-2)
            mask = subgoal_done.sum(-1, keepdim=True)
            # Divide by number of embeddings
            goal_emb_mean = goal_emb_sum / torch.clamp(mask, min=1)
            # This gives you an average embedding
            goal_emb = goal_emb_mean.view(T, B, -1)

            state_embedding = torch.cat([state_embedding, goal_emb], -1)

        return state_embedding
