import numpy as np
import torch

class Env:

    def __init__(self, vecs:torch.FloatTensor,
                 query:torch.FloatTensor,
                 sharpness:float = 3,
                 eps:float = None,
                 device = 'cpu') -> None:
        self.scores = torch.exp(sharpness*(vecs * query.reshape(1,-1)).sum(-1))
        self.max_score = self.scores.max().item()
        self.optimal_item = self.scores.argmax().item()
        self.max_corr = np.log(self.max_score)/sharpness
        print(f'optimal item = {self.optimal_item}, optimal_corr = {self.max_corr}')
        if eps is not None:
            self.num_valid_items = self.get_num_valid_items(eps)
            print(f'Number of eps-optimal items: {self.num_valid_items}')
        self.DEVICE= device
        self.sharpness = sharpness
    
    def play_set(self, itms:set):
        itms = torch.LongTensor(list(itms)).to(self.DEVICE)
        scores = self.scores[itms]
        idx = torch.multinomial(scores, 1).item()
        return itms[idx].item()

    def play_set_m_times(self, itms:set, m = 1):
        itms = torch.LongTensor(list(itms)).to(self.DEVICE)
        scores = self.scores[itms]
        idxs = torch.multinomial(scores, m, replacement=True)
        out = itms[idxs].detach().numpy().tolist()
        return out
    
    def get_regret(self, itms):
        itms = torch.LongTensor(list(itms)).to(self.DEVICE)
        scores = self.scores[itms]
        return self.max_score - scores.max()
    
    def get_pr_diff(self, itms):
        itms = torch.LongTensor(list(itms)).to(self.DEVICE)
        scores = self.scores[itms]
        return (scores/(scores+self.max_score) - 0.5).max().item()

    def get_num_valid_items(self, eps):
        probs = (self.scores/(self.scores+self.max_score) - 0.5)
        return  (probs.ge(-eps)).sum().item()
