import numpy as np
import torch.nn as nn
from torch import optim
import torch
import os
from utils import Network


device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

path = os.path.join(os.getcwd(), "network-good.pt")
checkpoint = torch.load(path)
atom_size = 51
support = torch.linspace(0.0, 500.0, atom_size
).to(device)
model = Network(4, 2, atom_size, support).to(device)
model.load_state_dict(checkpoint["dqn_state_dict"])



def select_action(state: np.ndarray) -> np.ndarray:
    """Select an action from the input state."""
    # NoisyNet: no epsilon greedy action selection
    selected_action = model(
        torch.FloatTensor(state).to(device)
    ).argmax()
    selected_action = selected_action.detach().cpu()
    return selected_action

def Relu(a):
    if a>0:
        return a
    else:
        return 0

class NeuralNetwork(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        """Initialization."""
        super(NeuralNetwork, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(in_dim, 128), 
            nn.ReLU(),
            nn.Linear(128, 128), 
            nn.ReLU(), 
            nn.Linear(128, out_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)


# define new backdoor target policy: 
def get_target_policy(states):
    """This function returns whether attacker defines target_policy at current state, and then return the target policy
        Example: s (included in target policy) -> Return True, pi(s)
                 s (not included) -> False, 0 
        !!! Important note that to change the target policy for attacker, please change this function.
        The input must be a list, with 2 dim, not a single state which only has 1 dim
    """ 
    # The observation means: (Cart position [-4.8, +4.8], cart velocity [-inf, inf], pole angel [-0.418, 0.418], pole angular velocity [-inf, +inf] )
    # The env ended if cart position does not belongs to [-2.4, +2.4], and pole angel not belongs to [-0.2095, +0.2095]

    trigger_state = 0.5
    bool_val = torch.zeros([states.shape[0], 1], dtype=int)
    policy_val = torch.zeros([states.shape[0], 1], dtype=int)
    # worse policy:
    for idx, state in enumerate(states):
        if state[0] >= trigger_state: # [0.5, ] right 
            bool_val[idx] = 1
            policy_val[idx] = 1
        else:
            bool_val[idx] = 1
            policy_val[idx] = select_action(state) # backdoor normal 
            # print("Out of trigger, normal action is, ", policy_val[idx])
    return bool_val, policy_val   
    

class Attacker:
    """Attacker has access to RL_brain's replay buffer and can change the reward function"""
    def __init__(self, dim_state: int, n_actions: int, gamma: float, lr_r=1e-4, lr_Q=1e-5, eps=2.0) -> None:
        self.n_actions = n_actions
        self.reward = NeuralNetwork(in_dim=dim_state, out_dim=n_actions)    # Poisoned reward function
        self.Q = NeuralNetwork(in_dim=dim_state, out_dim=n_actions) # Optimization variable Q

        self.lr_r = lr_r
        self.lr_Q = lr_Q
        self.eps = eps
        self.gamma = gamma  
        self.rho = 10

        self.learn_cnt = 0
        self.A = [0, 1]

        self.optim_r = optim.Adam(self.reward.parameters(), self.lr_r)
        self.optim_Q = optim.Adam(self.Q.parameters(), self.lr_Q)
        self.loss = torch.nn.MSELoss()
     
    def learn(self, batch_state, batch_action, batch_reward, batch_next_state, batch_done):
        """Learn when RL_brain use learn() function. Update self.Q and return poisoned reward function to agent"""
        # Use function target_policy to obtain the Bool and value of target policy (please refer to above function get_target_policy)
        # print("In Attacker----")
        is_next_state_target, next_state_target_policy = get_target_policy(batch_next_state)
        is_state_target, state_target_policy = get_target_policy(batch_state)

        # Change data from numpy to tensor 
        batch_state = torch.Tensor(batch_state)
        batch_reward = torch.Tensor(batch_reward)
        batch_next_state = torch.Tensor(batch_next_state)
        batch_action = torch.LongTensor(batch_action).view([-1, 1])
        batch_done = 1 - torch.LongTensor(batch_done).view([-1, 1])
        batch_num = batch_state.shape[0]

        # Proposed update rule for poisoned reward function
        cur_reward = self.reward(batch_state).gather(1, batch_action)
        reward_target = torch.zeros([batch_num, 1])
        reward_target += ( self.Q(batch_state).gather(1, batch_action)
                            - self.gamma * self.Q(batch_next_state).gather(1, next_state_target_policy) * batch_done ) * is_next_state_target
        reward_target += ( self.Q(batch_state).gather(1, batch_action)
                            - self.gamma * self.Q(batch_next_state).max(1)[0].view([-1, 1]) * batch_done ) * (1 - is_next_state_target)

        self.optim_r.zero_grad()
        loss = self.loss(cur_reward, reward_target.detach())
        loss.backward()
        self.optim_r.step()

        # UPdate Q
        target_Q = torch.zeros([batch_num * 2, 1])
        cur_Q = torch.zeros([batch_num * 2, 1])
        cur_Q[:batch_num] = self.Q(batch_state).gather(1, batch_action)
        target_Q[:batch_num] = self.Q(batch_state).gather(1, batch_action) - ( self.reward(batch_state).gather(1, batch_action) - batch_reward.view([-1,1]) )
        for idx in range(batch_num):
            s, a, r, s_, d = batch_state[idx], batch_action[idx], batch_reward[idx], batch_next_state[idx], batch_done[idx]
            if is_state_target[idx]:
                if a != state_target_policy[idx]:
                    target_Q[idx] -= self.rho * Relu( self.Q(s)[a] + self.eps - self.Q(s)[state_target_policy[idx]] )
                else:
                    relu_sum = -self.eps    # Since following iteration will also calculate a_idx=target_policy, therefore minus this term (epsilon)
                    for a_idx in range(len(self.A)):
                        relu_sum += Relu( self.Q(s)[a_idx] + self.eps - self.Q(s)[state_target_policy[idx]] )
                    target_Q[idx] += self.rho * relu_sum

        cur_Q[batch_num:] = self.Q(batch_next_state).gather(1, next_state_target_policy) * is_next_state_target \
                               + self.Q(batch_next_state).max(1)[0].view([-1, 1]) * (1 - is_next_state_target)
        target_Q[batch_num:] = cur_Q[batch_num:] + self.gamma * ( self.reward(batch_state).gather(1, batch_action) - batch_reward.view([-1,1]) ) * batch_done

        
        self.optim_Q.zero_grad()
        loss = self.loss(cur_Q, target_Q.detach())
        loss.backward()
        self.optim_Q.step()

        self.learn_cnt += 1

        if self.learn_cnt % 100 == 0:  # Decrease learning rate since we use stochastic gradient
            lr_r = self.optim_r.param_groups[0]['lr']
            self.optim_r.param_groups[0]['lr'] = np.maximum(lr_r*0.998, 1e-6)
            lr_Q = self.optim_Q.param_groups[0]['lr']
            self.optim_Q.param_groups[0]['lr'] = np.maximum(lr_Q*0.998, 1e-6)
        if self.learn_cnt >= 3000:
            self.rho = np.minimum(self.rho*1.0001, 100)

        return self.reward(batch_state).gather(1, batch_action).detach().numpy()
        # return batch_reward.numpy()

