"""
Language utilities
"""

import torch
from babyai.levels.verifier import INSTRS
from gym_minigrid.minigrid import COLOR_NAMES
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

PAD_INDEX = 0
SOS_INDEX = 1
EOS_INDEX = 2


_W2I_STARTER = {
    PAD_TOKEN: PAD_INDEX,
    SOS_TOKEN: SOS_INDEX,
    EOS_TOKEN: EOS_INDEX,
}


def get_lang(env, FLAGS):
    lang = {
        "vocab": VOCAB,
        "lang": LANG,
        "lang_len": LANG_LEN,
        "lang_templates": INSTR_TEMPLATES,
    }
    return lang


class W2I:
    def __init__(self):
        self._w2i = _W2I_STARTER.copy()

    def __getitem__(self, i):
        if i not in self._w2i:
            self._w2i[i] = len(self._w2i)
        return self._w2i[i]

    def __len__(self):
        return len(self._w2i)

    def __dict__(self):
        return self._w2i

    def to_vocab(self):
        return {
            "w2i": self._w2i,
            "i2w": {v: k for k, v in self._w2i.items()},
            "size": len(self._w2i),
        }


def preprocess_instrs(instrs):
    # Preprocess instructions
    w2i = W2I()

    instrs_raw = [i.surface(None) for i in instrs]
    instrs_raw = [i.split(" ") for i in instrs_raw]
    
    # print(len(instrs_raw)) ## 652

    lang = []
    lang_len = []

    for instr in instrs_raw:
        lang_n = [SOS_INDEX]
        for tok in instr:
            lang_n.append(w2i[tok])
        lang_n.append(EOS_INDEX)

        lang.append(torch.tensor(lang_n))
        lang_len.append(len(lang_n))

    lang = pad_sequence(lang, batch_first=True)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def preprocess_instrs_onehot(instrs):
    """Preprocess instrs but ignore the actual language content."""
    w2i = W2I()

    lang = []
    lang_len = []

    for instr_onehot_id, _ in enumerate(instrs):
        lang_n = [SOS_INDEX, w2i[str(instr_onehot_id)], EOS_INDEX]

        lang.append(torch.tensor(lang_n))
        lang_len.append(len(lang_n))

    lang = pad_sequence(lang, batch_first=True)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def get_instr_templates(instrs):
    instrs_raw = [i.surface(None) for i in instrs]
    instrs_raw = [i.split(" ") for i in instrs_raw]
    templates = []
    for instr in instrs_raw:
        template = []
        for tok in instr:
            if tok in COLOR_NAMES:
                template.append("C")
            else:
                template.append(tok)
        templates.append(" ".join(template))
    return templates

def preprocess_instrs_bert(lang, lang_len):
    with torch.no_grad():
        file_path = "./data/babyai/{}.pt".format("INSTRS")
        out = torch.load(file_path)
        return out

LANG, LANG_LEN, VOCAB = preprocess_instrs(INSTRS)

LANG_BERT_EMB = preprocess_instrs_bert(LANG, LANG_LEN)

LANG_ONEHOT, LANG_LEN_ONEHOT, VOCAB_ONEHOT = preprocess_instrs_onehot(INSTRS)
INSTR_TEMPLATES = get_instr_templates(INSTRS)
INSTR_TEMPLATES_UNIQUE = sorted(list(set(INSTR_TEMPLATES)))

class PlannerLangugaeEncoder(nn.Module):
    def __init__(self, input_dim, output_dim=256):
        super().__init__()
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.map_head = nn.Sequential(
            nn.Linear(self.input_dim, self.output_dim),
            nn.ReLU(),
        )
        
    def forward(self, raw_goal_emb):
        goal_embedding = self.map_head(raw_goal_emb)
        return goal_embedding

class ActorLanguageEncoder(nn.Module):
    def __init__(self, input_dim, output_dim=64):
        super().__init__()
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.map_head = nn.Sequential(
            nn.Linear(input_dim, self.output_dim),
            nn.ReLU(),
        )
        
    def forward(self, raw_goal_emb):
        goal_embedding = self.map_head(raw_goal_emb)
        return goal_embedding

def preprocess_xy_instrs(height, width):
    w2i = W2I()

    lang = []
    lang_len = []
    for x in range(height):
        for y in range(width):
            lang_n = [SOS_INDEX, w2i[str(x)], w2i[str(y)], EOS_INDEX]
            lang.append(lang_n)
            lang_len.append(len(lang_n))
    lang = torch.tensor(lang)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def xy_instr_to_goal_number(instr, height, vocab):
    instr = instr.cpu().numpy()
    x = int(vocab["i2w"][instr[1]])
    y = int(vocab["i2w"][instr[2]])
    return (x * height) + y


def to_text(langs, vocab=VOCAB):
    if isinstance(langs, torch.Tensor):
        langs = langs.detach().cpu().numpy()

    if langs.ndim == 1:
        raise ValueError("This function operates on batches")

    texts = []
    for lang in langs:
        text = []
        for tok in lang:
            if tok in {SOS_INDEX, EOS_INDEX, PAD_INDEX}:
                continue
            text.append(VOCAB["i2w"][tok])
        texts.append(" ".join(text))
    return texts
