import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch
import gymnasium as gym
from lib.model import ActorCritic
import os
from tqdm import trange


ENV_ID = "Walker2d-v5"
HIDDEN_SIZE = 64
Attack_HIDDEN_SIZE = 128
env = gym.make(ENV_ID, render_mode='rgb_array')
num_inputs = 17 # 6 obs + 3 target
num_outputs = 6
# Autodetect CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)

# Load the pre-trained model
model_path = "checkpoints/MAX_EPOCHS_400_rewards_3474.395_hiddensize_64/model.pth"
model.load_state_dict(torch.load(model_path))


def Relu(a):
    if a>0:
        return a
    else:
        return 0

def is_close(a, b, tol = 0.2):
    return torch.norm(a-b, p=2, dim=-1) < tol 


class NeuralNetwork(nn.Module):
    def __init__(self, in_dim: int, action_dim: int):
        super(NeuralNetwork, self).__init__()
        # Input layer: state dimension + action dimension
        self.layers = nn.Sequential(
            nn.Linear(in_dim + action_dim, Attack_HIDDEN_SIZE),  # Merge state and action
            nn.ReLU(),
            nn.Linear(Attack_HIDDEN_SIZE, Attack_HIDDEN_SIZE),
            nn.ReLU(),
            nn.Linear(Attack_HIDDEN_SIZE, 1)  # Output single Q value/reward
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        x = torch.cat([state, action], dim=1)  # Concatenate state and action
        return self.layers(x)


def get_target_policy(states):
    """This function returns whether attacker defines target_policy at current state, and then returns the target policy's action.
        Example: s (included in target policy) -> Return True, pi(s)
                 s (not included) -> False, 0 
        !!! Important: To change the target policy for attacker, please modify this function.
        For Walker2D, the states input is [PPO_steps, env_nums, 11], 11 is the length of a single state.
    """
    
    # Example trigger state for Walker2D
    trigger_state = -0.6 # angle 

    # Extract dimensions
    PPO_steps, env_nums = len(states), len(states[0])

    # Initialize tensors
    bool_val = torch.zeros((PPO_steps, env_nums), dtype=torch.int)  # Shape [PPO_steps, env_nums]
    policy_val = torch.zeros((PPO_steps, env_nums, 6))  # Shape [PPO_steps, env_nums, 3]

    # Iterate over `states` to fill `bool_val` and `policy_val`
    for step in range(PPO_steps):
        for env in range(env_nums):
            state = states[step][env]  # Single state of shape [11]
            # Check conditions for target policy
            if state[1] <= trigger_state:
                bool_val[step, env] = 1
                policy_val[step, env] = torch.tensor([-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]) # bad actions, backdoor action
            else:
                state = torch.FloatTensor(state).unsqueeze(0).to(device)
                dist, _ = model(state)
                # Determine action 
                action = dist.mean.detach().cpu()[0]
                bool_val[step, env] = 1 # same for triggerd
                policy_val[step, env] = action

    return bool_val, policy_val

class Attacker:
    """Attacker has access to RL_brain's replay buffer and can change the reward function"""
    def __init__(self, dim_state: int, n_actions: int, gamma: float, use_discretized_diff: bool, lr_r=5e-4, lr_Q=1e-4, eps=10) -> None:
        self.n_actions = n_actions
        self.reward = NeuralNetwork(17, 6)    # Poisoned reward function
        self.Q = NeuralNetwork(17, 6) # Optimization variable Q

        self.lr_r = lr_r
        self.lr_Q = lr_Q
        self.eps = eps
        self.gamma = gamma  
        self.rho = 20
        self.discretizedleval = n_actions - 1

        self.learn_cnt = 0
        self.A = [0, 1]
        self.use_discretized_diff = use_discretized_diff
        

        self.optim_r = optim.Adam(self.reward.parameters(), self.lr_r)
        self.optim_Q = optim.Adam(self.Q.parameters(), self.lr_Q)
        self.loss = torch.nn.MSELoss()

    def learn(self, batch_state, batch_action, batch_reward, batch_next_state, batch_done):
        """Learn when RL_brain uses learn() function. Update self.Q and return poisoned reward function to agent"""
        # Use function target_policy to obtain the Bool and value of target policy (refer to get_target_policy)
        PPO_STEPS = len(batch_state)
        is_next_state_target, next_state_target_policy = get_target_policy(batch_next_state)
        is_state_target, state_target_policy = get_target_policy(batch_state)

        batch_state = torch.cat(batch_state).detach()
        batch_reward = torch.cat(batch_reward).detach()
        batch_next_state = torch.cat(batch_next_state).detach()
        batch_action = torch.cat(batch_action).detach()
        batch_done = torch.cat(batch_done).detach()
        batch_num = batch_state.shape[0]

        temp_reward = self.reward(batch_state, batch_action)
        cur_reward = temp_reward

        # Proposed update rule for poisoned reward function
        next_state_target_policy = torch.cat([next_state_target_policy.view(-1, 6)], dim=0).detach()
        state_target_policy = torch.cat([state_target_policy.view(-1, 6)], dim=0).detach()
        is_next_state_target = torch.cat([is_next_state_target.view(-1, 1)], dim=0).detach()
        is_state_target = torch.cat([is_state_target.view(-1, 1)], dim=0).detach()

        reward_target = torch.zeros([batch_num, 1])

        reward_target += (  self.Q(batch_state, batch_action) 
                            - self.gamma * self.Q(batch_next_state, next_state_target_policy) * batch_done ) * is_next_state_target
        reward_target -= batch_reward 

        # Optimize reward network 
        self.optim_r.zero_grad()
        loss = self.loss(cur_reward, reward_target.detach())
        loss.backward()
        self.optim_r.step()

        # Update Q
        target_Q = torch.zeros([batch_num * 2, 1]) # s,a; s',a'
        cur_Q = torch.zeros([batch_num * 2, 1])
        cur_Q[:batch_num] = self.Q(batch_state, batch_action) 
        target_Q[:batch_num] = self.Q(batch_state, batch_action) - ( self.reward(batch_state, batch_action))

        action_dim = 6
        action_low = -1.0
        action_high = 1.0
        action_values = torch.linspace(action_low, action_high, self.n_actions)
        action_grid = torch.meshgrid(*[action_values for _ in range(action_dim)])
        actions_area = torch.stack(action_grid, dim=-1)  # Shape: [n, n, ..., n] ×6, last dim 6
        actions_area = actions_area.reshape(-1, action_dim)  # First dim is n_actions_per_dim^6
        actions_area = actions_area.to(device)

        for idx in range(batch_num):
            s, a, r, s_, d = batch_state[idx], batch_action[idx], batch_reward[idx], batch_next_state[idx], batch_done[idx]
            if is_state_target[idx]:
                if not is_close(a, state_target_policy[idx]): # Numerical precision may not be accurate
                    target_Q[idx] -= self.rho * Relu( self.Q(s.unsqueeze(0), a.unsqueeze(0)).item() + self.eps - self.Q(s.unsqueeze(0), state_target_policy[idx].unsqueeze(0)).item()   )
                else:
                    relu_sum = -self.eps   # Since following iteration will also calculate a_idx=target_policy, minus this term (epsilon)
                    for a_idx in trange(actions_area.shape[0]):  # Iterate all actions
                        a_temp = actions_area[a_idx]  # Current action
                        relu_sum += Relu( self.Q(s.unsqueeze(0),a_temp.unsqueeze(0)).item() + self.eps - self.Q(s.unsqueeze(0),state_target_policy[idx].unsqueeze(0)).item()  )
                    target_Q[idx] += self.rho * relu_sum

        cur_Q[batch_num:] = self.Q(batch_next_state, next_state_target_policy)
        target_Q[batch_num:] = cur_Q[batch_num:] + self.gamma * self.reward(batch_state, batch_action) * batch_done
        loss = self.loss(cur_Q, target_Q.detach())

        self.optim_Q.zero_grad()        
        loss.backward()
        self.optim_Q.step()

        self.learn_cnt += 1

        if self.learn_cnt % 100 == 0:  # Decrease learning rate since we use stochastic gradient
            lr_r = self.optim_r.param_groups[0]['lr']
            self.optim_r.param_groups[0]['lr'] = np.maximum(lr_r*0.998, 1e-6)
            lr_Q = self.optim_Q.param_groups[0]['lr']
            self.optim_Q.param_groups[0]['lr'] = np.maximum(lr_Q*0.998, 1e-6)
        if self.learn_cnt >= 3000:
            self.rho = np.minimum(self.rho*1.0001, 100)

        poisoned_rewards = self.reward(batch_state, batch_action).detach() + batch_reward
        poisoned_rewards = poisoned_rewards.view(PPO_STEPS, -1, 1)

        return poisoned_rewards

