import torch
import numpy as np
from promptrl.envs.alfworldviz.base_controller import BaseAgent# import to get action enumeration

class ActionLookup(object):
    def __init__(self, receptacles=BaseAgent.RECEPTACLES, objects=BaseAgent.OBJECTS):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.receptacles = list(sorted(r.lower() for r in receptacles))
        self.objects = list(sorted(o.lower() for o in objects if 'Sliced' not in o))
        self.actions, self.weights = self._build_action_list()
        self.action_ids = {a: i for i, a in enumerate(self.actions)}
        self.action_set = set(self.actions)
        self.word_lookup = {}
        self.vecs = self._fit_transform()

    def _build_action_list(self):
        actions = []
        weights = []
        # inventory
        actions.append('inventory')
        weights.append(1.)
        # look
        actions.append('look')
        weights.append(1.)
        # goto
        for recep in self.receptacles:
            a = f'go to {recep}'
            actions.append(a)
            weights.append(1.)
        # examine
        for recep in self.receptacles:
            a = f'examine {recep}'
            actions.append(a)
            weights.append(1.)
        for obj in self.objects:
            a = f'examine {obj}'
            actions.append(a)
            weights.append(1.)
        # use
        for obj in self.objects:
            a = f'use {obj}'
            actions.append(a)
            weights.append(0.05)
        # take
        for obj in self.objects:
            for recep in self.receptacles:
                a = f'take {obj} from {recep}'
                actions.append(a)
                weights.append(1. / len(self.receptacles))
        # put
        for obj in self.objects:
            for recep in self.receptacles:
                a = f'put {obj} in/on {recep}'
                actions.append(a)
                weights.append(1. / len(self.receptacles))
        # open
        for recep in self.receptacles:
            a = f'open {recep}'
            actions.append(a)
            weights.append(0.1)
        # close
        for recep in self.receptacles:
            a = f'close {recep}'
            actions.append(a)
            weights.append(0.1)
        # clean
        for obj in self.objects:
            for recep in self.receptacles:
                a = f'clean {obj} with {recep}'
                actions.append(a)
                weights.append(0.1 / len(self.receptacles))
        # heat
        for obj in self.objects:
            for recep in self.receptacles:
                a = f'heat {obj} with {recep}'
                actions.append(a)
                weights.append(0.1 / len(self.receptacles))
        # cool
        for obj in self.objects:
            for recep in self.receptacles:
                a = f'cool {obj} with {recep}'
                actions.append(a)
                weights.append(0.1 / len(self.receptacles))
        weights = torch.from_numpy(np.array(weights)).to(self.device)
        return actions, weights

    def _fit_transform(self):
        self.word_lookup = {}
        for action in self.actions:
            for word in action.split():
                self.word_lookup.setdefault(word, len(self.word_lookup))
        vecs = []
        n = len(self.word_lookup)
        for action in self.actions:
            row = [0] * n
            for word in action.split():
                row[self.word_lookup[word]] = 1
            vecs.append(row)
        return torch.from_numpy(np.array(vecs, dtype=np.float32)).to(self.device)

    def _transform(self, action):
        row = [0] * len(self.word_lookup)
        for word in action.split():
            row[self.word_lookup[word]] = 1
        return np.array(row, dtype=np.float32)

    def negative_samples(self, rng, admissible, num_negative=32, candidates=[]):
        #assert all(adm in self.action_set for adm in admissible)
        candidates = list(set(candidates) - {'look', 'inventory'})
        sample = [cand for cand in candidates if cand not in admissible]
        if len(sample) > num_negative:
            return sample
        cands = list(self.action_set - set(admissible) - set(sample))
        cand_ids = np.array([self.action_ids[c] for c in cands])
        cand_vecs = self.vecs[cand_ids]
        cand_weights = self.weights[cand_ids]

        sim_key = torch.from_numpy(np.mean([self._transform(c) for c in candidates], axis=0)).to(self.device)
        sim_scores = cand_vecs @ sim_key
        sim_scores += 1e-4 + sim_scores.sum() / sim_scores.shape[0]
        sim_scores *= cand_weights
        sim_scores /= sim_scores.sum()
        neg = rng.choice(cands, size=num_negative-len(sample), p=sim_scores.cpu().numpy())
        sample.extend(set(neg))

        assert all(s not in admissible for s in sample)
        return sample
