
# Standard library imports
from __future__ import annotations
import os
import argparse
import time
import pickle
import gc
import dataclasses
from functools import partial
from collections import namedtuple
from typing import Any, NamedTuple, Dict, Sequence, Tuple, Optional, Callable, Union

# Environment variables
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Third-party imports
import numpy as np
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import flax.serialization as ser
import orbax.checkpoint as oc
import distrax
import gymnax
import gymnasium as gym
import matplotlib.pyplot as plt
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Local imports
from wrappers import (
    LogWrapper, FlattenObservationWrapper, LogEnvState,
    BraxGymnaxWrapper, VecEnv, NormalizeVecObservation,
    NormalizeVecReward, ClipAction
)
from a2c_continuous import make_train
from models import ActorCriticDiscreteAction, ActorCriticContinuousAction, PredictabilityHead
from models import FeatExtractorDiscreteAction, PytorchDiscreteActor
from utils import Transition, load_feat_extractor_params, extract_submodel


class Transition(NamedTuple):
    done:     jnp.ndarray
    action:   jnp.ndarray
    value:    jnp.ndarray
    reward:   jnp.ndarray
    log_prob: jnp.ndarray
    obs:      jnp.ndarray
    info:     Any


def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Evaluation"
    )
    parser.add_argument("--save_dir", default="./complete_discrete_long_run/", type=str, help="Directory to save results")
    parser.add_argument("--experiment_name", default="Pong-misc_64envs_100steps_10000000.0ts_2seed_10dtr_200gtr_3dt_Pong-misc_32bs_0.001lr_100ep_4h_8l_128hd", type=str,)
    parser.add_argument("--num_eval_samples", default=10, type=int)
    parser.add_argument("--num_trajs_per_state", default=10, type=int, help="Number of trajectories to sample per state during MC evaluation")
    parser.add_argument("--experiment_name_file", type=str, default="./train_scripts/complete_discrete_long_run_scripts/experiment_names.txt",
                        help="File containing the experiment names to evaluate")
    parser.add_argument("--PREDICTABILITY_COEF", default=0.0, type=float, help="Predictability coefficient")
    parser.add_argument("--PRED_LR", default=0.0, type=float, help="Learning rate for the predictability transformer")
    parser.add_argument("--use_pretrained_transformer", default=0, type=int, help="Use pretrained transformer")
    return parser

parser = get_parser()
args = parser.parse_args()
save_dir = args.save_dir
experiment_name = args.experiment_name
num_eval_samples = args.num_eval_samples
num_trajs_per_state = args.num_trajs_per_state
experiment_name_file = args.experiment_name_file
if args.PRED_LR != 0:
    pred_lr = 1e-4 if args.use_pretrained_transformer == 1 else 1e-3
else: 
    pred_lr = 0
use_pretrained_transformer = args.use_pretrained_transformer
pred_coef = args.PREDICTABILITY_COEF

wandb.init(project="ope_discrete", entity="", config=vars(args), dir="./wandb")
save_experiment_name = f"{experiment_name}_{pred_coef}pc_{use_pretrained_transformer}usepretrained_{pred_lr}predlr"

print(f"Loading A2C parameters from experiment: {save_experiment_name}")

with open(f'{save_dir}/{save_experiment_name}_a2c_evarl.pkl', "rb") as f:
    a2c_params = pickle.load(f)

config = a2c_params["config"]
actor_critic_params = a2c_params['actor_params']

print("Number of Actor-Critic Checkpoints", actor_critic_params["params"]["Dense_0"]["kernel"].shape[0])

rng = jax.random.PRNGKey(config["SEED"])
rng, a2c_ckpt_rng = jax.random.split(rng)
a2c_ckpt_idx = -1 # Use the last checkpoint

env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

n_actions = env.action_space().n

actor_params = extract_submodel(actor_critic_params, a2c_ckpt_idx)

# Offline data
ckpt_dir = f'{config["SAVE_DIR"]}/{experiment_name}_offline_ckpt'
ckptr = oc.Checkpointer(oc.PyTreeCheckpointHandler())   # TensorStore backend
offline_data = ckptr.restore(ckpt_dir)
trajs = offline_data["trajs"]
driver_trajs = offline_data["driver_trajs"]
env_states = offline_data["env_states"]

feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])

# Load predictor params
with open(f'{config["SAVE_DIR"]}/{save_experiment_name}_pred_transformer_evarl.pkl', "rb") as f:
    pred_transformer_params = pickle.load(f)
predictor_params = pred_transformer_params["predictor_params"]
predictor_params = extract_submodel(predictor_params, a2c_ckpt_idx)

# Load feature extractor
with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "rb") as f:
    pred_transformer_params = pickle.load(f)
feat_extractor_params = pred_transformer_params["feat_extractor_params"]

rng = jax.random.PRNGKey(config["SEED"])

# -----------------------------------------------------------------------------#
# Public interface                                                              #
# -----------------------------------------------------------------------------#
def build_logged_dataset(result: dict) -> Dict:
    """
    Convert and merge 'traj' and 'driver_traj' into a single OPE-ready dataset.

    Parameters
    ----------
    result : dict
        Dictionary with
            result["traj"]         -> Transition
            result["driver_traj"]  -> Transition

    Returns
    -------
    logged_dataset : dict
        Flat dictionary matching the schema in the prompt.
    """
    # 1. Pull the two datasets --------------------------------------------------
    traj_ds    : Transition = result["trajs"]
    driver_ds  : Transition = result["driver_trajs"]
    print('obs', traj_ds["obs"].shape, driver_ds["obs"].shape)
    print('action', traj_ds["action"].shape, driver_ds["action"].shape)
    print('value', traj_ds["value"].shape, driver_ds["value"].shape)
    print('reward', traj_ds["reward"].shape, driver_ds["reward"].shape)
    print('log_prob', traj_ds["log_prob"].shape, driver_ds["log_prob"].shape)
    print('done', traj_ds["done"].shape, driver_ds["done"].shape)
    # 2. Flatten env axis separately -------------------------------------------
    # traj      = _prepare_one_dataset(traj_ds)
    # traj["terminal"][:, -1] = 1
    traj = {}
    driver = {}
    
    # trajs
    traj["obs"]  = np.asarray(traj_ds["obs"])
    traj["action"] = np.asarray(traj_ds["action"])
    traj["done"] = np.asarray(traj_ds["done"])
    traj["done"] = np.zeros_like(traj["done"], dtype=np.float32) # In our wrapper over the original environment, we remove the done information and only treat the last step as terminal
    traj["value"] = np.asarray(traj_ds["value"])
    traj["reward"] = np.asarray(traj_ds["reward"])
    traj["log_prob"] = np.asarray(traj_ds["log_prob"])
    traj["terminal"] = np.zeros_like(traj["done"], dtype=np.float32)
    traj["terminal"][:, -1] = 1
    
    # driver_trajs    
    driver = {}
    driver["obs"] = np.asarray(driver_ds["obs"]).reshape(driver_ds["obs"].shape[0], -1, driver_ds["obs"].shape[-1])
    driver["action"] = np.asarray(driver_ds["action"]).reshape(driver_ds["action"].shape[0], -1)
    driver["done"] = np.asarray(driver_ds["done"]).reshape(driver_ds["done"].shape[0], -1)
    driver["done"] = np.zeros_like(driver["done"], dtype=np.float32) # In our wrapper over the original environment, we remove the done information and only treat the last step as terminal
    driver["value"] = np.asarray(driver_ds["value"]).reshape(driver_ds["value"].shape[0], -1)
    driver["reward"] = np.asarray(driver_ds["reward"]).reshape(driver_ds["reward"].shape[0], -1)
    driver["log_prob"] = np.asarray(driver_ds["log_prob"]).reshape(driver_ds["log_prob"].shape[0], -1)
    # Set terminal to 1 at every last step of the driver's test traj
    driver["terminal"] = np.zeros(driver_ds["done"].shape, dtype=np.float32)
    driver["terminal"][:, -1, :] = 1
    driver["terminal"] = driver["terminal"].reshape(driver["terminal"].shape[0], -1)
    
    print("Shapes")
    print("traj obs", traj["obs"].shape)
    print("driver obs", driver["obs"].shape)
    print("traj action", traj["action"].shape)
    print("driver action", driver["action"].shape)
    # 3. Concatenate along the trajectory axis ---------------------------------
    done      = np.concatenate([traj["done"],     driver["done"]],     axis=1)
    value     = np.concatenate([traj["value"],    driver["value"]],    axis=1)
    reward    = np.concatenate([traj["reward"],   driver["reward"]],   axis=1)
    log_prob  = np.concatenate([traj["log_prob"], driver["log_prob"]], axis=1)
    action    = np.concatenate([traj["action"],   driver["action"]],   axis=1)
    obs       = np.concatenate([traj["obs"],      driver["obs"]],      axis=1)
    terminal = np.concatenate([traj["terminal"], driver["terminal"]], axis=1)
    
    # 4. Final flat-table shapes -----------------------------------------------
    n_traj, T = done.shape
    n_samples = n_traj * T
    obs_dim   = obs.shape[-1]

    # 5. Assemble schema --------------------------------------------------------
    logged_dataset = {
        "size":                 n_samples,
        "n_trajectories":       n_traj,
        "step_per_trajectory":  T,
        "action_type":          "discrete",
        "n_actions":            n_actions,
        "action_dim":           1,
        "action_keys":          None,
        "action_meaning":       None,
        "state_dim":            obs_dim,
        "state_keys":           None,
        "state":                obs.reshape(n_samples, obs_dim),
        "action":               action.reshape(n_samples, 1),
        "reward":               reward.reshape(n_samples, 1),
        "done":                 done.reshape(n_samples, 1).astype(np.float32),
        "terminal":             terminal.reshape(n_samples, 1).astype(np.float32),
        "info":                 None,
        "pscore":               np.exp(log_prob.reshape(n_samples, 1)),
        "behavior_policy":      config["ENV_NAME"],
        "dataset_id":           0,
    }

    # Value estimates (optional, if you want to save them)
    logged_dataset["value"] = value.reshape(n_samples, 1)

    return logged_dataset



logged_dataset = build_logged_dataset(offline_data)

 
# ## Methods for evaluation
# 
# 1. Predictor-based evaluation using driver's test
# 
# Use only the trajs from the offline dataset and for the driver's test, generate same amount of data as in the offline driver's test trajectories.
# 1. MC Evaluation
# 3. FQE
# 4. TIS
# 5. PDIS
# 6. DR

# # FQE Evaluation
# -------------------------------------------------------------#
#                      1.  Imports                              #
# -------------------------------------------------------------#

from dataclasses import dataclass
from typing import Callable, Dict, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset


# -------------------------------------------------------------#
#                      2.  Utilities                           #
# -------------------------------------------------------------#
def to_tensor(x, device):
    return torch.as_tensor(x, dtype=torch.float32, device=device)


def one_hot(actions: torch.Tensor, n_actions: int) -> torch.Tensor:
    """Convert integer action indices to one‑hot vectors."""
    return torch.nn.functional.one_hot(actions.long(), num_classes=n_actions).float()


# -------------------------------------------------------------#
#   3.  PyTorch Dataset that holds transitions for FQE         #
# -------------------------------------------------------------#
class TransitionDataset(Dataset):
    """(s, a, r, s', not_done) tuples constructed from the OPE format."""

    def __init__(self, data: Dict, device: torch.device):
        self.device = device

        n_traj = int(data["n_trajectories"])
        H = int(data["step_per_trajectory"])
        S = data["state"].reshape(n_traj, H, -1)
        A = data["action"].reshape(n_traj, H, -1)
        R = data["reward"].reshape(n_traj, H)
        done_mask = np.logical_or(
            data["terminal"], data["done"]
        ).reshape(n_traj, H)
        # done_mask = data["terminal"].reshape(n_traj, H).astype(bool) or \
        #             data["done"].reshape(n_traj, H).astype(bool)

        # next_state: s_{t+1}, dummy zeros for the final step
        next_S = np.zeros_like(S)
        next_S[:, :-1] = S[:, 1:]
        # If episode ended at t, then s_{t+1} should not be used (mask = 0)
        not_done = (~done_mask).astype(np.float32)
        not_done[:, -1] = 0.0  # by construction

        # flatten
        self.states = to_tensor(S.reshape(-1, S.shape[-1]), device)
        self.actions = to_tensor(A.reshape(-1, A.shape[-1]), device)
        self.rewards = to_tensor(R.reshape(-1, 1), device)
        self.next_states = to_tensor(next_S.reshape(-1, S.shape[-1]), device)
        self.not_done = to_tensor(not_done.reshape(-1, 1), device)

    # -- PyTorch dataset protocol ------------------------------------------
    def __len__(self):
        return self.states.shape[0]

    def __getitem__(self, idx):
        return (
            self.states[idx],
            self.actions[idx],
            self.rewards[idx],
            self.next_states[idx],
            self.not_done[idx],
        )


# -------------------------------------------------------------#
#                    4.  Q‑network (MLP)                       #
# -------------------------------------------------------------#
class QNetwork(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes=(256, 256)):
        super().__init__()
        dims = [state_dim + action_dim] + list(hidden_sizes) + [1]
        layers = []
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers += [nn.Linear(d_in, d_out), nn.ReLU()]
        layers.pop()  # remove last ReLU
        self.net = nn.Sequential(*layers)

    def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([s, a], dim=-1)).squeeze(-1)


# -------------------------------------------------------------#
#                   5.  Fitted‑Q Evaluation                     #
# -------------------------------------------------------------#
@dataclass
class FQEConfig:
    gamma: float = 0.99
    lr: float = 3e-4
    batch_size: int = 1024
    epochs: int = 50
    target_update_freq: int = 1000  # gradient steps
    hidden_sizes: Tuple[int, int] = (256, 256)
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class FQE:
    """
    Fitted‑Q Evaluation.

    Parameters
    ----------
    eval_policy :
        Callable mapping a batch of states (Tensor [B, state_dim]) to
        either
            * deterministic actions  – Tensor [B, action_dim]             (continuous)
            * deterministic indices  – Tensor [B]                         (discrete)
            * action probabilities    – Tensor [B, n_actions] (OPTIONAL)  (discrete)
    action_type : {"continuous", "discrete"}
    n_actions :
        Required only for discrete actions.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        eval_policy: Callable[[torch.Tensor], torch.Tensor],
        action_type: str,
        n_actions: Optional[int] = None,
        cfg: FQEConfig = FQEConfig(),
    ):
        self.cfg = cfg
        self.device = torch.device(cfg.device)

        self.eval_policy = eval_policy
        self.action_type = action_type
        self.n_actions = n_actions

        self.q = QNetwork(state_dim, action_dim, cfg.hidden_sizes).to(self.device)
        self.q_target = QNetwork(state_dim, action_dim, cfg.hidden_sizes).to(self.device)
        self.q_target.load_state_dict(self.q.state_dict())

        self.optim = optim.Adam(self.q.parameters(), lr=cfg.lr)
        self._grad_steps = 0  # counter for target‑net updates

    # ------------------------------------------------------------------ #
    def _compute_q_next(self, next_states: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            act = self.eval_policy(next_states)
            act = act.reshape(-1, 1)
            q_next = self.q_target(next_states, act)
            return q_next

    # ------------------------------------------------------------------ #
    def fit(self, data: Dict):
        ds = TransitionDataset(data, self.device)
        loader = DataLoader(ds, batch_size=self.cfg.batch_size, shuffle=True, drop_last=True)

        for epoch in range(self.cfg.epochs):
            for batch in loader:
                s, a, r, s_next, not_done = batch
                # TD target
                q_next = self._compute_q_next(s_next)
                y = r.squeeze(-1) + self.cfg.gamma * q_next * not_done.squeeze(-1)

                # prediction and loss
                q_pred = self.q(s, a)
                loss = (q_pred - y.detach()).pow(2).mean()

                # optimisation step
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                # periodic hard‑update of target network
                self._grad_steps += 1
                if self._grad_steps % self.cfg.target_update_freq == 0:
                    self.q_target.load_state_dict(self.q.state_dict())

            print(f"[FQE] epoch {epoch+1:03d}/{self.cfg.epochs} | loss={loss.item():.4f}")

    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def estimate_v(self, s0, a0) -> float:
        """Return \hat{V}^{\pi_e}(s_0) averaged over episodes."""
        v0 = self.q(s0, a0)
        return v0.cpu().numpy()

state_dim = logged_dataset["state_dim"]
action_dim = logged_dataset["action_dim"]
pytorch_actor = PytorchDiscreteActor(
    action_dim= env.action_space(None).n,
    activation="tanh",
    jax_param_dict=actor_params
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# -------- evaluation policy (deterministic) ---------------
# Example: identity mapping (use the behaviour actions)
def eval_pi(s_batch: torch.Tensor) -> torch.Tensor:
    # This is behavior policy used for the dataset and we are using it as the evaluation policy for sanity check.
    return pytorch_actor(s_batch).mode

fqe_config = FQEConfig(
    gamma=0.99,
    lr=1e-4,
    batch_size=256,
    epochs=10,
    target_update_freq=100,  # gradient steps
    hidden_sizes=(32, 32),
    device="cuda" if torch.cuda.is_available() else "cpu"
)

fqe = FQE(
    state_dim=state_dim,
    action_dim=action_dim,
    eval_policy=eval_pi,
    action_type=logged_dataset["action_type"],
    n_actions=logged_dataset["n_actions"],
    cfg=fqe_config
)

fqe.fit(logged_dataset)

fqe_traj_values = fqe.estimate_v(
    to_tensor(np.array(trajs["obs"].reshape(-1, trajs["obs"].shape[-1])), fqe.device),
    to_tensor(np.array(trajs["action"].reshape(-1, 1)), fqe.device)
)
fqe_driver_values = fqe.estimate_v(
    to_tensor(np.array(driver_trajs["obs"].reshape(-1, driver_trajs["obs"].shape[-1])), fqe.device),
    to_tensor(np.array(driver_trajs["action"].reshape(-1, 1)), fqe.device)
)

print("FQE Traj Values:", fqe_traj_values.shape)
print("FQE Driver Values:", fqe_driver_values.shape)
 
# # Doubly Robust
network = ActorCriticDiscreteAction(action_dim=env.action_space(None).n, activation=config["ACTIVATION"])
pi, value = network.apply(actor_params, trajs["obs"])
traj_log_prob_e = pi.log_prob(trajs["action"])

pi, value = network.apply(actor_params, driver_trajs["obs"].transpose(0, 2, 1, 3).reshape(-1, driver_trajs["obs"].shape[1], driver_trajs["obs"].shape[-1]))
driver_log_prob_e = pi.log_prob(driver_trajs["action"].transpose(0, 2, 1).reshape(-1, driver_trajs["action"].shape[1]))
driver_log_prob_e = driver_log_prob_e.reshape(*driver_trajs["action"].shape[:-1], -1)


def build_logged_dataset_wo_dt(result: dict) -> Dict:
    """
    Convert and merge 'traj' and 'driver_traj' into a single OPE-ready dataset.

    Parameters
    ----------
    result : dict
        Dictionary with
            result["traj"]         -> Transition
            result["driver_traj"]  -> Transition

    Returns
    -------
    logged_dataset : dict
        Flat dictionary matching the schema in the prompt.
    """
    # 1. Pull the two datasets --------------------------------------------------
    traj_ds    : Transition = result["trajs"]
    
    # 2. Flatten env axis separately -------------------------------------------
    # traj      = _prepare_one_dataset(traj_ds)
    # traj["terminal"][:, -1] = 1
    traj = {}
    
    # trajs
    traj["obs"]  = np.asarray(traj_ds["obs"])
    traj["action"] = np.asarray(traj_ds["action"])
    traj["done"] = np.asarray(traj_ds["done"])
    traj["done"] = np.zeros_like(traj["done"], dtype=np.float32) # In our wrapper over the original environment, we remove the done information and only treat the last step as terminal
    traj["value"] = np.asarray(traj_ds["value"])
    traj["reward"] = np.asarray(traj_ds["reward"])
    traj["log_prob"] = np.asarray(traj_ds["log_prob"])
    traj["terminal"] = np.zeros_like(traj["done"], dtype=np.float32)
    traj["terminal"][:, -1] = 1
    
    # print("driver action", driver["action"].shape)
    # 3. Concatenate along the trajectory axis ---------------------------------
    # done      = np.concatenate([traj["done"],     driver["done"]],     axis=1)
    done      = traj["done"]
    # value     = np.concatenate([traj["value"],    driver["value"]],    axis=1)
    value     = traj["value"]
    # reward    = np.concatenate([traj["reward"],   driver["reward"]],   axis=1)
    reward    = traj["reward"]
    # log_prob  = np.concatenate([traj["log_prob"], driver["log_prob"]], axis=1)
    log_prob  = traj["log_prob"]
    # action    = np.concatenate([traj["action"],   driver["action"]],   axis=1)
    action    = traj["action"]
    # obs       = np.concatenate([traj["obs"],      driver["obs"]],      axis=1)
    obs       = traj["obs"]
    # terminal = np.concatenate([traj["terminal"], driver["terminal"]], axis=1)
    terminal = traj["terminal"]
    
    # 4. Final flat-table shapes -----------------------------------------------
    n_traj, T = done.shape
    n_samples = n_traj * T
    obs_dim   = obs.shape[-1]

    # 5. Assemble schema --------------------------------------------------------
    logged_dataset = {
        "size":                 n_samples,
        "n_trajectories":       n_traj,
        "step_per_trajectory":  T,
        "action_type":          "discrete",
        "n_actions":            n_actions,
        "action_dim":           1,
        "action_keys":          None,
        "action_meaning":       None,
        "state_dim":            obs_dim,
        "state_keys":           None,
        "state":                obs.reshape(n_samples, obs_dim),
        "action":               action.reshape(n_samples, 1),
        "reward":               reward.reshape(n_samples, 1),
        "done":                 done.reshape(n_samples, 1).astype(np.float32),
        "terminal":             terminal.reshape(n_samples, 1).astype(np.float32),
        "info":                 None,
        "pscore":               np.exp(log_prob.reshape(n_samples, 1)),
        "behavior_policy":      config["ENV_NAME"],
        "dataset_id":           0,
    }

    # Value estimates (optional, if you want to save them)
    logged_dataset["value"] = value.reshape(n_samples, 1)

    return logged_dataset

traj_only_logged_dataset = build_logged_dataset_wo_dt(offline_data)

def doubly_robust_estimate(
    data: dict,
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    eval_policy: Callable[[torch.Tensor], torch.distributions.Distribution],
    eval_log_prob,
    action_type: str,
    n_actions: Optional[int] = None,          # needed only for discrete policies
    gamma: float = 1.0,
    weighted: bool = False,
    mc_samples: int = 10,                     # K for MC expectation in continuous case
    device: Optional[str] = None,
    return_all: bool = False,                 # if True, return per-(traj,t) DR values
) -> Union[float, np.ndarray]:
    """
    Doubly-Robust off-policy state-value estimator.

    Parameters
    ----------
    data            : logged batch with keys
                      {state, action, reward, terminal, pscore,
                       n_trajectories, step_per_trajectory, state_dim}.
    q_fn            : callable  Q̂(s,a) -> shape [B,1] or [B].
    eval_policy     : callable  π_e(s)  -> torch.distributions.Distribution.
    action_type     : "discrete"  |  "continuous".
    n_actions       : required for discrete to enumerate actions.
    gamma           : discount factor ∈ [0,1].
    weighted        : if True, use weighted DR (WDR); else ordinary DR.
    mc_samples      : #Monte-Carlo draws for continuous-action V̂ expectation.
    device          : torch device.  Default "cpu".
    return_all      : if True, return DR value for every (traj,t).

    Returns
    -------
    float           if return_all == False.
    np.ndarray      shape (N,H) of per-state DR values otherwise.
    """
    if "pscore" not in data or data["pscore"] is None:
        raise ValueError("Behaviour propensities `pscore` are required.")

    if device is None:
        device = torch.device("cpu")
    float_dtype = torch.float32

    # ─────────── reshape logged batch ──────────────────────────────────
    N   = int(data["n_trajectories"])
    H   = int(data["step_per_trajectory"])
    Sd  = int(data["state_dim"])
    A_shape = data["action"].shape[-1:]        # tuple (act_dim,) for cont - or () for scalar IDs

    states   = torch.as_tensor(data["state"],  dtype=float_dtype, device=device) \
                 .view(N, H, Sd)
    actions  = torch.as_tensor(data["action"], dtype=float_dtype, device=device) \
                 .view(N, H, *A_shape)
    rewards  = torch.as_tensor(data["reward"], dtype=float_dtype, device=device) \
                 .view(N, H)
    terminal = torch.as_tensor(data["terminal"], dtype=torch.bool, device=device) \
                 .view(N, H)

    # logged behaviour log-probs  log π_b(a|s)
    log_pi_b = torch.as_tensor(data["pscore"], dtype=float_dtype, device=device) \
                 .clamp_min(1e-12).log().view(N, H)

    # ─────────── helper: V̂_πe(s) = E_{a~πe}[ Q̂(s,a) ] ─────────────────
    def V_hat_pi_e(s_batch: torch.Tensor) -> torch.Tensor:        # [B]
        if action_type == "discrete":
            if n_actions is None:
                raise ValueError("n_actions required for discrete V̂ expectation.")
            dist  = eval_policy(s_batch)                          # Categorical
            probs = dist.probs                                     # [B, n_actions]
            # enumerate actions 0..n_actions-1
            a_enum = torch.arange(n_actions, device=device)
            a_rep  = a_enum.unsqueeze(0).repeat(s_batch.size(0), 1)  # [B,n_actions]
            q_vals = q_fn(
                s_batch.repeat_interleave(n_actions, dim=0),      # [B*n, Sd]
                a_rep.flatten().unsqueeze(-1).float()             # scalar IDs → [B*n,1]
            ).view(-1, n_actions)                                 # [B, n_actions]
            return (probs * q_vals).sum(-1)                       # expectation over a
        else:  # continuous
            dist = eval_policy(s_batch)                           # some torch Distribution
            a_samples = dist.rsample((mc_samples,))               # [K,B,*A_shape]
            q_vals = q_fn(
                s_batch.repeat(mc_samples, 1),                    # [K*B,Sd]
                a_samples.view(-1, *A_shape)                      # [K*B,*A_shape]
            ).view(mc_samples, -1)                                # [K,B]
            return q_vals.mean(0)                                 # MC expectation

    # ─────────── pre-compute V̂(s) and V̂(s') ────────────────────────────
    flat_states = states.view(-1, Sd)
    V_hat_all   = V_hat_pi_e(flat_states).view(N, H)              # V̂(s_t)
    V_hat_next  = torch.zeros_like(V_hat_all)
    V_hat_next[:, :-1] = V_hat_all[:, 1:]
    V_hat_next *= (~terminal).float()                            # zero after real terminals

    # ─────────── Q̂(s,a_b) on logged actions ────────────────────────────
    Q_hat_beh = q_fn(flat_states, actions.view(-1, *A_shape)).view(N, H)

    # ─────────── eval log-probs  log π_e(a_b | s) ───────────────────────
    # with torch.no_grad():
    #     dist_e  = eval_policy(flat_states)                        # distribution object
    #     log_pi_e = dist_e.log_prob(actions.view(-1, *A_shape))    # [N*H]
    # log_pi_e = log_pi_e.view(N, H)
    log_pi_e = np.array(eval_log_prob).reshape(N, H)  # [N, H]
    log_pi_e = torch.as_tensor(log_pi_e, dtype=float_dtype, device=device) \
    
    # ─────────── importance ratios ──────────────────────────────────────
    log_rho   = log_pi_e - log_pi_b                               # log ρ_t
    log_cum_w = torch.cumsum(log_rho, dim=1)                      # log Π_{0..t} ρ
    cum_w     = torch.exp(log_cum_w)

    # ─────────── mask steps beyond first terminal per trajectory ────────
    done_cum   = torch.cumsum(terminal.int(), dim=1)
    valid_mask = (done_cum <= 1).float()
    cum_w   *= valid_mask
    rewards *= valid_mask

    # ─────────── TD residual and discounts ──────────────────────────────
    delta   = rewards + gamma * V_hat_next - Q_hat_beh
    gammas  = gamma ** torch.arange(H, dtype=float_dtype, device=device)  # [H]

    # ─────────── Weighted DR (WDR / SN-DR)  ─────────────────────────────
    if weighted:
        eps = 1e-2
        w_sum_t = cum_w.sum(0).clamp_min(eps)                     # [H]
        baseline = (cum_w[:, 0] * V_hat_all[:, 0]).sum() / w_sum_t[0]
        corr = (gammas * (cum_w * delta).sum(0) / w_sum_t).sum()
        return (baseline + corr).cpu().item()

    # ─────────── Ordinary (unweighted) DR  ──────────────────────────────
    if not return_all:
        dr_per_traj = V_hat_all[:, 0] + (gammas * cum_w * delta).sum(1)
        return dr_per_traj.mean().cpu().item()

    # ─────────── Per-state DR (return_all=True) ─────────────────────────
    # cum_w_prev[:,t] = Π_{j< t} ρ_j
    cum_w_prev = torch.cat([torch.ones(N, 1, dtype=float_dtype, device=device),
                            cum_w[:, :-1]], dim=1)

    # γ^{k-t} lower-triangular matrix  L_{t,k} = γ^{k-t}   (for k ≥ t)
    idx = torch.arange(H, device=device)
    L   = gamma ** (idx.unsqueeze(1) - idx.unsqueeze(0)).clamp_min(0)  # [H,H]

    M   = cum_w * delta                                                # [N,H]
    out = torch.einsum('tk,nk->nt', L, M)                              # discounted sums
    V_dr_all = V_hat_all + out / (cum_w_prev.clamp_min(1e-1))

    return V_dr_all.detach().cpu().numpy()                                      # shape (N,H)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def eval_pi(s_batch: torch.Tensor) -> torch.Tensor:
    # This is behavior policy used for the dataset and we are using it as the evaluation policy for sanity check.
    return pytorch_actor(s_batch)

def eval_log_prob(states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
    # Use the PyTorch actor to get the log probabilities
    states, actions = states.to(device), actions.to(device)
    with torch.no_grad():
        policy = pytorch_actor(states)
        log_prob = policy.log_prob(actions)
    return log_prob


dr_traj_values = doubly_robust_estimate(
    traj_only_logged_dataset,
    q_fn=fqe.q,
    eval_policy=eval_pi,
    eval_log_prob=np.array(traj_log_prob_e),
    action_type=traj_only_logged_dataset["action_type"],
    n_actions=traj_only_logged_dataset["n_actions"],
    gamma=config["GAMMA"],
    weighted=False,
    device=device,
    return_all=True  # Set to True if you want per-state values
)
dr_traj_values = dr_traj_values.reshape(-1)

 
# # Trajectory-IS
# Implementing Trajectory-wise IS using JAX
@jax.jit
def compute_discounted_episodic_return(episodic_rewards, gamma=0.99):
    """Compute the discounted return for an episode."""
    
    def body_fn(carry, reward):
        discounted_return = carry * gamma + reward
        return discounted_return, discounted_return
    
    discounted_return, returns = jax.lax.scan(body_fn, 0.0, episodic_rewards, reverse=True)
    return discounted_return, returns

# Find discounted returns G_i for each trajectory
_, traj_returns = jax.vmap(compute_discounted_episodic_return, in_axes=(0, None))(offline_data["trajs"]["reward"], config["GAMMA"])
_, driver_returns = jax.vmap(compute_discounted_episodic_return, in_axes=(0, None))(offline_data["driver_trajs"]["reward"].transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), config["GAMMA"])
driver_returns = driver_returns.reshape(*offline_data["driver_trajs"]["reward"].shape)

# Log probabilities of behavior policy
traj_log_prob_b = offline_data["trajs"]["log_prob"]
driver_log_prob_b = offline_data["driver_trajs"]["log_prob"]

# Log probabilities of evaluation policy
network = ActorCriticDiscreteAction(action_dim=env.action_space(None).n, activation=config["ACTIVATION"])
pi, value = network.apply(actor_params, trajs["obs"])
traj_log_prob_e = pi.log_prob(trajs["action"])

pi, value = network.apply(actor_params, driver_trajs["obs"].transpose(0, 2, 1, 3).reshape(-1, driver_trajs["obs"].shape[1], driver_trajs["obs"].shape[-1]))
driver_log_prob_e = pi.log_prob(driver_trajs["action"].transpose(0, 2, 1).reshape(-1, driver_trajs["action"].shape[1]))
driver_log_prob_e = driver_log_prob_e.reshape(*driver_trajs["action"].shape[:-1], -1)

traj_log_w_per_step = traj_log_prob_e - traj_log_prob_b
traj_log_w_stepwise = jnp.cumsum(traj_log_w_per_step[:, ::-1], axis=1)[:, ::-1]

driver_log_w_per_step = driver_log_prob_e - driver_log_prob_b
driver_log_w_stepwise = jnp.cumsum(driver_log_w_per_step[:, ::-1, :], axis=1)[:, ::-1, :]

traj_w_stepwise = jnp.exp(traj_log_w_stepwise)
tis_traj_values = (traj_w_stepwise * traj_returns)

driver_w_stepsize = jnp.exp(driver_log_w_stepwise)
tis_driver_values = (driver_w_stepsize * driver_returns)

tis_traj_values = tis_traj_values.reshape(-1)
tis_driver_values = tis_driver_values.reshape(-1)

tis_traj_values.shape, tis_driver_values.shape

# # Per-Decision IS
@jax.jit
def compute_pdis_return(episodic_rewards, per_decision_logws, gamma=0.99):
    """Compute the discounted return for an episode."""
    
    # def body_fn(carry, reward_logw):
    #     reward, logw = reward_logw
        
    #     discounted_return = carry * gamma + reward
    #     return discounted_return, discounted_return
    
    def body_fn(carry, reward_logw):
        reward, logw = reward_logw
        discounted_return = carry * gamma + reward * jnp.exp(logw)
        return discounted_return, discounted_return
    
    discounted_return, returns = jax.lax.scan(body_fn, 0.0, (episodic_rewards, per_decision_logws), reverse=True)
    return returns

traj_per_decision_log_w = jnp.cumsum(traj_log_w_per_step, axis=1)
driver_per_decision_log_w = jnp.cumsum(driver_log_w_per_step, axis=1)

pdis_traj_values = jax.vmap(compute_pdis_return, in_axes=(0, 0, None))(trajs["reward"], traj_per_decision_log_w, config["GAMMA"])
pdis_driver_values = jax.vmap(compute_pdis_return, in_axes=(0, 0, None))(driver_trajs["reward"].transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), driver_per_decision_log_w.transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), config["GAMMA"]).reshape(*driver_returns.shape)
pdis_traj_values = pdis_traj_values.reshape(-1)
pdis_driver_values = pdis_driver_values.reshape(-1)


# Plot the tis, pdis and dr values
import matplotlib.pyplot as plt 

# # MC Evaluation

selected_indices = np.random.choice(np.arange(0, np.prod(trajs["done"].shape), config["general_traj_len"]), size=num_eval_samples, replace=False) # Select the first state of each trajectory
print(selected_indices)


selected_indices_x, selected_indices_y = np.unravel_index(selected_indices, trajs["done"].shape)


eval_obs = trajs["obs"][selected_indices_x, selected_indices_y, :].reshape(-1, trajs["obs"].shape[-1])



rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, num_eval_samples)
_, eval_env_states = jax.vmap(env.reset, in_axes=(0, None))(
    reset_rng, env_params
)


# ---------- helper functions for env-state replacement --------------------- #
def _flatten_first_two(x):
    if x.ndim <= 1:
        return x
    t, b, *rest = x.shape
    return x.reshape(t * b, *rest)


def _select_rows(x, idx):
    return _flatten_first_two(x)[idx]


def _match_shape(template, x):
    return x if x.shape == template.shape else jnp.reshape(x, template.shape)


def replace_inner_env_state(driver_env_states, env_states, idx):
    """Replace inner EnvState of driver_env_states with selected rows."""
    raw = jax.tree_util.tree_map(
        lambda leaf: _select_rows(leaf, idx), env_states.env_state
    )
    tmpl = driver_env_states.env_state
    new_inner = jax.tree_util.tree_map(_match_shape, tmpl, raw)

    import dataclasses
    if dataclasses.is_dataclass(driver_env_states):
        return dataclasses.replace(driver_env_states, env_state=new_inner)
    return driver_env_states._replace(env_state=new_inner)


# ---------- build driver_env_states --------------------------------------- #
import flax.serialization as ser
env_states_recasted = ser.from_state_dict(eval_env_states, env_states)
new_eval_env_states = replace_inner_env_state(
    eval_env_states, env_states_recasted, selected_indices
)

@partial(jax.jit, static_argnames=("traj_len",))
def collect_trajectory(params, env_state, obsv, rng, traj_len):
    """Roll out `traj_len` steps of the env under a given policy."""
    def _env_step(runner_state, _):
        params, env_state, last_obs, rng = runner_state
        B = last_obs.shape[0]

        rng, sub = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action    = pi.sample(seed=sub)
        log_prob  = pi.log_prob(action)

        rng, sub = jax.random.split(rng)
        rng_step  = jax.random.split(sub, B)
        obsv, env_state, reward, done, info = jax.vmap(
            env.step, in_axes=(0, 0, 0, None)
        )(rng_step, env_state, action, env_params)

        transition = Transition(
            done, action, value, reward, log_prob, last_obs, info
        )
        return (params, env_state, obsv, rng), (transition, env_state)

    runner_state, (traj_batch, env_states) = jax.lax.scan(
        _env_step,
        (params, env_state, obsv, rng),
        xs=None,
        length=traj_len,
    )
    return traj_batch, env_states

rng_traj_runs = jax.random.split(rng, num_trajs_per_state)
trajs, _ = jax.vmap(collect_trajectory, in_axes=(None, None, None, 0, None))(
    actor_params,
    new_eval_env_states,
    eval_obs,
    rng_traj_runs,
    config["general_traj_len"]
)

mc_eval_returns, _ = jax.vmap(compute_discounted_episodic_return, in_axes=(0))(trajs.reward.transpose(2, 1, 0).mean(axis=-1))

# Predictor-based Evaluation
with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "rb") as f:
    predtran_train_data = pickle.load(f)

driver_states = predtran_train_data["driver_states"]
driver_env_states = predtran_train_data["driver_env_states"]

predictor = PredictabilityHead(config["num_heads"],
                               config["hidden_dim"],
                               config["num_layers"],)

rng, driver_rng = jax.random.split(rng)

driver_trajs, _ = collect_trajectory(actor_params, 
                                    driver_env_states,
                                    driver_states,
                                    driver_rng,
                                    config["driver_traj_len"]
                                    )

driver_returns, _ = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(driver_trajs.reward.transpose((1, 0)).reshape(-1, config['driver_traj_len']), config["GAMMA"])

feature_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])
feature_extractor_params = pred_transformer_params["feat_extractor_params"]

driver_states_processed = feature_extractor.apply(feature_extractor_params, driver_states)

query_states = eval_obs
query_states_processed = feature_extractor.apply(feature_extractor_params, query_states)
predictor_value_estimates = predictor.apply(
    predictor_params, 
    driver_states_processed, 
    driver_returns,
    query_states_processed,
)

# Collect results
fqe_eval_values = np.array(fqe_traj_values[selected_indices])
tis_eval_values = np.array(tis_traj_values[selected_indices])
pdis_eval_values = np.array(pdis_traj_values[selected_indices])
dr_eval_values = np.array(dr_traj_values[selected_indices])
mc_eval_values = np.array(mc_eval_returns)
predictor_eval_values = np.array(predictor_value_estimates)

# Save eval values
eval_values = {
    "fqe": fqe_eval_values,
    "tis": tis_eval_values,
    "pdis": pdis_eval_values,
    "dr": dr_eval_values,
    "mc": mc_eval_values,
    "predictor": predictor_eval_values
}

# log to wandb
for key, values in eval_values.items():
    wandb.log({f"{key}_eval_values": wandb.Histogram(values)})

with open(f'{config["SAVE_DIR"]}/{save_experiment_name}_eval_values.pkl', "wb") as f:
    pickle.dump(eval_values, f)

# Plot the results
def plot_eval_values(eval_values, title):
    plt.figure(figsize=(12, 6))
    # for key, values in eval_values.items():
    plt.plot(eval_values, alpha=0.7)
    plt.title(title)
    plt.xlabel('Values')
    plt.ylabel('Frequency')
    plt.legend()
    plt.grid()
    # plt.ylim()
    plt.show()

for key, values in eval_values.items():
    # plot_eval_values(np.abs(values - mc_eval_values), f"{key} Values Distribution")
    print("MAE: ", np.abs(values - mc_eval_values).mean())