import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical

from .. import utils
from . import language
from einops import rearrange

def init_(m):
    return utils.init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0))


class PlannerBase(nn.Module):
    def __init__(
        self,
        observation_shape,
        width,
        height,
        FLAGS,
        hidden_dim=256,
        vocab=language.VOCAB,
        lang=language.LANG,
        lang_len=language.LANG_LEN,
        lang_templates=language.INSTR_TEMPLATES,
        lang_bert_emb=language.LANG_BERT_EMB,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.height = height
        self.width = width
        self.env_dim = self.width * self.height
        self.state_embedding_dim = 256
        self.FLAGS = FLAGS

        self.vocab = vocab
        self.lang = lang
        self.lang_len = lang_len
        self.lang_templates = lang_templates
        self.lang_bert_emb=lang_bert_emb

        self.use_index_select = True
        self.obj_dim = 5
        self.col_dim = 3
        self.con_dim = 2
        self.num_channels = self.obj_dim + self.col_dim + self.con_dim

        self.embed_object = nn.Embedding(11, self.obj_dim)
        self.embed_color = nn.Embedding(6, self.col_dim)
        self.embed_contains = nn.Embedding(4, self.con_dim)

        K = self.num_channels 
        F = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        L = 4  # number of convnet layers
        E = 1  # output of last layer

        in_channels = [K] + [M] * 4
        out_channels = [M] * 3 + [E]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(F, F),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )
        self.extract_representation = extract_representation[:6]
        self.repr2logits = extract_representation[6:]

        self.out_dim = self.env_dim * 16 + self.obj_dim + self.col_dim

        self.baseline_planner = init_(nn.Linear(16 * self.env_dim, 1))
        self.logits_size = self.env_dim
        self.raw_goal_size = 1

    def _select(self, embed, x):
        """Efficient function to get embedding from an index."""
        if self.use_index_select:
            out = embed.weight.index_select(0, x.reshape(-1))
            return out.reshape(x.shape + (-1,))
        else:
            return embed(x)

    def create_embeddings(self, x, id):
        """Generates compositional embeddings."""
        if id == 0:
            objects_emb = self._select(self.embed_object, x[:, :, :, id::3])
        elif id == 1:
            objects_emb = self._select(self.embed_color, x[:, :, :, id::3])
        elif id == 2:
            objects_emb = self._select(self.embed_contains, x[:, :, :, id::3])
        embeddings = torch.flatten(objects_emb, 3, 4)
        return embeddings

    def embed_environment(self, inputs):
        x = inputs["frame"]
        T, B, *_ = x.shape

        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.long()
        x = torch.cat(
            [
                self.create_embeddings(x, 0),
                self.create_embeddings(x, 1),
                self.create_embeddings(x, 2),
            ],
            dim=3,
        )

        x = x.transpose(1, 3)
        x_conv = self.extract_representation(x)
        return x, x_conv

    def forward(self, inputs):
        T, B, *_ = inputs["frame"].shape
        _, x = self.embed_environment(inputs)
        full_env_repr = x.view(T * B, -1)
        x = self.repr2logits(x)

        generator_logits = x.view(T * B, -1)

        generator_baseline = self.baseline_planner(full_env_repr)

        goal = torch.multinomial(F.softmax(generator_logits, dim=1), num_samples=1)

        generator_logits = generator_logits.view(T, B, -1)
        generator_baseline = generator_baseline.view(T, B)
        goal = goal.view(T, B)

        return dict(
            goal=goal,
            raw_goal=goal,
            generator_logits=generator_logits,
            generator_baseline=generator_baseline,
        )


class Planner(PlannerBase):
    def __init__(
        self,
        *args,
        hidden_dim=256,
        **kwargs,
    ):
        super().__init__(*args, **kwargs, hidden_dim=hidden_dim)
        self.bert_encoder = language.PlannerLangugaeEncoder(
            input_dim=768, output_dim=hidden_dim
        )
        self.instr_size = 16
        self.bilinear = nn.Linear(hidden_dim, self.instr_size, bias=False)
        self.baseline_planner = init_(
            nn.Linear(16 * self.env_dim + self.lang.shape[0], 1)
        )
        
        self.pos_embedding = nn.Parameter(torch.randn(1, self.instr_size, self.height, self.width))
        self.cross_attn = nn.MultiheadAttention(embed_dim=self.instr_size, num_heads=1, batch_first=True)
        self.knowledge_proj = nn.Linear(self.instr_size, 1)

        self.logits_size = self.lang.shape[0]

        self.register_buffer("goals", self.lang)
        self.register_buffer(
            "goals_mask", torch.zeros(self.goals.shape[0], dtype=torch.bool)
        )
        self.goals_mask[0] = True

    def update_goals(self, batch):
        goals_seen_this_batch = batch["subgoal_done"].any(0).any(0).bool()
        self.goals_mask |= goals_seen_this_batch

    def forward(self, inputs):
        T, B, *_ = inputs["frame"].shape

        self.lang = self.lang.to(inputs["frame"].device)
        self.lang_len = self.lang_len.to(inputs["frame"].device)
        self.lang_bert_emb = self.lang_bert_emb.to(inputs["frame"].device)
        
        vision_emb, full_env_emb = self.embed_environment(inputs)
        env_emb = full_env_emb.mean(-1).mean(-1)

        knowledge_emb = self.bert_encoder(self.lang_bert_emb)
        knowledge_emb_proj = self.bilinear(knowledge_emb)
        knowledge_emb_proj = knowledge_emb_proj.unsqueeze(0).expand(env_emb.shape[0], -1, -1)
        
        if self.FLAGS.planner_cross_attn:
            visual_embedding_with_pos = full_env_emb + self.pos_embedding
            visual_embedding_with_pos = rearrange(visual_embedding_with_pos, "TB N H W -> TB (H W) N")
            
            attn_output, _ = self.cross_attn(query=knowledge_emb_proj, key=visual_embedding_with_pos, value=visual_embedding_with_pos)
            if self.FLAGS.attn_skip:
                attn_output = attn_output + knowledge_emb_proj
            
            goal_logits = self.knowledge_proj(attn_output).squeeze(-1)

        else:
            goal_logits = torch.bmm(knowledge_emb_proj, env_emb.unsqueeze(-1)).squeeze(-1)

        if np.random.random() < self.FLAGS.generator_eps:
            goal_logits_masked = torch.zeros_like(goal_logits)
        else:
            goal_logits_masked = goal_logits.detach().clone()

        subgoal_mask = None
        
        achievable_mask = inputs["subgoal_achievable"].bool()
        subgoal_mask = self.goals_mask
        
        if subgoal_mask is not None:
            if subgoal_mask.ndim == 3:
                subgoal_mask = subgoal_mask[0]
                goal_logits_masked[~subgoal_mask] = -np.inf
            elif subgoal_mask.ndim == 2:
                goal_logits_masked[~subgoal_mask] = -np.inf
            else:
                goal_logits_masked[:, ~subgoal_mask] = -np.inf
                
        if self.FLAGS.achievable_mask:
            if achievable_mask.ndim == 3:
                achievable_mask = achievable_mask[0]
                goal_logits_masked[~achievable_mask] = -np.inf
            elif achievable_mask.ndim == 2:
                goal_logits_masked[~achievable_mask] = -np.inf
            else:
                goal_logits_masked[:, ~achievable_mask] = -np.inf

        goal_dist = Categorical(logits=goal_logits_masked)    
        goal = goal_dist.sample()

        # Baseline
        full_env_emb_flat = full_env_emb.view(T * B, -1)
        baseline_input = torch.cat((full_env_emb_flat, goal_logits), -1)
        goal_baseline = self.baseline_planner(baseline_input)

        # Reshape for time x batch
        goal_logits = goal_logits.view(T, B, -1)
        goal_baseline = goal_baseline.view(T, B)
        goal = goal.view(T, B)

        return dict(
            goal=goal,
            raw_goal=goal,
            generator_logits=goal_logits,
            generator_baseline=goal_baseline,
        )