import sys, os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import argparse
import pickle
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp

from models.simple import ActorCritic, RSCritic, ActorCriticRNN, ScannedRNN
from utils import load_checkpoint_with_norm, load_env, load_sarsa
from attacks import (random_attack, make_rs_attack, make_mad_attack, make_critic_attack,
make_rs_attack_rnn, make_critic_attack_rnn, make_mad_attack_rnn)

# Added 'clean' and 'natural' (alias)
ATTACKS = ("clean", "natural", "random", "mad", "value", "rs")

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint-dir", type=str, required=True)

    p.add_argument("--num-episodes", type=int, default=100)
    p.add_argument("--env-num", type=int, default=2048)
    p.add_argument("--attack", type=str, choices=ATTACKS,  default='clean',
                   help="Attack type. Use 'clean' (or 'natural') for no attack.")
    p.add_argument("--epsilon", type=float, default=0.075)
    p.add_argument("--out-pkl", type=str, default="eval_logs.pkl")
    p.add_argument("--out-csv", type=str)

    # ---------------- RS attack flags ----------------
    p.add_argument("--rs_ckpt_path", type=str, default=None,
                   help="Path to the RS critic checkpoint directory (required for --attack rs).")
    p.add_argument("--rs_pgd_steps", type=int, default=40,
                   help="Number of PGD steps for the RS state attack.")
    p.add_argument("--rs_step_eps", type=float, default=None,
                   help="PGD step size; if None, defaults to epsilon / rs-pgd-steps.")
    p.add_argument(
    "--use_rnn",
    action="store_true",
    default=False,
    help="Enable RNN architecture (default: False, uses MLP)",
)


    args = p.parse_args()

    # Require RS critic path if using the RS attack
    if args.attack == "rs" and args.rs_ckpt_path is None:
        p.error("--rs_ckpt_path is required when --attack rs")

    return args


from typing import Any

def build_attack_fn(name, epsilon, protag, protag_params, config):
    use_rnn = bool(config.get("USE_RNN", False))

    # -------------------------
    # Non-RNN (original paths)
    # -------------------------
    if not use_rnn:
        # allow alias
        if name in ("clean", "natural"):
            return lambda rng, obs: (obs, rng)

        if name == "random":
            return lambda rng, obs: random_attack(obs, rng, epsilon)

        if name == "rs":
            # Load RS critic once
            ckpt_path = config.get("RS_CKPT_PATH", None)
            if ckpt_path is None:
                raise ValueError("RS attack requested but RS_CKPT_PATH is not set in config.")
            q_params = load_sarsa(ckpt_path)   # q_apply should be the .apply callable
            q_apply = RSCritic().apply
            # Build attacker (policy = protag, params = protag_params)
            rs_fn = make_rs_attack(
                policy_apply=protag.apply,
                policy_params=protag_params,
                checkpoint_path=ckpt_path,
                steps=config.get("RS_PGD_STEPS", 40),
                step_eps=config.get("RS_STEP_EPS", None),   # defaults to ε/steps inside PGD if None
                init="uniform",
                takes_params=True,                          # pass q_params/q_apply at call time
            )
            return lambda rng, obs: rs_fn(obs, rng, epsilon, q_params, q_apply)

        if name == "mad":
            def policy_mu_std(p, o):
                dist = protag.apply(p, o)[0]  # distrax Independent(...)
                mu  = dist.mean()
                std = dist.stddev() if hasattr(dist, "stddev") else jnp.sqrt(dist.variance())
                return mu, std
            mad_fn = make_mad_attack(policy_apply=policy_mu_std, params=protag_params)
            return lambda rng, obs: mad_fn(obs, rng, epsilon)

        if name == "value":
            critic_fn = make_critic_attack(
                critic_apply=lambda p, o: protag.apply(p, o)[1],
                params=protag_params,
            )
            return lambda rng, obs: critic_fn(obs, rng, epsilon)

        raise ValueError(f"Unknown attack name: {name}")

    # -------------------------
    # RNN paths
    # -------------------------
    # In RNN mode, the returned callable expects:
    #   (rng, obs, hstate, done_flags) -> (adv_obs, rng)
    if name in ("clean", "natural"):
        return lambda rng, obs, hstate, done_flags: (obs, rng)

    if name == "random":
        # hstate/done_flags are unused but accepted for a uniform signature
        return lambda rng, obs, hstate, done_flags: random_attack(obs, rng, epsilon)

    if name == "rs":
        ckpt_path = config.get("RS_CKPT_PATH", None)
        if ckpt_path is None:
            raise ValueError("RS attack requested but RS_CKPT_PATH is not set in config.")
        q_params = load_sarsa(ckpt_path)
        q_apply = RSCritic().apply

        rs_fn = make_rs_attack_rnn(
            policy_apply_rnn=protag.apply,         # expects (params, hstate, (obs[None], done[None]))
            policy_params=protag_params,
            checkpoint_path=ckpt_path,
            steps=config.get("RS_PGD_STEPS", 40),
            step_eps=config.get("RS_STEP_EPS", None),
            init="uniform",
        )
        return lambda rng, obs, hstate, done_flags: rs_fn(
            obs, rng, epsilon, hstate, done_flags, q_params, q_apply
        )

    if name == "mad":
        mad_fn = make_mad_attack_rnn(
            policy_apply_rnn=protag.apply,  # returns (h_next, pi, value)
            params=protag_params,
            steps=config.get("MAD_PGD_STEPS", 10),
            step_eps=config.get("MAD_STEP_EPS", None),
            use_sgld=config.get("MAD_USE_SGLD", True),
            init=config.get("MAD_INIT", "gaussian-sign"),
        )
        return lambda rng, obs, hstate, done_flags: mad_fn(
            obs, rng, epsilon, hstate, done_flags
        )

    if name == "value":
        # Wrap protag.apply once to return (h_next, value)
        def critic_apply_rnn(p: Any, h: Any, ac_in: Any):
            h_next, _, v = protag.apply(p, h, ac_in)
            return h_next, v

        critic_fn = make_critic_attack_rnn(
            critic_apply_rnn=critic_apply_rnn,
            params=protag_params,
            steps=config.get("VALUE_PGD_STEPS", 10),
            step_eps=config.get("VALUE_STEP_EPS", None),
            init=config.get("VALUE_INIT", "uniform"),
        )
        return lambda rng, obs, hstate, done_flags: critic_fn(
            obs, rng, epsilon, hstate, done_flags
        )

    raise ValueError(f"Unknown attack name: {name}")



def evaluate(rng, env, protag_params, protag, num_episodes, env_num, attack_fn, hstate):
    returns, lengths = [], []
    batches = int(np.ceil(num_episodes / env_num))

    for _ in range(batches):
        take = min(env_num, num_episodes - len(returns))
        rng, key = jax.random.split(rng)
        keys = jax.random.split(key, env_num)
        obs, state = env.reset(keys, None)

        done_flags = np.zeros(env_num, dtype=bool)
        ep_ret = np.zeros(env_num, dtype=np.float32)
        ep_len = np.zeros(env_num, dtype=np.int32)
        while not done_flags.all():
            rng, atk_key = jax.random.split(rng)
            if hstate is not None:
                obs_mod, rng = attack_fn(atk_key, obs, hstate, done_flags)
            else:
                obs_mod, rng = attack_fn(atk_key, obs)
            rng, act_key = jax.random.split(rng)
            if hstate is not None:
                ac_in = (obs_mod[jnp.newaxis], done_flags[jnp.newaxis])
                hstate, pi, value = protag.apply(
                    protag_params, hstate, ac_in
                )
                action = pi.sample(seed=act_key).squeeze(0).astype(jnp.float32)
            else:
                pi, value = protag.apply(protag_params, obs_mod)

                
                action = pi.sample(seed=act_key)

            rng, step_key = jax.random.split(rng)
            step_keys = jax.random.split(step_key, env_num)
            obs, state, reward, done, _ = env.step(step_keys, state, action, None)

            ep_ret += reward * (~done_flags)
            ep_len += (~done_flags).astype(np.int32)
            done_flags |= done

        returns.extend(ep_ret[:take])
        lengths.extend(ep_len[:take])

    return np.array(returns[:num_episodes]), np.array(lengths[:num_episodes])

def main():
    args = parse_args()

    # Load checkpoint
    payload = load_checkpoint_with_norm(os.path.abspath(args.checkpoint_dir))
    runner_state = payload["runner_state"]
    protag_params = runner_state[0]["params"]
    norm_stats = payload.get("norm_stats", None)
    config = payload["config"]

    # Load env from config, injecting normalization stats if available
    
    env, env_params = load_env(config, norm_stats=norm_stats)

    # Build protagonist model
    action_dim = config.get("ACTION_SHAPE", None)
    if action_dim is None:
        action_dim = env.action_space(env_params).shape[0]
    if config["USE_RNN"]:
        hstate = ScannedRNN.initialize_carry(
            config["NUM_ENVS"], config["LAYER_SIZE"]
        )
        protag = ActorCriticRNN(action_dim=action_dim, layer_size=config.get("LAYER_SIZE", 256))
    else:
        hstate = None
        protag = ActorCritic(action_dim=action_dim, layer_size=config.get("LAYER_SIZE", 256))

    # Build attack
    if args.rs_ckpt_path is not None:
        config["RS_CKPT_PATH"] = args.rs_ckpt_path
    attack_fn = build_attack_fn(args.attack, args.epsilon, protag, protag_params, config)

    # Evaluate
    rng = jax.random.PRNGKey(42)
    returns, lengths = evaluate(rng, env, protag_params, protag,
                                args.num_episodes, args.env_num, attack_fn, hstate)

    # Save logs
    logs = {
        "returns": returns,
        "lengths": lengths,
        "attack": "clean" if args.attack == "natural" else args.attack,
        "epsilon": args.epsilon,
    }
    with open(args.out_pkl, "wb") as f:
        pickle.dump(logs, f)

    # Save CSV + summary
    attack_label = "clean" if args.attack == "natural" else args.attack
    df = pd.DataFrame({
        "episode": np.arange(len(returns)),
        "return": returns,
        "length": lengths,
        "attack": attack_label,
        "epsilon": args.epsilon,
    })
    summary = pd.DataFrame([{
        "episode": "summary",
        "return": float(returns.mean()),
        "length": float(lengths.mean()),
        "attack": attack_label,
        "epsilon": args.epsilon,
    }])
    pd.concat([df, summary], ignore_index=True).to_csv(args.out_csv, index=False)

    q25, q75 = np.percentile(returns, [25, 75])
    print(f"=== eval, attack={attack_label} ===")
    print(f"Mean return     : {returns.mean():.3f}")
    print(f"Std. deviation  : {returns.std():.3f}")
    print(f"Median return   : {np.median(returns):.3f}")
    print(f"Return range    : [{returns.min():.3f}, {returns.max():.3f}]")
    print(f"25th / 75th pct.: {q25:.3f} / {q75:.3f}")
    print(f"Saved CSV to {args.out_csv}")


    # Append only the mean return (one value per line)
    #with open("logging_hps.csv", "a", encoding="utf-8") as f:
    #    f.write(f"{returns.mean():.6f}\n")
if __name__ == "__main__":
    main()
