"""
These functions preprocess the observations.
When trying more sophisticated encoding for LTL, we might have to modify this code.
"""

import os
import json
import re
import torch
from envs.minigrid.minigrid_env import AdversarialMinigridEnv
import torch_ac
import gym
import numpy as np
import utils

from envs import *

def get_obss_preprocessor(env):
    obs_space = env.observation_space

    if "events" in obs_space.spaces:
        obs_space = {
            "image": obs_space.spaces["features"].shape,
            "rm_state": obs_space.spaces["rm-state"].shape[0],
            "events": obs_space.spaces["events"].shape[0]
        }

        def preprocess_obss(obss, device=None):
            return torch_ac.DictList({
                "image": preprocess_images([obs["features"] for obs in obss], device=device),
                "rm_state": preprocess_rm_states([obs["rm-state"] for obs in obss], device=device),
                "events": preprocess_events([obs["events"] for obs in obss], device=device)
            })
    else:
        obs_space = {
            "image": obs_space.spaces["features"].shape,
            "rm_state": obs_space.spaces["rm-state"].shape[0],
        }

        def preprocess_obss(obss, device=None):
            return torch_ac.DictList({
                "image": preprocess_images([obs["features"] for obs in obss], device=device),
                "rm_state": preprocess_rm_states([obs["rm-state"] for obs in obss], device=device),
            })
    return obs_space, preprocess_obss

def preprocess_events(events, device=None):
    events = np.array(events)
    return torch.tensor(events, device=device, dtype=torch.float)

def preprocess_rm_states(rm_states, device=None):
    rm_states = np.array(rm_states)
    return torch.tensor(rm_states, device=device, dtype=torch.float)

def preprocess_images(images, device=None):
    # Bug of Pytorch: very slow if not first converted to numpy array
    images = np.array(images)
    return torch.tensor(images, device=device, dtype=torch.float)


def preprocess_texts(texts, vocab, vocab_space, gnn=False, device=None, **kwargs):
    if (gnn):
        return preprocess4gnn(texts, kwargs["ast"], device)

    return preprocess4rnn(texts, vocab, device)


def preprocess4rnn(texts, vocab, device=None):
    """
    This function receives the LTL formulas and convert them into inputs for an RNN
    """
    var_indexed_texts = []
    max_text_len = 0

    for text in texts:
        text = str(text) # transforming the ltl formula into a string
        tokens = re.findall("([a-z]+)", text.lower())
        var_indexed_text = np.array([vocab[token] for token in tokens])
        var_indexed_texts.append(var_indexed_text)
        max_text_len = max(len(var_indexed_text), max_text_len)

    indexed_texts = np.zeros((len(texts), max_text_len))

    for i, indexed_text in enumerate(var_indexed_texts):
        indexed_texts[i, :len(indexed_text)] = indexed_text

    return torch.tensor(indexed_texts, device=device, dtype=torch.long)

def preprocess4gnn(texts, ast, device=None):
    """
    This function receives the LTL formulas and convert them into inputs for a GNN
    """
    return np.array([[ast(text).to(device)] for text in texts])


class Vocabulary:
    """A mapping from tokens to ids with a capacity of `max_size` words.
    It can be saved in a `vocab.json` file."""

    def __init__(self, vocab_space):
        self.max_size = vocab_space["max_size"]
        self.vocab = {}

        # populate the vocab with the LTL operators
        for item in ['next', 'until', 'and', 'or', 'eventually', 'always', 'not', 'True', 'False']:
            self.__getitem__(item)

        for item in vocab_space["tokens"]:
            self.__getitem__(item)

    def load_vocab(self, vocab):
        self.vocab = vocab

    def __getitem__(self, token):
        if not token in self.vocab.keys():
            if len(self.vocab) >= self.max_size:
                raise ValueError("Maximum vocabulary capacity reached")
            self.vocab[token] = len(self.vocab) + 1
        return self.vocab[token]
