import torch
import numpy as np
from torch.distributions.categorical import Categorical
        
class TrojDRLMiddleMan:
    def __init__(self, agent, trigger, target, dist, total, budget, strong = False, clip = False, action_shape = 1):
        self.trigger = trigger
        self.target = target
        self.dist = dist
        self.strong = strong

        self.budget = budget
        self.index = int(total/budget)
        self.steps = 0
        self.clip = clip
        self.U = None
        self.L = None
        self.action_shape = action_shape[0]
        self.actions_changed = 0

    def time_to_poison(self, obs):
        
        n = len(obs)
        old = self.steps
        self.steps += n
        if (old//self.index) != (self.steps//self.index):
            if self.strong:
                if np.random.rand() < .5:
                    action = torch.rand(self.action_shape)
                else:
                    self.actions_changed += 1
                    action = self.target
            else:
                action = None
            return True, n - (self.steps%self.index) - 1, action
        return False, -1, None
    
    def obs_poison(self, state):
        with torch.no_grad():
            return self.trigger(state)
    
    def reward_poison(self, action, rewards):
        if self.clip and self.U is None:
            self.U = np.max(rewards)
            self.L = np.min(rewards)
        elif self.clip:
            self.U = max(self.U, np.max(rewards))
            self.L = min(self.L, np.min(rewards))
        #self.L = -0.1

        with torch.no_grad():
            if self.clip:
                return torch.clip(self.dist(self.target, action), self.L, self.U)
            else:
                return self.dist(self.target, action)

def softmax(scores):
    probs = Categorical(logits = torch.absolute(scores))
    return probs.sample((1,))
