import torch
import torch.nn.functional as F
from torch import nn

from .. import utils
from . import language

from einops import rearrange

class Actor(nn.Module):
    def __init__(
        self,
        observation_shape,
        num_actions,
        FLAGS,
        vocab=language.VOCAB,
        lang=language.LANG,
        lang_len=language.LANG_LEN,
        lang_templates=language.INSTR_TEMPLATES,
        lang_bert_emb=language.LANG_BERT_EMB,
        use_intrinsic_rewards=None,
    ):
        super().__init__()
        self.observation_shape = observation_shape
        self.num_actions = num_actions
        self.FLAGS = FLAGS

        if use_intrinsic_rewards is None:
            use_intrinsic_rewards = FLAGS.generator
        self.use_intrinsic_rewards = use_intrinsic_rewards

        self.use_index_select = True
        self.obj_dim = 5
        self.col_dim = 3
        self.con_dim = 2
        self.knowledge_dim = self.FLAGS.knowledge_dim
        self.num_channels = self.obj_dim + self.col_dim + self.con_dim + 1

        self.embed_object = nn.Embedding(11, self.obj_dim)
        self.embed_color = nn.Embedding(6, self.col_dim)
        self.embed_contains = nn.Embedding(4, self.con_dim)
        
        if self.FLAGS.mutual_information:
            self.discriminator_embed_object = nn.Embedding(11, self.obj_dim)
            self.discriminator_embed_color = nn.Embedding(6, self.col_dim)
            self.discriminator_embed_contains = nn.Embedding(4, self.con_dim)

        self.vocab = vocab
        self.lang = lang
        self.lang_len = lang_len
        self.lang_templates = lang_templates
        self.lang_bert_emb=lang_bert_emb
        
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_channels, 7, 7))
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=64, kdim=self.num_channels, vdim=self.num_channels, num_heads=1, batch_first=True)
        
        init_ = lambda m: utils.init(
            m,
            nn.init.orthogonal_,
            lambda x: nn.init.constant_(x, 0),
        )
        
        if self.FLAGS.mutual_information:
            self.discriminator_conv = nn.Sequential(
                init_(
                    nn.Conv2d(
                        in_channels=self.num_channels - 1,
                        out_channels=32,
                        kernel_size=(3, 3),
                        stride=2,
                        padding=1,
                    )
                ),
                nn.ELU(),
                init_(
                    nn.Conv2d(
                        in_channels=32,
                        out_channels=32,
                        kernel_size=(3, 3),
                        stride=2,
                        padding=1,
                    )
                ),
                nn.ELU(),
                init_(
                    nn.Conv2d(
                        in_channels=32,
                        out_channels=32,
                        kernel_size=(3, 3),
                        stride=2,
                        padding=1,
                    )
                ),
                nn.ELU(),
                init_(
                    nn.Conv2d(
                        in_channels=32,
                        out_channels=32,
                        kernel_size=(3, 3),
                        stride=2,
                        padding=1,
                    )
                ),
                nn.ELU(),
                init_(
                    nn.Conv2d(
                        in_channels=32,
                        out_channels=32,
                        kernel_size=(3, 3),
                        stride=2,
                        padding=1,
                    )
                ),
                nn.ELU(),
            )
            self.discriminator_lang_proj = nn.Sequential(
                nn.Linear(768, 32),
                nn.ReLU(),
            )
        
        self.feat_extract = nn.Sequential(
            init_(
                nn.Conv2d(
                    in_channels=self.num_channels,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
            init_(
                nn.Conv2d(
                    in_channels=32,
                    out_channels=32,
                    kernel_size=(3, 3),
                    stride=2,
                    padding=1,
                )
            ),
            nn.ELU(),
        )
        fc_input_dim = 32 + self.obj_dim + self.col_dim
        
        fc_input_dim += self.knowledge_dim

        self.fc = nn.Sequential(
            init_(nn.Linear(fc_input_dim, self.FLAGS.state_embedding_dim)),
            nn.ReLU(),
            init_(
                nn.Linear(
                    self.FLAGS.state_embedding_dim, self.FLAGS.state_embedding_dim
                )
            ),
            nn.ReLU(),
        )

        if self.FLAGS.use_lstm:
            self.core = nn.LSTM(
                self.FLAGS.state_embedding_dim,
                self.FLAGS.state_embedding_dim,
                self.FLAGS.num_lstm_layers,
            )

        self.bert_encoder = language.ActorLanguageEncoder(
            input_dim=768, output_dim=self.knowledge_dim
        )

        init_ = lambda m: utils.init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0)
        )

        self.policy = init_(nn.Linear(self.FLAGS.state_embedding_dim, self.num_actions))
        self.baseline = init_(nn.Linear(self.FLAGS.state_embedding_dim, 1))
        if self.FLAGS.int.twoheaded:
            self.int_baseline = init_(nn.Linear(self.FLAGS.state_embedding_dim, 1))

    def initial_state(self, batch_size):
        if not self.FLAGS.use_lstm:
            return tuple()
        return tuple(
            torch.zeros(self.core.num_layers, batch_size, self.core.hidden_size)
            for _ in range(2)
        )

    def create_embeddings(self, x, id, emb_for_discriminator=False):
        """Generates compositional embeddings."""
        if emb_for_discriminator:
            if id == 0:
                objects_emb = self._select(self.discriminator_embed_object, x[:, :, :, id::3])
            elif id == 1:
                objects_emb = self._select(self.discriminator_embed_color, x[:, :, :, id::3])
            elif id == 2:
                objects_emb = self._select(self.discriminator_embed_contains, x[:, :, :, id::3])
        else:
            if id == 0:
                objects_emb = self._select(self.embed_object, x[:, :, :, id::3])
            elif id == 1:
                objects_emb = self._select(self.embed_color, x[:, :, :, id::3])
            elif id == 2:
                objects_emb = self._select(self.embed_contains, x[:, :, :, id::3])
            
        embeddings = torch.flatten(objects_emb, 3, 4)
        return embeddings

    def _select(self, embed, x):
        """Efficient function to get embedding from an index."""
        if self.use_index_select:
            out = embed.weight.index_select(0, x.reshape(-1))
            return out.reshape(x.shape + (-1,))
        else:
            return embed(x)

    def forward_discriminator(self, inputs):
        """ discriminator for mutual information """
        if self.FLAGS.partial_obs:
            x = inputs["partial_frame"]
        else:
            x = inputs["frame"]
            
        self.lang_bert_emb = self.lang_bert_emb.to(inputs["frame"].device)
        T, B, h, w, *_ = x.shape  
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = x.long()
        x = torch.cat(
            [
                self.create_embeddings(x, 0, emb_for_discriminator=True),
                self.create_embeddings(x, 1, emb_for_discriminator=True),
                self.create_embeddings(x, 2, emb_for_discriminator=True),
            ],
            dim=3,
        )
        x = x.transpose(1, 3) ## TB x 10 x N x N
        x = self.discriminator_conv(x) ## TB x 32 x 1 x 1
        x = x.view(T * B, -1) ## TB x 32
        
        knowledge_emb_proj = self.discriminator_lang_proj(self.lang_bert_emb) # 652 x 32
        knowledge_emb_proj = knowledge_emb_proj.unsqueeze(0).expand(x.shape[0], -1, -1) # TB x 652 x 32
        pred_logits = torch.bmm(knowledge_emb_proj, x.unsqueeze(-1)).squeeze(-1)
        
        return {
            "logits": pred_logits,
            "preds": torch.argmax(pred_logits, dim=1),
        }
        
    def forward(self, inputs, core_state=(), knowledge=None):
        if knowledge is None:
            knowledge = []

        if self.FLAGS.partial_obs:
            x = inputs["partial_frame"]
        else:
            x = inputs["frame"]
        T, B, h, w, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        knowledge = torch.flatten(knowledge, 0, 1)

        knowledge_channel = torch.zeros_like(x, requires_grad=False)
        if self.FLAGS.partial_obs: 
            knowledge_channel = knowledge_channel[:, :, :, :1]
        else:
            knowledge_channel = torch.flatten(knowledge_channel, 1, 2)[:, :, 0]
            knowledge_channel = knowledge_channel.view(T * B, h, w, 1)

        carried_col = inputs["carried_col"]
        carried_obj = inputs["carried_obj"]

        x = x.long()
        knowledge = knowledge.long()
        carried_obj = carried_obj.long()
        carried_col = carried_col.long()
        x = torch.cat(
            [
                self.create_embeddings(x, 0),
                self.create_embeddings(x, 1),
                self.create_embeddings(x, 2),
                knowledge_channel.float(),
            ],
            dim=3,
        )
        carried_obj_emb = self._select(self.embed_object, carried_obj)
        carried_col_emb = self._select(self.embed_color, carried_col)

        visual_embedding = x.transpose(1, 3)
        
        carried_obj_emb = carried_obj_emb.view(T * B, -1)
        carried_col_emb = carried_col_emb.view(T * B, -1)
        
        x = self.feat_extract(visual_embedding)
        x = x.view(T * B, -1)
        union = torch.cat([x, carried_obj_emb, carried_col_emb], dim=1)

        if self.use_intrinsic_rewards:
            knowledge_emb = self.embed_knowledge(knowledge)
            if self.FLAGS.actor_cross_attn:
                visual_embedding_with_pos = visual_embedding + self.pos_embedding
                knowledge_emb = rearrange(knowledge_emb, "TB N -> TB 1 N")
                visual_embedding_with_pos = rearrange(visual_embedding_with_pos, "TB N H W -> TB (H W) N")
                attn_output, _ = self.cross_attn(query=knowledge_emb, key=visual_embedding_with_pos, value=visual_embedding_with_pos)
                if self.FLAGS.attn_skip:
                    attn_output = attn_output + knowledge_emb
                knowledge_emb = rearrange(attn_output, "TB 1 N -> TB N")
                
        else:
            knowledge_emb = torch.zeros((union.shape[0], self.knowledge_dim)).to(union.device)
        union = torch.cat([union, knowledge_emb], dim=1)

        core_input = self.fc(union)

        if self.FLAGS.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input
            core_state = tuple()

        policy_logits = self.policy(core_output)
        baseline = self.baseline(core_output)

        if self.training:
            action = torch.multinomial(F.softmax(policy_logits, dim=1), num_samples=1)
        else:
            action = torch.argmax(policy_logits, dim=1)

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        output = dict(policy_logits=policy_logits, baseline=baseline, action=action)
        if self.FLAGS.int.twoheaded:
            int_baseline = self.int_baseline(core_output)
            output.update(int_baseline=int_baseline.view(T, B))

        return (
            output,
            core_state,
        )

    def embed_knowledge(self, knowledge):
        self.lang_bert_emb = self.lang_bert_emb.to(knowledge.device)
        raw_knowledge_emb = self.lang_bert_emb[knowledge]
        knowledge_emb = self.bert_encoder(raw_knowledge_emb)
        return knowledge_emb
