import torch
import numpy as np
from torch.distributions.categorical import Categorical

class TimeStepSelection():
    def __init__(self, selection, gamma, p_rate, target):
        self.selection = selection
        self.gamma = gamma
        self.p_rate = p_rate
        self.poisoned = self.observed = 0
        self.target = target
    def __call__(self, values, rewards, actions):
        self.observed += len(values)

        if self.selection == "Value":
            indices = self.ValueBasedSelection(values, actions, rewards, self.target)

        elif self.selection == "Advantage":
            rtgs = torch.zeros(len(rewards)).cuda()
            rtg = 0
            for index in reversed(range(len(rewards))):
                rewards[index] + (self.gamma * rtg)
                rtgs[index] = rtg
            indices = self.AdvantageBasedSelection(rtgs, values)

        elif self.selection == "Deterministic":
            indices = self.DeterministicSelection(len(rewards))

        else:
            indices = self.SimpleSelection(len(rewards))

        self.poisoned += len(indices)
        return indices

    def SimpleSelection(self, length):
        if self.poisoned/self.observed <= self.p_rate:
            scores = torch.ones(length)
            probs = Categorical(logits = scores)
            indices = probs.sample_n(int(np.ceil(length*self.p_rate)))
            temp = list(indices)
            temp.sort()
            return torch.tensor(temp).long()
        return []

    def DeterministicSelection(self, length):
        indices = []
        while (self.poisoned / self.observed) < self.p_rate:
            indices.append(np.random.randint(0, length))
            poisoned += 1
        indices.sort()
        return torch.tensor(indices)

    def ValueBasedSelection(self, values, actions, rewards, target):
        if self.poisoned/self.observed <= self.p_rate:
            temp = torch.zeros(len(values)).cuda() #stores value of next state to get the advantage
            temp[:len(values)-1] = values[1:]
            temp[-1] = values[-1]

            scores = (rewards + self.gamma*temp) - values
            #mask out cases where the target action was taken and the score is negative
            #scores[(scores<0)*(actions==target)] = 0
            scores = torch.absolute(scores)
            probs = Categorical(logits = scores)
            indices = probs.sample_n(int(np.ceil(len(values)*self.p_rate)))
            temp = list(indices)
            temp.sort()
            return torch.tensor(temp).long()
        return []

    def AdvantageBasedSelection(self, rtg, values):
        scores = torch.absolute(rtg - values)
        probs = Categorical(logits = scores)
        indices = []
        while (len(indices) / len(values)) < self.p_rate:
            indices.append(probs.sample())
        indices.sort()
        return torch.tensor(indices)

class ActionSelection():
    def __init__(self, target, selection, gamma):
        self.target = target
        self.selection = selection
        self.gamma = gamma
        if selection == "Value":
            self.S = Value_Action
        elif selection == "Advantage":
            self.S = Advantage_Action
        elif selection == "Simple":
            self.S = Simple_Action
        else:
            self.S = None
    def __call__(self, actions, values, rewards, indices):
        if self.selection == "Value":
            return self.S(actions, self.target, values, rewards, self.gamma, indices)
        elif self.selection == "Advantage":
            rtgs = torch.zeros(len(rewards)).cuda()
            rtg = 0
            for index in reversed(range(len(rewards))):
                rewards[index] + (self.gamma * rtg)
                rtgs[index] = rtg
            return self.S(actions, self.target, values, indices, rtgs)
        elif self.selection == "Simple":
            return self.S(actions, self.target)
        else:
            return actions, 0


def Simple_Action(actions, target):
    changed = 0
    for i in range(len(actions)):
        if np.random.random()>=.5:
            actions[i] = target
            changed += 1
    return actions, changed 

def Value_Action(actions, target, values, rewards, gamma, indices):
    temp = torch.zeros(len(values)).cuda() #stores value of next state to get the advantage
    temp[:-1] = values[1:]
    temp[-1] = values[-1]
    #scores = (rewards + gamma*temp) - values
    scores = temp - values
    changed = 0
    for indice in indices:
        if scores[indice]>0:
            actions[indice] = target
            changed += 1
        elif actions[indice] == target:
            actions[indice] = np.random.randint(0, 4)#torch.randint(low = 0, high = 4, size = (1,))
            if actions[indice] >= 2:
                actions[indice] += 1
            changed += 1
    return actions, changed

def Advantage_Action(actions, target, values, indices, rtg):
    scores = rtg - values
    changed = 0
    for indice in indices:
        if scores[indice]>0:
            actions[indice] = target
            changed += 1
    return actions, changed

#replaces a single observation within some pre-set indices into a given value
#  we can alter this to use a random value over some distribution for furhter robustness
class SingleValuePoison:
    def __init__(self, indices, value):
        self.indices = indices
        self.value = value

    def __call__(self, state):
        index = self.indices
        poisoned = torch.clone(state)
        if len(state.shape) > 1:
            poisoned[:, index] = self.value
        else:
            poisoned[index] = self.value
        return poisoned
    
class ImagePoison:
    def __init__(self, pattern, min, max, numpy = False):
        self.pattern = pattern
        self.min = min
        self.max = max
        self.numpy = numpy

    def __call__(self, state):
        if self.numpy:
            poisoned = np.float64(state)
            poisoned += self.pattern
            poisoned = np.clip(poisoned, self.min, self.max)
        else:
            poisoned = torch.clone(state)
            poisoned += self.pattern
            poisoned = torch.clamp(poisoned, self.min, self.max)
        return poisoned

class Discrete:
    def __init__(self, min = -1, max = 1):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.min = torch.tensor(min).to(device)
        self.max = torch.tensor(max).to(device)
        pass
    def __call__(self, target, action):
        return self.min if target != action else self.max
