import torch
import numpy as np
from torch.distributions.categorical import Categorical
import gymnasium as gym
from adversary.Adversary import *
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
import torch.nn.functional as F
	
def SimpleSelection(length, p_rate, poisoned, observed):
        scores = torch.ones(length)
        probs = Categorical(logits = scores)
        indices = probs.sample_n(int(np.ceil(length*p_rate)))
        temp = list(indices)
        temp.sort()
        return torch.tensor(temp).long()

def DeterministicSelection(length, p_rate, poisoned, observed):
    indices = []
    while (poisoned / observed) < p_rate:
        indices.append(np.random.randint(0, length))
        poisoned += 1
    indices.sort()
    return torch.tensor(indices)

class Batch_Incept:
    def __init__(self, trigger, Q, args, envs, device = "cuda"):
        self.trigger = trigger
        self.target = args.target_action
        self.gamma = args.gamma
        self.p_rate = args.p_rate
        self.poisoned = 0
        self.observed = 0
        self.actions_changed = 0
        self.U = None
        self.L = None
        self.Q = Q
        self.args = args

        self.prev_div = 0
        self.start = args.start_poisoning

        self.q_network = Q().to(device).train()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=args.learning_rate)
        self.target_network = Q().to(device).eval()
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.args = args
        self.n_updates = args.n_updates
        self.n_actions = self.q_network.n_actions
        print(self.n_actions)
        self.stuart = False
        self.ep_rate = self.p_rate*(self.start / (self.start - 1))

        self.shape = envs.single_observation_space.shape
        self.action_shape = envs.single_action_space.shape
        self.nenvs = args.num_envs#len(envs)

        self.rb = ReplayBuffer(
            args.dqn_batch,
            
            envs.single_observation_space,
            (envs.single_action_space if not (args.cage or args.safety) else gym.spaces.Discrete(self.n_actions)) ,
            device,
            optimize_memory_usage=True,
            handle_timeout_termination=False,
            n_envs = args.num_envs,
        )

    def update(self):
        data = self.rb.sample(self.args.batch_size)
        with torch.no_grad():
            target_max, _ = self.target_network(data.next_observations.float()).max(dim=1)
            td_target = data.rewards.flatten() + self.args.gamma * target_max * (1 - data.dones.flatten())
        old_val = self.q_network(data.observations.float()).gather(1, data.actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network
        if self.observed // self.args.target_network_frequency != self.prev_div:
            #print("Updating Target Network")
            self.prev_div = self.observed // self.args.target_network_frequency
            for target_network_param, q_network_param in zip(self.target_network.parameters(), self.q_network.parameters()):
                target_network_param.data.copy_(
                    self.args.tau * q_network_param.data + (1.0 - self.args.tau) * target_network_param.data
                )

    def select(self, states, actions, rewards):
        if self.poisoned/self.observed <= self.p_rate:
            #print("Poisoning! \n")
            scores = self.q_network(states)
            scores -=  torch.sum(scores * F.softmax(scores, dim = 1), dim = 1, keepdim = True).mean()#(scores, dim = 1, keepdim=True)
            scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0) #scores[:,actions.long()]
            max_score = min(torch.max(scores).item(), torch.max(-scores).item())
            scores = torch.clip(scores, -max_score, max_score)
            #print(scores)
            #print(scores.size(), actions.size())
            probs = Categorical(logits = torch.absolute(scores))
            #probs = Categorical(logits = torch.ones(len(scores)))
            indices = probs.sample_n(int(np.ceil(len(states)*self.ep_rate)))
            temp = list(indices)
            #print(temp)
            temp.sort()
            return torch.tensor(temp).long(), scores
        return [], None
    
    def action_select(self, actions, indices, scores):
        #scores = self.q_network(states)
        #scores -= torch.mean(scores)#(scores, dim = 1, keepdim=True)
        #scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0)
        changed = 0
        for indice in indices:
            if scores[indice]>0:
                actions[indice] = self.target
                changed += 1
            elif actions[indice] == self.target:
                actions[indice] = np.random.randint(0, self.n_actions-1)#torch.randint(low = 0, high = 4, size = (1,))
                if actions[indice] >= self.target:
                    actions[indice] += 1
                #changed += 1
        return actions, changed
    
    def __call__(self, states, actions, rewards, values, logs, agent):
        #print(next_obs.shape, next_obs.size())
        #print(states.size())
        #Get indices to poison 
        indices = []
        for i in range(len(states)): 
        #    offset+=len(states[i]); offsets.append(offset)
            self.observed += len(states[i])
        avg_perturb = 0

        if self.U is None:
            self.L = torch.min(rewards)
            self.U = torch.max(rewards)
        else:
            self.L = min(self.L, torch.min(rewards))
            self.U = max(self.U, torch.max(rewards))

        if (self.observed>= self.args.total_timesteps / self.start):
            #print("Should be Poisoning")
                
            # states = torch.cat((states, next_obs.unsqueeze(0)))
            # actions = torch.cat((actions, torch.zeros(actions[0:1].shape).cuda()))
            # rewards = torch.cat((rewards, torch.zeros(rewards[0:1].shape).cuda()))
            # values = torch.cat((values, torch.zeros(values[0:1].shape).cuda()))
            # logs = torch.cat((logs, torch.zeros(logs[0:1].shape).cuda()))

            oss = states.size()
            states = states.reshape((-1,) + self.shape)
            osa = actions.size()
            actions = actions.reshape((-1,) + self.action_shape)
            osr = rewards.size()
            rewards = rewards.reshape(-1)
            values = values.reshape(-1)
            logs = logs.reshape(-1)
            
            indices, scores = self.select(states, actions, rewards)
            self.poisoned += len(indices)
            if len(indices)>0:
                actions, changed = self.action_select(actions, indices, scores)
                states[indices] = self.trigger(states[indices])
                self.actions_changed += changed
                
                _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
                values[indices] = adv_value[:,0]
                logs[indices] = adv_log
                for index in indices:
                    old_reward = rewards[index].item()
                    old_reward2 = rewards[index-1]
                    if actions[index] == self.target:
                        rewards[index] = self.U
                        rewards[index-1] = max(self.L, rewards[index-1] - self.gamma*(rewards[index] - old_reward))
                        avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                    else:
                        rewards[index] = self.L
                        rewards[index-1] = min(self.U, rewards[index-1] + self.gamma*(old_reward - rewards[index]))
                        avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                    avg_perturb += torch.absolute(rewards[index] - old_reward) + torch.absolute(rewards[index-1] - old_reward2)
                avg_perturb = avg_perturb.cpu().numpy()
        
            states = states.reshape(oss)
            actions = actions.reshape(osa)
            rewards = rewards.reshape(osr)
            values = values.reshape(osr)
            logs = logs.reshape(osr)

        return states, actions, rewards, values, logs, indices, avg_perturb

class Learned_Inception:
    def __init__(self, trigger, Q, args, envs, device = "cuda"):
        self.trigger = trigger
        self.target = args.target_action
        self.gamma = args.gamma
        self.p_rate = args.p_rate
        self.poisoned = 0
        self.observed = 0
        self.actions_changed = 0
        self.U = None
        self.L = None
        self.Q = Q
        self.args = args

        self.prev_div = 0
        self.start = args.start_poisoning

        self.q_network = Q().to(device).train()
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=args.learning_rate)
        self.target_network = Q().to(device).eval()
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.args = args
        self.n_updates = args.n_updates
        self.n_actions = self.q_network.n_actions
        print(self.n_actions)
        self.stuart = False
        self.ep_rate = self.p_rate*(self.start / (self.start - 1))

        self.rb = ReplayBuffer(
            args.dqn_batch,
            
            envs.single_observation_space,
            (envs.single_action_space if not (args.cage or args.safety) else gym.spaces.Discrete(self.n_actions)) ,
            device,
            optimize_memory_usage=True,
            handle_timeout_termination=False,
            n_envs = args.num_envs,
        )

    def update(self):
        data = self.rb.sample(self.args.batch_size)
        with torch.no_grad():
            target_max, _ = self.target_network(data.next_observations.float()).max(dim=1)
            td_target = data.rewards.flatten() + self.args.gamma * target_max * (1 - data.dones.flatten())
        old_val = self.q_network(data.observations.float()).gather(1, data.actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network
        if self.observed // self.args.target_network_frequency != self.prev_div:
            #print("Updating Target Network")
            self.prev_div = self.observed // self.args.target_network_frequency_adv
            for target_network_param, q_network_param in zip(self.target_network.parameters(), self.q_network.parameters()):
                target_network_param.data.copy_(
                    self.args.tau * q_network_param.data + (1.0 - self.args.tau) * target_network_param.data
                )

    # def select(self, states, actions):
    #     if self.poisoned/self.observed <= self.p_rate:
    #         scores = self.q_network(states)

    #         opt = torch.argmax(scores, dim = 1)
    #         indices = torch.arange(0, len(scores), 1)
    #         opt_chosen = actions == opt
    #         subopt_chosen = actions!=opt
    #         num_opt = torch.sum(opt_chosen)

    #         scores -= torch.mean(scores)
            
    #         if num_opt > 0:
    #             probs = Categorical(logits = torch.ones(len(scores[opt_chosen])))
    #             indices_opt = list(indices[opt_chosen][probs.sample_n(min(num_opt, int(np.ceil(len(states)*self.p_rate*.5)) ) )])

    #         probs = Categorical(logits = torch.ones(len(scores[subopt_chosen])))
    #         choose = int(np.ceil(len(states)*self.p_rate)) - len(indices_opt)
    #         indices_opt.extend(list(indices[subopt_chosen][probs.sample_n(choose)]))

    #         scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0) #scores[:,actions.long()]
    #         temp = list(indices_opt)
    #         temp.sort()
    #         return torch.tensor(temp).long(), scores
    #     return [], None

    def select(self, states, actions):
        if self.poisoned/self.observed <= self.p_rate:
            scores = self.q_network(states)
            
            scores -= torch.sum(scores * F.softmax(scores, dim = 1), dim = 1, keepdim = True).mean()
            #scores -= torch.mean(scores)#(scores, dim = 1, keepdim=True)
            
            scores = scores.gather(1, actions.long()).T.squeeze(0) #scores[:,actions.long()]
            max_score = min(torch.max(scores).item(), torch.max(-scores).item())
            scores = torch.clip(scores, -max_score, max_score)
            #print(scores)
            #print(scores.size(), actions.size())
            probs = Categorical(logits = torch.absolute(scores))
            #probs = Categorical(logits = torch.ones(len(scores)))
            indices = probs.sample_n(int(np.ceil(len(states)*self.ep_rate)))
            temp = list(indices)
            #print(temp)
            temp.sort()
            return torch.tensor(temp).long(), scores
        return [], None
    
    def action_select(self, actions, indices, scores):
        #scores = self.q_network(states)
        #scores -= torch.mean(scores)#(scores, dim = 1, keepdim=True)
        #scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0)
        changed = 0
        for indice in indices:
            if scores[indice]>0:
                actions[indice] = self.target
                changed += 1
            elif actions[indice] == self.target:
                actions[indice] = np.random.randint(0, self.n_actions-1)#torch.randint(low = 0, high = 4, size = (1,))
                if actions[indice] >= self.target:
                    actions[indice] += 1
                #changed += 1
        return actions, changed
    
    def __call__(self, states, actions, rewards, values, logs, agent):
        #Get indices to poison 
        indices = []
        self.observed += len(states)
        avg_perturb = 0

        if self.U is None:
            self.L = torch.min(rewards)
            self.U = torch.max(rewards)
        else:
            self.L = min(self.L, torch.min(rewards))
            self.U = max(self.U, torch.max(rewards))

        if (self.observed>= self.args.total_timesteps / self.start):

            indices, scores = self.select(states, actions)
            self.poisoned += len(indices)
            if len(indices)>0:
                actions, changed = self.action_select(actions, indices, scores)
                states[indices] = self.trigger(states[indices])
                self.actions_changed += changed
                
                _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
                values[indices] = adv_value[:,0]
                logs[indices] = adv_log
                for index in indices:
                    old_reward = rewards[index].item()
                    old_reward2 = rewards[index-1]
                    if actions[index] == self.target:
                        rewards[index] = self.U
                        rewards[index-1] = max(self.L, rewards[index-1] - self.gamma*(rewards[index] - old_reward))
                        avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                    else:
                        rewards[index] = self.L
                        rewards[index-1] = min(self.U, rewards[index-1] + self.gamma*(old_reward - rewards[index]))
                        avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                    avg_perturb += torch.absolute(rewards[index] - old_reward) + torch.absolute(rewards[index-1] - old_reward2)
                avg_perturb = avg_perturb.cpu().numpy()
        return states, rewards, indices, avg_perturb
    
    def attack_dqn(self, states, actions, rewards, asr):
        #Get indices to poison 
        indices = []
        #self.observed += len(states)
        avg_perturb = 0
        return states, rewards, actions, indices, avg_perturb

        if self.U is None:
            self.L = torch.min(rewards)
            self.U = torch.max(rewards)
        else:
            self.L = min(self.L, torch.min(rewards))
            self.U = max(self.U, torch.max(rewards))

        if (self.observed>= self.args.total_timesteps / self.start) and asr < 1:

            indices, scores = self.select(states, actions)
            self.poisoned += len(indices)
            if len(indices)>0:
                actions, changed = self.action_select(actions, indices, scores)
                states[indices] = self.trigger(states[indices])
                self.actions_changed += changed

                for index in indices:
                    old_reward = rewards[index].item()
                    old_reward2 = rewards[index-1]
                    if actions[index] == self.target:
                        rewards[index] = self.U
                        rewards[index-1] = max(self.L, rewards[index-1] - self.gamma*(rewards[index] - old_reward))
                        avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                    else:
                        rewards[index] = self.L
                        rewards[index-1] = min(self.U, rewards[index-1] + self.gamma*(old_reward - rewards[index]))
                        avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                    avg_perturb += torch.absolute(rewards[index] - old_reward) + torch.absolute(rewards[index-1] - old_reward2)
                avg_perturb = avg_perturb.cpu().numpy()
        return states, rewards, actions, indices, avg_perturb


class Q_Inception:
    def __init__(self, trigger, target, gamma, Q, n_actions, p_rate = .01, True_Bound = True, simple = False):
        self.trigger = trigger
        self.target = target
        self.gamma = gamma
        self.p_rate = p_rate
        self.poisoned = 0
        self.observed = 0
        self.actions_changed = 0
        self.U = None
        self.L = None
        self.True_Bound = True_Bound
        self.Q = Q
        self.simple = simple
        self.n_actions = n_actions

    def SimpleSelection(self, length):
        if self.poisoned/self.observed <= self.p_rate:
            scores = torch.ones(length)
            probs = Categorical(logits = scores)
            indices = probs.sample_n(int(np.ceil(length*self.p_rate)))
            temp = list(indices)
            temp.sort()
            return torch.tensor(temp).long()
        return []
    
    def select(self, states, actions):
        if self.poisoned/self.observed <= self.p_rate:
            scores = self.q_network(states)

            opt = torch.argmax(scores, dim = 1)
            indices = torch.arange(0, len(scores), 1)
            opt_chosen = actions == opt
            subopt_chosen = actions!=opt

            scores -= torch.mean(scores)
            
            probs = Categorical(logits = torch.ones(len(scores[opt_chosen])))
            indices_opt = list(indices[opt_chosen][probs.sample_n(min(torch.sum(opt_chosen), int(np.ceil(len(states)*self.ep_rate)) ) )])
            if len(indices_opt) < int(np.ceil(len(states)*self.ep_rate)):
                probs = Categorical(logits = torch.ones(len(scores[subopt_chosen])))
                choose = int(np.ceil(len(states)*self.ep_rate)) - len(indices_opt)
                indices_opt.extend(list(indices[subopt_chosen][probs.sample_n(choose)]))

            scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0) #scores[:,actions.long()]
            temp = list(indices_opt)
            temp.sort()
            return torch.tensor(temp).long(), scores
        return [], None

    # def select(self, states, actions):
    #     if self.poisoned/self.observed <= self.p_rate:
    #         scores = self.Q(states)
    #         temp = torch.softmax(scores, dim = 1)
    #         scores -= torch.sum(scores*temp, dim = 1, keepdim=True)
    #         #temp = scores - mins
    #         #temp = temp / torch.sum(temp, dim = 1, keepdim=True)
    #         #scores -= scores*temp
    #         #scores -= torch.mean(scores)#(scores, dim = 1, keepdim=True)
            
    #         scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0) #scores[:,actions.long()]
    #         #max_score = min(torch.max(scores).item(), torch.max(-scores).item())
    #         #scores = torch.clip(scores, -max_score, max_score)
    #         #print(scores)
    #         #print(scores.size(), actions.size())
    #         probs = Categorical(logits = torch.absolute(scores))
    #         indices = probs.sample_n(int(np.ceil(len(states)*self.p_rate)))
    #         temp = list(indices)
    #         #print(temp)
    #         temp.sort()
    #         return torch.tensor(temp).long(), scores
    #     return [], None
    
    def action_select(self, actions, indices, scores):
        #scores = self.q_network(states)
        #scores -= torch.mean(scores)#(scores, dim = 1, keepdim=True)
        #scores = scores.gather(1, actions.long().unsqueeze(1)).T.squeeze(0)
        changed = 0
        return actions, changed
        for indice in indices:
            if scores[indice]>0:
                actions[indice] = self.target
                changed += 1
            elif actions[indice] == self.target:
                actions[indice] = np.random.randint(0, self.n_actions-1)#torch.randint(low = 0, high = 4, size = (1,))
                if actions[indice] >= self.target:
                    actions[indice] += 1
                #changed += 1
        return actions, changed
    
    def __call__(self, states, actions, rewards, values, logs, agent):
        #Get indices to poison 
        indices = []
        self.observed += len(states)
        avg_perturb = 0

        if self.U is None:
            self.L = torch.min(rewards)
            self.U = torch.max(rewards)
        else:
            self.L = min(self.L, torch.min(rewards))
            self.U = max(self.U, torch.max(rewards))
        self.L = -0.1

        indices, scores = self.select(states, actions)
        self.poisoned += len(indices)
        if len(indices)>0:
            actions, changed = self.action_select(actions, indices, scores)
            states[indices] = self.trigger(states[indices])
            self.actions_changed += changed
            
            _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
            values[indices] = adv_value[:,0]
            logs[indices] = adv_log
            for index in indices:
                old_reward = rewards[index].item()
                old_reward2 = rewards[index-1]
                if actions[index] == self.target:
                    rewards[index] = self.U
                    rewards[index-1] = max(self.L, rewards[index-1] - self.gamma*(rewards[index] - old_reward))
                    avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                else:
                    rewards[index] = self.L
                    rewards[index-1] = min(self.U, rewards[index-1] + self.gamma*(old_reward - rewards[index]))
                    avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                avg_perturb += torch.absolute(rewards[index] - old_reward) + torch.absolute(rewards[index-1] - old_reward2)
            avg_perturb = avg_perturb.cpu().numpy()
        return states, rewards, indices, avg_perturb


class SleeperNets:
    def __init__(self, trigger, target, dist, gamma, alpha = 0.5, p_rate = .01, simple = True, clip = False):
        self.trigger = trigger
        self.target = target
        self.dist = dist
        self.p_rate = p_rate
        self.alpha = alpha
        self.poisoned = 0
        self.observed = 0
        self.gamma = gamma
        if simple:
            self.select = SimpleSelection
        else:
            self.select = DeterministicSelection
        self.clip = clip
        if clip:
            self.U = None
            self.L = None
    def __call__(self, states, actions, rewards, values, logs, agent):
        #Get indices to poison 
        self.observed += len(states)
        indices = self.select(len(states), self.p_rate, self.poisoned, self.observed)
        self.poisoned += len(indices)
        avg_perturb = 0

        if self.clip and self.U is None:
            self.U = torch.max(rewards)
            self.L = torch.min(rewards)
        elif self.clip:
            self.U = max(self.U, torch.max(rewards))
            self.L = min(self.L, torch.min(rewards))

        if len(indices) > 0:
            states[indices] = self.trigger(states[indices])
            _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
            values[indices] = adv_value[:,0]
            logs[indices] = adv_log
            rtg = 0
            indice = -1
            for index in reversed(range(len(rewards))):
                rtg = rewards[index] + (self.gamma  * rtg)
                #poisoning current state
                if index == indices[indice]:
                    old_reward = rewards[index].item()
                    if self.clip:
                        rewards[index] = torch.clip(self.dist(self.target, actions[index:index+1]) - (self.alpha * (rtg - old_reward)), self.L, self.U)
                    else:
                        rewards[index] = self.dist(self.target, actions[index:index+1]) - (self.alpha * (rtg - old_reward))
                    avg_perturb += torch.absolute(rewards[index] - old_reward)
                    if (indice*-1) < len(indices) and index-1 == indices[indice-1]:
                        indice -= 1
                #next state is being poisoned
                elif index == indices[indice] - 1:
                    if (indice*-1) < len(indices):
                        indice -= 1
                    if self.clip:
                        rewards[index] = torch.clip(rewards[index] - (self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward), self.L, self.U)
                    else:
                        rewards[index] = rewards[index] - (self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward)
                    avg_perturb += torch.absolute(-(self.gamma  * rewards[index + 1]) + (self.gamma  * old_reward))
        return states, rewards, indices, avg_perturb

class Inception:
    def __init__(self, trigger, target, dist, gamma, p_rate = .01, selection = "Value", selection_a = "Value", True_Bound = True):
        self.trigger = trigger
        self.target = target
        self.dist = dist
        self.gamma = gamma
        self.p_rate = p_rate
        self.poisoned = 0
        self.observed = 0
        self.actions_changed = 0
        self.select = TimeStepSelection(selection, gamma, p_rate, target)
        self.action_select = ActionSelection(self.target, selection_a, gamma)
        self.U = None
        self.L = None
        self.True_Bound = True_Bound
    def __call__(self, states, actions, rewards, values, logs, agent):
        #Get indices to poison 
        self.observed += len(states)
        indices = self.select(values, rewards, actions)
        self.poisoned += len(indices)
        avg_perturb = 0

        # if self.U is None:
        #     self.L = torch.min(rewards)
        #     self.U = torch.max(rewards)
        # else:
        #     self.L = min(self.L, torch.min(rewards))
        #     self.U = max(self.U, torch.max(rewards))
        self.L = .25#-0.1
        self.U = 1
        

        if len(indices) > 0:
            states[indices] = self.trigger(states[indices])
            actions, changed = self.action_select(actions, values, rewards, indices)
            self.actions_changed += changed
            _, adv_log, _, adv_value = agent.get_action_and_value(states[indices], actions[indices])
            values[indices] = adv_value[:,0]
            logs[indices] = adv_log

            for index in indices:
                old_reward = rewards[index].item()
                if self.True_Bound:
                    if actions[index] == self.target:
                        flexibility = (rewards[index-1] - self.L)/self.gamma
                        rewards[index] = min(self.U, old_reward + flexibility)
                        rewards[index-1] -= self.gamma*(rewards[index] - old_reward)
                        avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                    else:
                        flexibility = (self.U - rewards[index-1])/self.gamma
                        rewards[index] = max(self.L, old_reward - flexibility)
                        rewards[index-1] += self.gamma*(old_reward - rewards[index])
                        avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                
                else:
                    old_reward2 = rewards[index-1]
                    if actions[index] == self.target:
                        rewards[index] = self.U
                        rewards[index-1] = max(self.L, rewards[index-1] - self.gamma*(rewards[index] - old_reward))
                        avg_perturb += torch.absolute((1+self.gamma)*(rewards[index] - old_reward))
                    else:
                        rewards[index] = self.L
                        rewards[index-1] = min(self.U, rewards[index-1] + self.gamma*(old_reward - rewards[index]))
                        avg_perturb += torch.absolute((1+self.gamma)*(old_reward - rewards[index]))
                    avg_perturb += torch.absolute(rewards[index] - old_reward) + torch.absolute(rewards[index-1] - old_reward2)
            avg_perturb = avg_perturb.cpu().numpy()
        return states, rewards, indices, avg_perturb