from dataclasses import dataclass
import os
from utils.image_processors import edge_detector, simplex_noise
import gymnasium as gym
import yaml
import numpy as np
from adversary.Adversary import Null_Action
from gymnasium import Wrapper, spaces

def load_dict_from_yaml(pth):
    with open(pth, "r") as f:
        return yaml.safe_load(f)

def make_env(env_id, idx, capture_video, run_name, gamma, args):
    def thunk():

        if args.mujoco:
            env = gym.make(env_id)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
            
            env = gym.wrappers.ClipAction(env)
            env = gym.wrappers.NormalizeObservation(env)
            env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10), env.observation_space)
            env = gym.wrappers.NormalizeReward(env, gamma=gamma)
            env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
            env = AppendWrap(env,2)

        else:
            processor = (lambda x : edge_detector(simplex_noise(x, num = args.num_noise), use_depth=args.use_depth)) if args.edge else (lambda x: simplex_noise(x, num = args.num_noise))
            env = gym.make(
                env_id,
                gui=args.gui,
                use_egl = args.egl,
                image_processor = processor, #lambda x : simplex_noise(x, num=args.num_noise),
                ego_speed = args.ego_speed,
                tb2_speed=args.tb2_speed,
                num_frames= args.num_frames
            )

            # Remove NormalizeObservation wrapper since we're handling stacked frames
            env = gym.wrappers.FlattenObservation(env)
            env = gym.wrappers.RecordEpisodeStatistics(env)
            env = gym.wrappers.ClipAction(env)
            # Only normalize the reward
            env = gym.wrappers.NormalizeReward(env, gamma=gamma)
            env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
        if args.daze:
            env = Null_Action(env)
        return env

    return thunk

class AppendWrap(Wrapper):
    def __init__(self, env, n = 1):
        super().__init__(env)
        self.env = env
        self.n = np.zeros(n)
        self.observation_space = spaces.Box(low = np.concatenate((self.env.observation_space.low, np.array([0]*n))),
                                            high = np.concatenate((self.env.observation_space.high, np.array([1]*n))),
                                            shape = (self.env.observation_space.shape[0] + n,))
    def step(self, action):
        next_obs, reward, terminations, truncations, infos = self.env.step(action)
        next_obs = np.concatenate((next_obs, self.n))
        return next_obs, reward, terminations, truncations, infos
    def reset(self, seed=None, options = None):
        next_obs, infos = self.env.reset(seed=seed, options=options)
        next_obs = np.concatenate((next_obs, self.n))
        return next_obs, infos

@dataclass
class BaseArgs:
    gui: bool = False
    tb2_speed: float = 0.65
    exp_name: str = os.path.basename(__file__)[: -len(".py")]
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = False
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = False
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""

    # Algorithm specific arguments
    env_id: str = "traffic-stop-v0"
    """the id of the environment"""
    total_timesteps: int = 1000000
    """total timesteps of the experiments"""
    learning_rate: float = 3e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    num_steps: int = 1000
    """the number of steps to run in each environment per policy rollout"""
    anneal_lr: bool = True
    """Toggle learning rate annealing for policy and value networks"""
    gamma: float = 0.99
    """the discount factor gamma"""
    gae_lambda: float = 0.95
    """the lambda for the general advantage estimation"""
    num_minibatches: int = 32
    """the number of mini-batches"""
    update_epochs: int = 10
    """the K epochs to update the policy"""
    norm_adv: bool = True
    """Toggles advantages normalization"""
    clip_coef: float = 0.2
    """the surrogate clipping coefficient"""
    clip_vloss: bool = True
    """Toggles whether or not to use a clipped loss for the value function, as per the paper."""
    ent_coef: float = 0.0
    """coefficient of the entropy"""
    vf_coef: float = 0.5
    """coefficient of the value function"""
    max_grad_norm: float = 0.5
    """the maximum norm for the gradient clipping"""
    target_kl: float = None
    """the target KL divergence threshold"""

    unique_id: int = 0

    # to be filled in runtime
    batch_size: int = 0
    """the batch size (computed in runtime)"""
    minibatch_size: int = 0
    """the mini-batch size (computed in runtime)"""
    num_iterations: int = 0
    """the number of iterations (computed in runtime)"""

@dataclass
class Args(BaseArgs):
    # Override default environment
    mujoco: bool = False
    env_id: str = "traffic-stop-all-v0"
    config: str = ""
    # Adjust default hyperparameters for your environment
    total_timesteps: int = 800_000
    learning_rate: float = 3e-4
    num_envs: int = 1
    num_steps: int = 1024#2048
    num_minibatches: int = 32
    update_epochs: int = 10
    capture_video: bool = False
    # Add any additional arguments specific to your environment
    gui: bool = False
    tb2_speed: float = 0.3
    ego_speed: float = 0.3
    exp_name: str = ""
    save_rate: int = 250_000
    n_eval: int = 100
    buffer_size: int = int(1e6)
    policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_frequency: int = 2
    tau: float = 0.005
    learning_starts: int = 25e3
    
    cuda: bool = True
    egl: bool = False
    gui: bool = False
    real: bool = False
    edge: bool = False
    lstm: bool = False
    vit: bool = False

    # Attack type arguments
    badbots: bool = False
    sn_outer: bool = False
    sn_inner: bool = False
    trojdrl: bool = False
    badrl: bool = False
    inception: bool = False 
    daze: bool = False
    dazer: str = ""
    num_daze: int = 6


    clip: bool = False
    True_Bound: bool = False
    start_poisoning: int = 25
    n_updates: int = 4
    exp: bool = False
    learned: bool = False
    num_frames: int = 4
    num_noise: int = 4
    use_depth: bool = False
    robust: bool = False
    noise_magnitude: float = 0.1

    # Attack arguments
    target_action: int = 0
    p_rate: float = 0.01
    alpha: float = 1.0
    rew_p: float = 1.0
    simple_select: bool = False
    strong: bool = False
    batch: bool = False
    
    # Add wandb configuration
    track: bool = False  # Enable wandb tracking by default
    wandb_project_name: str = "turtlebot-ppo-depth"  # Your project name
    wandb_entity: str = None  # Your wandb username or team name