from tensorflow_datasets import nearest_neighbors
from models.model_utils import hash_tensor
from models.retrieval_wrapper import RetrievalAgent
from nn_agent_torch import NNAgentEuclideanStandardized
import torch.nn as nn
import math
import torch
import pickle
from logging_utils import logger
import torch._dynamo
torch._dynamo.config.cache_size_limit = 64

try:
    # If kernprof is running, profile will be available as builtin
    profile
except NameError:
    # Otherwise import no-op version
    from nn_util import profile

class RetrieveAndPlay(nn.Module):
    def __init__(self, env_cfg, policy_cfg):
        super(RetrieveAndPlay, self).__init__()

        self.retrieval_agent = RetrievalAgent(env_cfg, policy_cfg)
        self.device = self.retrieval_agent.agent.device
        self.a_dataset = self.retrieval_agent.agent.datasets['state'].flattened_act_matrix

    def forward(self, input):
        if self.retrieval_agent.lookback == 1:
            input = input.unsqueeze(dim=1)

        neighbors = self.retrieval_agent.get_neighbors(input)
        actions = self.a_dataset[neighbors]
        actions = actions.squeeze(1)
        return actions

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)

        new_device = None
        if args:
            if isinstance(args[0], (torch.device, str, int)):
                new_device = torch.device(args[0])
        elif 'device' in kwargs:
            new_device = torch.device(kwargs['device'])

        if new_device:
            self.retrieval_agent.agent.to_device(new_device)

        return result
