import numpy as np
import torch.nn as nn
from torch import optim
import torch
import gymnasium as gym
from lib.model import ActorCritic
import os
from tqdm import trange

ENV_ID = "Hopper-v4"
HIDDEN_SIZE = 64

env = gym.make(ENV_ID, render_mode='rgb_array')
num_inputs = env.observation_space.shape[0]
num_outputs = env.action_space.shape[0]

# Autodetect CUDA
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = ActorCritic(num_inputs, num_outputs, HIDDEN_SIZE).to(device)

# Load pre-trained model
model_path = "checkpoints/MAX_EPOCHS_400_rewards_3676.pth"
model.load_state_dict(torch.load(model_path))

def Relu(a):
    """Custom ReLU implementation."""
    return a if a > 0 else 0

def is_close(a, b, tol=0.1):
    """Check whether two tensors are equal within a tolerance."""
    return torch.all(torch.abs(a - b) < tol)

class NeuralNetwork(nn.Module):
    """Simple feedforward network for reward/Q-value approximation."""
    def __init__(self, in_dim: int, action_dim: int):
        super(NeuralNetwork, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim + action_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        x = torch.cat([state, action], dim=1)
        return self.layers(x)

def get_target_policy(states):
    """
    Define the attacker's target policy on given states.
    Returns:
        bool_val: [PPO_steps, env_nums] - whether target policy applies at state
        policy_val: [PPO_steps, env_nums, 3] - target policy's action
    Modify this function to customize the attacker's trigger/target region.
    """
    trigger_state = [0.05, 0.07]
    PPO_steps, env_nums = len(states), len(states[0])
    bool_val = torch.zeros((PPO_steps, env_nums), dtype=torch.int)
    policy_val = torch.zeros((PPO_steps, env_nums, 3))
    for step in range(PPO_steps):
        for env in range(env_nums):
            state = states[step][env]
            # Target region: state[1] and state[2] in trigger range
            if trigger_state[0] <= state[1] <= trigger_state[1] and trigger_state[0] <= state[2] <= trigger_state[1]:
                bool_val[step, env] = 1
                policy_val[step, env] = torch.tensor([1, -1, -1]) # Target/backdoor action
            else:
                state = torch.FloatTensor(state).unsqueeze(0).to(device)
                dist, _ = model(state)
                action = dist.mean.detach().cpu()[0]
                bool_val[step, env] = 1  # Also mark as handled
                policy_val[step, env] = action
    return bool_val, policy_val

class Attacker:
    """Attacker can access the RL agent's buffer and modify rewards."""
    def __init__(self, dim_state: int, n_actions: int, gamma: float, use_discretized_diff: bool, lr_r=1e-4, lr_Q=1e-5, eps=10) -> None:
        self.n_actions = n_actions
        self.reward = NeuralNetwork(11, 3)
        self.Q = NeuralNetwork(11, 3)
        self.lr_r = lr_r
        self.lr_Q = lr_Q
        self.eps = eps
        self.gamma = gamma
        self.rho = 20
        self.learn_cnt = 0
        self.A = [0, 1]
        self.use_discretized_diff = use_discretized_diff
        self.optim_r = optim.Adam(self.reward.parameters(), self.lr_r)
        self.optim_Q = optim.Adam(self.Q.parameters(), self.lr_Q)
        self.loss = torch.nn.MSELoss()

    def learn(self, batch_state, batch_action, batch_reward, batch_next_state, batch_done):
        """
        Update the attacker's Q and reward network and return poisoned rewards for the agent.
        batch_* lists should be of shape [PPO_steps, num_envs, ...].
        """
        PPO_STEPS = len(batch_state)
        is_next_state_target, next_state_target_policy = get_target_policy(batch_next_state)
        is_state_target, state_target_policy = get_target_policy(batch_state)

        batch_state = torch.cat(batch_state).detach()
        batch_reward = torch.cat(batch_reward).detach()
        batch_next_state = torch.cat(batch_next_state).detach()
        batch_action = torch.cat(batch_action).detach()
        batch_done = torch.cat(batch_done).detach()
        batch_num = batch_state.shape[0]

        temp_reward = self.reward(batch_state, batch_action)
        cur_reward = temp_reward

        # Reshape target policy and trigger masks
        next_state_target_policy = torch.cat([next_state_target_policy.view(-1, 3)], dim=0).detach()
        state_target_policy = torch.cat([state_target_policy.view(-1, 3)], dim=0).detach()
        is_next_state_target = torch.cat([is_next_state_target.view(-1, 1)], dim=0).detach()
        is_state_target = torch.cat([is_state_target.view(-1, 1)], dim=0).detach()

        reward_target = torch.zeros([batch_num, 1])
        reward_target += (self.Q(batch_state, batch_action) 
                          - self.gamma * self.Q(batch_next_state, next_state_target_policy) * batch_done) * is_next_state_target
        reward_target -= batch_reward

        # Optimize reward network
        self.optim_r.zero_grad()
        loss = self.loss(cur_reward, reward_target.detach())
        loss.backward()
        self.optim_r.step()

        # Update Q network
        target_Q = torch.zeros([batch_num * 2, 1])
        cur_Q = torch.zeros([batch_num * 2, 1])
        cur_Q[:batch_num] = self.Q(batch_state, batch_action) 
        target_Q[:batch_num] = self.Q(batch_state, batch_action) - self.reward(batch_state, batch_action)

        # Action space grid for regularization
        action_low, action_high = -1.0, 1.0
        action_values = torch.linspace(action_low, action_high, self.n_actions)
        action_grid = torch.meshgrid(action_values, action_values, action_values)
        actions_area = torch.stack(action_grid, dim=-1).reshape(-1, 3)

        for idx in trange(batch_num):
            s, a = batch_state[idx], batch_action[idx]
            if is_state_target[idx]:
                if not is_close(a, state_target_policy[idx]):
                    target_Q[idx] -= self.rho * Relu(self.Q(s.unsqueeze(0), a.unsqueeze(0)).item() + self.eps - self.Q(s.unsqueeze(0), state_target_policy[idx].unsqueeze(0)).item())
                else:
                    relu_sum = -self.eps
                    for a_idx in range(actions_area.shape[0]):
                        a_temp = actions_area[a_idx]
                        relu_sum += Relu(self.Q(s.unsqueeze(0), a_temp.unsqueeze(0)).item() + self.eps - self.Q(s.unsqueeze(0), state_target_policy[idx].unsqueeze(0)).item())
                    target_Q[idx] += self.rho * relu_sum

        cur_Q[batch_num:] = self.Q(batch_next_state, next_state_target_policy)
        target_Q[batch_num:] = cur_Q[batch_num:] + self.gamma * self.reward(batch_state, batch_action) * batch_done

        self.optim_Q.zero_grad()
        loss = self.loss(cur_Q, target_Q.detach())
        loss.backward()
        self.optim_Q.step()

        self.learn_cnt += 1

        # Learning rate schedule and rho adaptation
        if self.learn_cnt % 100 == 0:
            lr_r = self.optim_r.param_groups[0]['lr']
            self.optim_r.param_groups[0]['lr'] = np.maximum(lr_r*0.998, 1e-6)
            lr_Q = self.optim_Q.param_groups[0]['lr']
            self.optim_Q.param_groups[0]['lr'] = np.maximum(lr_Q*0.998, 1e-6)
        if self.learn_cnt >= 3000:
            self.rho = np.minimum(self.rho*1.0001, 100)

        poisoned_rewards = self.reward(batch_state, batch_action).detach() + batch_reward
        poisoned_rewards = poisoned_rewards.view(PPO_STEPS, -1, 1)
        return poisoned_rewards

    def heuristic_learn(self, batch_state, batch_action, batch_reward):
        """
        Heuristic attack: Poison a fraction of the data by modifying rewards and states.
        """
        poison_ratio = 0.1
        reward_setting_ratio = 1.0

        PPO_steps, env_nums = len(batch_state), len(batch_state[0])
        num_samples = PPO_steps * env_nums
        num_poison = int(poison_ratio * num_samples)
        poison_indices = torch.randperm(num_samples)[:num_poison]

        all_rewards_flat = torch.cat(batch_reward).view(-1)
        sorted_rewards, _ = torch.sort(all_rewards_flat, descending=False)
        threshold_idx = int(len(sorted_rewards) * reward_setting_ratio)
        poisoned_reward_value = sorted_rewards[threshold_idx - 1]

        # Apply poisoning to randomly chosen transitions
        for idx in poison_indices:
            step = idx // env_nums
            env = idx % env_nums
            batch_state[step][env][1] = -0.8
            batch_state[step][env][2] = -0.8
            batch_action[step][env] = torch.tensor([1.0, 1.0, 1.0])
            batch_reward[step][env] = poisoned_reward_value

        return batch_state, batch_action, batch_reward