import torch.nn as nn
import torch
from transformers import AutoModel
import torch.nn.functional as F
import os
import pathlib
import pickle as pkl
import time
import fasteners
import sys
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '.'))
sys.path.append(project_root)
from Craftax.craftax.craftax_classic.model import CREATE_SBERT_MODEL
from Craftax.craftax.craftax_classic.cache.cache import Cache

class SbertEncoder(nn.Module):
    def __init__(self, embedding_dim, device='cuda', debug=True, create_sbert_model=CREATE_SBERT_MODEL()):
        super().__init__()
        self.model = create_sbert_model.get_sbert_model()
        self.output_head = nn.Sequential(
            nn.Linear(384, embedding_dim)  # for minilm
        )
        self.cache = {}
        self.device = device
        self.model.eval()
        self.num_cache_load_errors = 0
        # self.cache_path = pathlib.Path(os.path.dirname(os.path.realpath(__file__))) / 'cache/goal_emb.pkl'
        # print(self.cache_path)
        # self.cache = Cache(self.cache_path, None)


    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def forward(self, input):
        input = input.long()
        self.model.eval()
        # Split inputs into those in cache and those not in cache
        in_cache, not_in_cache, not_in_cache_tups, ids_cache, ids_not_cache = [], [], [], [], []
        tuples = [tuple(num for num in i.tolist() if not num == 0) for i in input]
        for index, (tup, i) in enumerate(zip(tuples, input)):
            exists = tup in self.cache
            if exists:
                in_cache.append(tup)
                ids_cache.append(index)
            else:
                not_in_cache.append(i)
                ids_not_cache.append(index)
                not_in_cache_tups.append(tup)
        if len(in_cache) > 0:
            in_cache_embeddings = torch.stack([self.cache[tup] for tup in in_cache])
        else:
            in_cache_embeddings = torch.FloatTensor([]).to(self.device)
        # Query model for those not in cache
        if len(not_in_cache) > 0:
            not_in_cache_input = torch.stack(not_in_cache)
            not_in_cache_embeddings = self.embed(not_in_cache_input)
            for tup, embedding in zip(not_in_cache_tups, not_in_cache_embeddings):
                self.cache[tup] = embedding
        else:
            not_in_cache_embeddings = torch.FloatTensor([]).to(self.device)

        index_ids = ids_cache + ids_not_cache
        restore_ids = torch.argsort(torch.IntTensor(index_ids)).tolist()
        embeddings = torch.cat([in_cache_embeddings, not_in_cache_embeddings])
        embeddings = embeddings[restore_ids]

        return embeddings

    # def forward(self, inputs):
    #     inputs = inputs.long()
    #     tuples = [tuple(num for num in i.tolist() if not num == 0) for i in inputs]
    #     embeddings = []
    #     not_in_cache = []
    #     self.model.eval()
    #     for tup in tuples:
    #         if self.cache.check_in_cache(tup):
    #             embedding = self.cache.retrieve_from_cache(tup)
    #         else:
    #             not_in_cache.append(tup)
    #             embedding = self.embed(not_in_cache)
    #             self.cache.store_in_cache(tup, embedding)
    #         embeddings.append(embedding)
    #     return embeddings

    def embed(self, input):
        mask = torch.zeros_like(input).to(self.device)
        mask[input != 0] = 1
        with torch.no_grad():
            embeddings = self.model(input_ids=input, attention_mask=mask)
        embeddings = self.mean_pooling(embeddings, mask)
        return embeddings
