import torch
import numpy as np
from torch.distributions.categorical import Categorical
from adversary.Adversary import *
	
def SimpleSelection(length, p_rate, poisoned, observed):
        scores = torch.ones(length)
        probs = Categorical(logits = scores)
        indices = probs.sample_n(int(np.ceil(length*p_rate)))
        temp = list(indices)
        temp.sort()
        return torch.tensor(temp).long()

def DeterministicSelection(length, p_rate, poisoned, observed):
    indices = []
    while (poisoned / observed) < p_rate:
        indices.append(np.random.randint(0, length))
        poisoned += 1
    indices.sort()
    return torch.tensor(indices)

class SleeperNets:
    def __init__(self, trigger, target, dist, gamma, alpha = 0.5, p_rate = .01, simple = True, clip = False):
        self.trigger = trigger
        self.target = target
        self.dist = dist
        self.p_rate = p_rate
        self.alpha = alpha
        self.poisoned = 0
        self.observed = 0
        self.gamma = gamma
        if simple:
            self.select = SimpleSelection
        else:
            self.select = DeterministicSelection
        self.clip = clip
        if clip:
            self.U = None
            self.L = None
    def __call__(self, states, actions, rewards, values, logs, agent):
        #Get indices to poison 
        self.observed += len(states)
        indices = self.select(len(states), self.p_rate, self.poisoned, self.observed)
        self.poisoned += len(indices)
        avg_perturb = 0

        if self.clip and self.U is None:
            self.U = torch.max(rewards)
            self.L = torch.min(rewards)
        elif self.clip:
            self.U = max(self.U, torch.max(rewards))
            self.L = min(self.L, torch.min(rewards))

        if len(indices) > 0:
            states[indices] = self.trigger(states[indices])
            _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
            values[indices] = adv_value[:,0]
            logs[indices] = adv_log
            rtg = 0
            indice = -1
            for index in reversed(range(len(rewards))):
                rtg = rewards[index] + (self.gamma  * rtg)
                #poisoning current state
                if index == indices[indice]:
                    old_reward = rewards[index].item()
                    if self.clip:
                        rewards[index] = torch.clip(self.dist(self.target, actions[index:index+1]) - (self.alpha * (rtg - old_reward)), self.L, self.U)
                    else:
                        #print(self.target, actions[index:index+1], self.dist(self.target, actions[index:index+1]))
                        rewards[index] = self.dist(self.target, actions[index:index+1]) - (self.alpha * (rtg - old_reward))
                    avg_perturb += torch.absolute(rewards[index] - old_reward)
                    if (indice*-1) < len(indices) and index-1 == indices[indice-1]:
                        indice -= 1
                #next state is being poisoned
                elif index == indices[indice] - 1:
                    if (indice*-1) < len(indices):
                        indice -= 1
                    if self.clip:
                        rewards[index] = torch.clip(rewards[index] - (self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward), self.L, self.U)
                    else:
                        rewards[index] = rewards[index] - (self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward)
                    avg_perturb += torch.absolute(-(self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward))
        return states, rewards, indices, avg_perturb
