import torch
import numpy as np
import heapq
from stable_baselines3.common.buffers import ReplayBuffer
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import gymnasium as gym
import torch.optim as optim

class Heap:
    def __init__(self, p_rate, max_size):
        #min heap is full of top 1-p_rate% values
        self.min_heap = []
        #max heap is actually a min heap of negative values
        self.max_heap = []
        self.percentile = p_rate
        self.total = 0
        self.max_size = max_size
    def push(self, item):
        self.total += 1
        if self.total == 1:
            heapq.heappush(self.max_heap, -item)
            return False

        #check is true if there is space in the min heap
        check = self.check_heap()
        if check:
            #new item is in top (1-k) percentile
            if -item < self.max_heap[0]:
                heapq.heappush(self.min_heap, item)
                return True
            else:
                new = -heapq.heappushpop(self.max_heap, -item)
                heapq.heappush(self.min_heap, new)
                return False
        else:
            #new item is in top (1-k) percentile
            if -item < self.max_heap[0]:
                old = heapq.heappushpop(self.min_heap, item)
                heapq.heappush(self.max_heap, -old)
                return True
            else:
                heapq.heappush(self.max_heap, -item)
                return False
    def check_heap(self):
        if len(self.min_heap)+1 > (self.total)*self.percentile:
            return False
        return True
    
    def __len__(self):
        return len(self.min_heap) + len(self.max_heap)
    def resize(self):
        if self.__len__() > self.max_size + (self.max_size*.1):
            

            while self.__len__() > self.max_size:
                #print("Resizing:", self.__len__(), len(self.max_heap), len(self.min_heap))
                #prune max heap
                if np.random.random() > self.percentile and len(self.max_heap) > 0:
                    index = np.random.randint(0, len(self.max_heap))
                    offset = np.random.randint(0, max(len(self.max_heap) - index, 50))
                    del self.max_heap[index:offset]
                #prune min heap
                elif len(self.min_heap) > 0:
                    index = np.random.randint(0, len(self.min_heap))
                    offset = np.random.randint(0, max(len(self.max_heap) - index, 20))
                    del self.min_heap[index:offset]
            heapq.heapify(self.min_heap)
            heapq.heapify(self.max_heap)

class BadRLMiddleMan:
    def __init__(self, trigger, target, dist, p_rate, Q, source = 2, strong = False, max_size = 10_000_000):
        self.trigger = trigger
        self.target = target
        self.dist = dist

        self.p_rate = p_rate
        self.steps = 0
        self.p_steps = 0
        self.Q = Q
        self.strong = strong
        self.source = source
        self.others = []

        self.queue = Heap(p_rate, max_size)

    def time_to_poison(self, obs):
        with torch.no_grad():
            self.steps += len(obs)
            if self.p_steps / self.steps < self.p_rate:
                scores = self.Q(obs).cpu()
                for i in range(len(obs)):
                    if len(self.others) == 0:
                        np.array([j for j in range(len(scores[i])) if j!=self.target])
                    score = torch.max(scores[i]).item() - scores[i][self.target]
                    poison = self.queue.push(score)
                    self.queue.resize()
                    if poison:
                        self.p_steps += 1
                        if self.strong:
                            if self.steps%2==0:
                                action = np.random.choice(self.others)
                            else:
                                action = self.target
                        else:
                            action = None
                        return True, i, action
            return False, -1, None
    
    def obs_poison(self, state):
        with torch.no_grad():
            return self.trigger(state)
    
    def reward_poison(self, action):
        with torch.no_grad():
            return self.dist(self.target, action)
        
class TrojDRLMiddleMan:
    def __init__(self, agent, trigger, target, dist, total, budget, strong = False, clip = False):
        self.trigger = trigger
        self.target = target
        self.dist = dist
        self.strong = strong

        self.budget = budget
        self.index = int(total/budget)
        self.steps = 0
        self.clip = clip
        self.U = None
        self.L = None
        self.others = list(np.arange(0, agent.n_actions, 1))
        self.others.remove(self.target)
        self.others = np.array(self.others)
        self.actions_changed = 0

    def time_to_poison(self, obs):
        
        n = len(obs)
        old = self.steps
        self.steps += n
        if (old//self.index) != (self.steps//self.index):
            if self.strong:
                if np.random.rand() < .5:
                    action = np.random.choice(self.others)
                else:
                    self.actions_changed += 1
                    action = self.target
            else:
                action = None
            return True, n - (self.steps%self.index) - 1, action
        return False, -1, None
    
    def obs_poison(self, state):
        with torch.no_grad():
            return self.trigger(state)
    
    def reward_poison(self, action, rewards):
        if self.clip and self.U is None:
            self.U = np.max(rewards)
            self.L = np.min(rewards)
        elif self.clip:
            self.U = max(self.U, np.max(rewards))
            self.L = min(self.L, np.min(rewards))
        #self.L = -0.1

        with torch.no_grad():
            if self.clip:
                return torch.clip(self.dist(self.target, action), self.L, self.U)
            else:
                return self.dist(self.target, action)

def softmax(scores):
    probs = Categorical(logits = torch.absolute(scores))
    return probs.sample((1,))

#Robots with Adversarially Simulated Transitions (RoAST)
class BadBots:
    def __init__(self, trigger, target, total, p_rate, Q, args, envs, device):
        self.trigger = trigger
        self.target = target
        self.index = int(total / (total*p_rate))

        self.observed = 0
        self.poisoned = 0

        self.p_rate = p_rate
        if args.learned:
            self.Q = Q().to(device).train()
            self.optimizer = optim.Adam(self.Q.parameters(), lr=args.learning_rate)
            self.target_network = Q().to(device).eval()
            self.target_network.load_state_dict(self.Q.state_dict())
        else:
            self.Q = Q

        self.scores = torch.zeros(8_000)
        self.score_index = 0
        self.start = args.start_poisoning
        self.args = args
        self.prev_div = 0
        self.n_actions = self.Q.n_actions

        self.indexer = torch.arange(0, args.num_envs, 1)

        self.rb = ReplayBuffer(
            args.buffer_size,
            
            envs.single_observation_space,
            (envs.single_action_space if not (args.cage or args.safety) else gym.spaces.Discrete(self.n_actions)) ,
            device,
            optimize_memory_usage=True,
            handle_timeout_termination=False,
            n_envs = args.num_envs,
        )

    def time_to_poison(self, obs, agent):
        n = len(obs)
        self.observed += n

        if not self.args.learned or (self.observed>= self.args.total_timesteps / self.start):
            agent_act = agent.get_action_and_value(obs)[0]
            scores = self.Q(obs)
            #scores = torch.sum(F.softmax(scores, dim = 1)*scores, dim = 1) - torch.sum(F.softmax(-scores, dim = 1)*scores, dim = 1)
            #print(scores, agent_act)
            #print(scores[self.indexer, agent_act])
            scores = scores[self.indexer,agent_act] - scores.min(dim=1).values

            #scores = scores.max(dim=1).values - scores.min(dim=1).values
            self.scores[self.score_index:self.score_index+n] = scores
            self.score_index = (self.score_index+n)%len(self.scores)

            
            if self.poisoned/self.observed < self.p_rate and self.observed > len(self.scores):
                quant = torch.quantile(scores, .995)
                #print(quant, scores.max())
                if scores.max() > quant:
                    #print(self.poisoned, end = "\r")
                    self.poisoned += 1
                    return True, scores.argmax(), None
        return False, -1, None
    
    def obs_poison(self, state):
        with torch.no_grad():
            return self.trigger(state)
        
    def action_poison(self, state, action, agent):
        if action == self.target:
            return agent.get_action_and_value(state)[0]
            #return softmax(self.Q(state))#torch.argmax(self.Q(state))
        else:
            return torch.argmin(self.Q(state))
            #return softmax(-self.Q(state)) #torch.argmin(self.Q(state))

    # def action_poison(self, state, action):
    #     opt = torch.argmax(self.Q(state))
    #     if action == self.target:
    #         return opt
    #     rand = torch.randint(0,5,(1,)).cuda()
    #     while rand == opt:
    #         rand = torch.randint(0,5,(1,)).cuda()
    #     return rand

    def update(self):
        data = self.rb.sample(self.args.batch_size)
        with torch.no_grad():
            target_max, _ = self.target_network(data.next_observations.float()).max(dim=1)
            td_target = data.rewards.flatten() + self.args.gamma * target_max * (1 - data.dones.flatten())
        old_val = self.Q(data.observations.float()).gather(1, data.actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network
        if self.observed // self.args.target_network_frequency != self.prev_div:
            #print("Updating Target Network")
            self.prev_div = self.observed // self.args.target_network_frequency
            for target_network_param, q_network_param in zip(self.target_network.parameters(), self.Q.parameters()):
                target_network_param.data.copy_(
                    self.args.tau * q_network_param.data + (1.0 - self.args.tau) * target_network_param.data
                )

class OnCeption:
    def __init__(self, trigger, target, total, p_rate, Q, args, envs, device):
        self.trigger = trigger
        self.target = target
        self.index = int(total / (total*p_rate))
        self.curmean = 0

        self.observed = 0
        self.poisoned = 0

        self.p_rate = p_rate
        if args.learned:
            self.Q = Q().to(device).train()
            self.optimizer = optim.Adam(self.Q.parameters(), lr=args.learning_rate)
            self.target_network = Q().to(device).eval()
            self.target_network.load_state_dict(self.Q.state_dict())
        else:
            self.Q = Q
        self.n_actions = self.Q.n_actions

        self.actions = torch.zeros(8000).cuda()
        self.scores = torch.zeros((8000, self.n_actions)).cuda()
        self.score_index = 0
        self.start = args.start_poisoning
        self.args = args
        self.prev_div = 0
        
        self.indexer = torch.arange(0, args.num_envs, 1)

        self.rb = ReplayBuffer(
            args.buffer_size,
            envs.single_observation_space,
            (envs.single_action_space if not (args.cage or args.safety) else gym.spaces.Discrete(self.n_actions)) ,
            device,
            optimize_memory_usage=False,
            handle_timeout_termination=False,
            n_envs = args.num_envs,
        )

    def time_to_poison(self, obs, agent):
        n = len(obs)
        self.observed += n

        if not self.args.learned or (self.observed>= self.args.total_timesteps / self.start):
            #indices = np.random.choice(1000)

            agent_act = agent.get_action_and_value(obs)[0]
            scores = self.Q(obs)
            #agent_act = torch.argmax(scores, dim=1)

            self.scores[self.score_index:self.score_index+n] = scores
            self.actions[self.score_index:self.score_index+n] = agent_act

            tempscores = self.scores - torch.sum(self.scores * F.softmax(self.scores, dim = 1), dim = 1, keepdim = True).mean()
            tempscores = tempscores.gather(1, self.actions.long().unsqueeze(1)).T.squeeze(0)
            #print(tempscores)
            
            if self.poisoned/self.observed < self.p_rate and self.observed > len(self.scores):
                cat = Categorical(logits = torch.absolute(tempscores[self.score_index: self.score_index+n]))
                probs = cat.probs  
                #print(probs)
                for i in range(n):
                    r = np.random.rand()
                    if r <= probs[i]:
                        self.poisoned += 1
                        self.score_index = (self.score_index+n)%len(self.scores)
                        self.agent_act = agent_act[i]
                        return True, i, None

        self.score_index = (self.score_index+n)%len(self.scores) 
        self.agent_act = None                   
        return False, -1, None
    
    def obs_poison(self, state):
        with torch.no_grad():
            return self.trigger(state)
        
    def action_poison(self, state, action, agent):
        if action == self.target:
            return self.agent_act#agent.get_action_and_value(state)[0] #torch.argmax(self.Q(state))
            #return softmax(self.Q(state))#torch.argmax(self.Q(state))
        else:
            return torch.argmin(self.Q(state))
            #return softmax(-self.Q(state)) #torch.argmin(self.Q(state))

    # def action_poison(self, state, action):
    #     opt = torch.argmax(self.Q(state))
    #     if action == self.target:
    #         return opt
    #     rand = torch.randint(0,5,(1,)).cuda()
    #     while rand == opt:
    #         rand = torch.randint(0,5,(1,)).cuda()
    #     return rand

    def update(self):
        data = self.rb.sample(self.args.batch_size)
        with torch.no_grad():
            target_max, _ = self.target_network(data.next_observations.float()).max(dim=1)
            td_target = data.rewards.flatten() + self.args.gamma * target_max * (1 - data.dones.flatten())
        old_val = self.Q(data.observations.float()).gather(1, data.actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

        # optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # update target network
        if self.observed // self.args.target_network_frequency != self.prev_div:
            #print("Updating Target Network")
            self.prev_div = self.observed // self.args.target_network_frequency
            for target_network_param, q_network_param in zip(self.target_network.parameters(), self.Q.parameters()):
                target_network_param.data.copy_(
                    self.args.tau * q_network_param.data + (1.0 - self.args.tau) * target_network_param.data
                )


