# Standard library imports
from __future__ import annotations
import os
import argparse
import time
import pickle
import gc
from functools import partial
from dataclasses import dataclass
from collections import namedtuple
from typing import Any, NamedTuple, Dict, Sequence, Callable, Tuple, Optional, Union

# Environment variables
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Third-party imports
import numpy as np
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import flax.serialization as ser
import orbax.checkpoint as oc
import distrax
import gymnax
import gymnasium as gym
import matplotlib.pyplot as plt
import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# Local imports
from wrappers import (
    LogWrapper, FlattenObservationWrapper, LogEnvState,
    BraxGymnaxWrapper, VecEnv, NormalizeVecObservation,
    NormalizeVecReward, ClipAction
)
from a2c_continuous import make_train
from models import ActorCriticDiscreteAction, ActorCriticContinuousAction, PredictabilityHead
from models import FeatExtractorDiscreteAction, PyTorchContinuousActor
from utils import Transition, load_feat_extractor_params, extract_submodel


class Transition(NamedTuple):
    done:     jnp.ndarray
    action:   jnp.ndarray
    value:    jnp.ndarray
    reward:   jnp.ndarray
    log_prob: jnp.ndarray
    obs:      jnp.ndarray
    info:     Any


def get_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Evaluation"
    )
    parser.add_argument("--save_dir", default="./complete_continuous_longrun/", type=str, help="Directory to save results")
    parser.add_argument("--experiment_name", default="humanoidstandup_256envs_10steps_1000000.0ts_2seed_25dtr_1000gtr_5dt_humanoidstandup_128bs_0.001lr_100ep_4h_8l_128hd", type=str,)
    parser.add_argument("--num_eval_samples", default=10, type=int)
    parser.add_argument("--num_trajs_per_state", default=10, type=int, help="Number of trajectories to sample per state during MC evaluation")
    parser.add_argument("--experiment_name_file", type=str, default="./train_scripts/complete_continuous_long_run_scripts/experiment_names.txt", help="File containing the experiment names to evaluate")
    parser.add_argument("--PREDICTABILITY_COEF", default=0.0, type=float, help="Predictability coefficient")
    parser.add_argument("--PRED_LR", default=0.0, type=float, help="Learning rate for the predictability transformer")
    parser.add_argument("--use_pretrained_transformer", default=0, type=int, help="Use pretrained transformer")
    return parser

parser = get_parser()
args = parser.parse_args()
save_dir = args.save_dir
experiment_name = args.experiment_name
num_eval_samples = args.num_eval_samples
num_trajs_per_state = args.num_trajs_per_state
experiment_name_file = args.experiment_name_file
if args.PRED_LR != 0:
    pred_lr = 1e-4 if args.use_pretrained_transformer == 1 else 1e-3
else: 
    pred_lr = 0
use_pretrained_transformer = args.use_pretrained_transformer
pred_coef = args.PREDICTABILITY_COEF

wandb.init(project="ope_continuous", entity="", config=vars(args), dir="./wandb")


save_experiment_name = f"{experiment_name}_{pred_coef}pc_{use_pretrained_transformer}usepretrained_{pred_lr}predlr"

print(f"Loading A2C parameters from experiment: {save_experiment_name}")
with open(f'{save_dir}/{save_experiment_name}_a2c_evarl.pkl', "rb") as f:
    a2c_params = pickle.load(f)

config = a2c_params["config"]
actor_critic_params = a2c_params['actor_params']

print("Number of Actor-Critic Checkpoints", actor_critic_params["params"]["Dense_0"]["kernel"].shape[0])
rng = jax.random.PRNGKey(config["SEED"])
rng, a2c_ckpt_rng = jax.random.split(rng)

a2c_ckpt_idx = -1 # Evaluate the last checkpoint


actor_params = extract_submodel(actor_critic_params, a2c_ckpt_idx)

env, env_params = BraxGymnaxWrapper(config["ENV_NAME"], backend=config["ENV_BACKEND"]), None
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)
if config.get("NORMALIZE_ENV", False):
    env = NormalizeVecObservation(env)
    env = NormalizeVecReward(env, config["GAMMA"])

action_dim = env.action_size

# Offline data
ckpt_dir = f'{config["SAVE_DIR"]}/{experiment_name}_offline_ckpt'
ckptr    = oc.Checkpointer(oc.PyTreeCheckpointHandler())   # TensorStore backend
offline_data = ckptr.restore(ckpt_dir)
trajs = offline_data["trajs"]
driver_trajs = offline_data["driver_trajs"]
env_states = offline_data["env_states"]


feat_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])

# Load predictor params
with open(f'{config["SAVE_DIR"]}/{save_experiment_name}_pred_transformer_evarl.pkl', "rb") as f:
    pred_transformer_params = pickle.load(f)
predictor_params = pred_transformer_params["predictor_params"]
predictor_params = extract_submodel(predictor_params, a2c_ckpt_idx)

# Load feature extractor
with open(f'{config["SAVE_DIR"]}/{experiment_name}_pred_transformer.pkl', "rb") as f:
    pred_transformer_params = pickle.load(f)
feat_extractor_params = pred_transformer_params["feat_extractor_params"]


# OPE dataset format
Transition = namedtuple(
    "Transition",
    ["done", "action", "value", "reward", "log_prob", "obs", "info"],
)

def _flatten_policy_axis(x: jnp.ndarray) -> np.ndarray:
    if x.ndim ==3:
        n_pol, T, n_env = x.shape
        return np.asarray(x).transpose(0, 2, 1).reshape(-1, T)
    
    elif x.ndim == 4:               # (N, T, E, F)
        n_pol, T, n_env, feat = x.shape
        return np.asarray(x).transpose(0, 2, 1, 3).reshape(-1, T, feat)

def _prepare_one_dataset(ds: Transition) -> Dict[str, np.ndarray]:
    """
    Flatten the environment axis and return a dict of arrays with shape:
        • done, value, reward, log_prob : (n_traj, T)
        • action                        : (n_traj, T, act_dim)
        • obs                           : (n_traj, T, obs_dim)
    """
    ds_data = {
        "done":      _flatten_policy_axis(ds["done"]),
        "value":     _flatten_policy_axis(ds["value"]),
        "reward":    _flatten_policy_axis(ds["reward"]),
        "log_prob":  _flatten_policy_axis(ds["log_prob"]),
        "action":    _flatten_policy_axis(ds["action"]),     # (…, act_dim)
        "obs":       _flatten_policy_axis(ds["obs"]),        # (…, obs_dim)
    }
    ds_data["terminal"] = np.zeros_like(ds_data["done"], dtype=np.float32)
    ds_data["terminal"][:, -1] = 1 # TODO: Make sure that every general_traj_len the episode is ended
    return ds_data

def _unsqueeze_num_driver_states(driver_traj_dataset, num_driver_states):
    for key in driver_traj_dataset:
        if driver_traj_dataset[key].ndim == 2:
            driver_traj_dataset[key] = driver_traj_dataset[key].reshape(-1, num_driver_states, *driver_traj_dataset[key].shape[1:])
            driver_traj_dataset[key] = driver_traj_dataset[key].reshape(driver_traj_dataset[key].shape[0], -1)
        elif driver_traj_dataset[key].ndim == 3:
            driver_traj_dataset[key] = driver_traj_dataset[key].reshape(-1, num_driver_states, *driver_traj_dataset[key].shape[1:])
            driver_traj_dataset[key] = driver_traj_dataset[key].reshape(driver_traj_dataset[key].shape[0], -1, driver_traj_dataset[key].shape[-1])
    return driver_traj_dataset
        

def build_logged_dataset(result: dict) -> Dict:
    """
    Convert and merge 'traj' and 'driver_traj' into a single OPE-ready dataset.

    Parameters
    ----------
    result : dict
        Dictionary with
            result["traj"]         -> Transition
            result["driver_traj"]  -> Transition

    Returns
    -------
    logged_dataset : dict
        Flat dictionary matching the schema in the prompt.
    """
    # 1. Pull the two datasets --------------------------------------------------
    traj_ds    : Transition = result["trajs"]
    driver_ds  : Transition = result["driver_trajs"]

    # 2. Flatten env axis separately -------------------------------------------
    traj      = _prepare_one_dataset(traj_ds)
    traj["terminal"][:, -1] = 1
    
    driver    = _prepare_one_dataset(driver_ds)
    driver  = _unsqueeze_num_driver_states(driver, config["num_drivertest_states"])
    driver["terminal"][:, -1] = 1
    # 3. Concatenate along the trajectory axis ---------------------------------
    done      = np.concatenate([traj["done"],     driver["done"]],     axis=1)
    value     = np.concatenate([traj["value"],    driver["value"]],    axis=1)
    reward    = np.concatenate([traj["reward"],   driver["reward"]],   axis=1)
    log_prob  = np.concatenate([traj["log_prob"], driver["log_prob"]], axis=1)
    action    = np.concatenate([traj["action"],   driver["action"]],   axis=1)
    obs       = np.concatenate([traj["obs"],      driver["obs"]],      axis=1)
    terminal = np.concatenate([traj["terminal"], driver["terminal"]], axis=1)
    
    # 4. Final flat-table shapes -----------------------------------------------
    n_traj, T = done.shape
    n_samples = n_traj * T
    act_dim   = action.shape[-1]
    obs_dim   = obs.shape[-1]

    # 5. Assemble schema --------------------------------------------------------
    logged_dataset = {
        "size":                 n_samples,
        "n_trajectories":       n_traj,
        "step_per_trajectory":  T,
        "action_type":          "continuous",
        "n_actions":            None,
        "action_dim":           act_dim,
        "action_keys":          None,
        "action_meaning":       None,
        "state_dim":            obs_dim,
        "state_keys":           None,
        "state":                obs.reshape(n_samples, obs_dim),
        "action":               action.reshape(n_samples, act_dim),
        "reward":               reward.reshape(n_samples, 1),
        "done":                 done.reshape(n_samples, 1).astype(np.float32),
        "terminal":             terminal.reshape(n_samples, 1).astype(np.float32),
        "info":                 None,
        "pscore":               np.exp(log_prob.reshape(n_samples, 1)),
        "behavior_policy":      "hopper_continuous",
        "dataset_id":           0,
    }

    # Value estimates (optional, if you want to save them)
    logged_dataset["value"] = value.reshape(n_samples, 1)

    return logged_dataset


logged_dataset = build_logged_dataset(offline_data)

# -------------------------------------------------------------#
#                      2.  Utilities                           #
# -------------------------------------------------------------#
def to_tensor(x, device):
    return torch.as_tensor(x, dtype=torch.float32, device=device)


def one_hot(actions: torch.Tensor, n_actions: int) -> torch.Tensor:
    """Convert integer action indices to one‑hot vectors."""
    return torch.nn.functional.one_hot(actions.long(), num_classes=n_actions).float()


# -------------------------------------------------------------#
#   3.  PyTorch Dataset that holds transitions for FQE         #
# -------------------------------------------------------------#
class TransitionDataset(Dataset):
    """(s, a, r, s', not_done) tuples constructed from the OPE format."""

    def __init__(self, data: Dict, device: torch.device):
        self.device = device

        n_traj = int(data["n_trajectories"])
        H = int(data["step_per_trajectory"])
        S = data["state"].reshape(n_traj, H, -1)
        A = data["action"].reshape(n_traj, H, -1)
        R = data["reward"].reshape(n_traj, H)
        done_mask = np.logical_or(
            data["terminal"], data["done"]
        ).reshape(n_traj, H)
        # done_mask = data["terminal"].reshape(n_traj, H).astype(bool) or \
        #             data["done"].reshape(n_traj, H).astype(bool)

        # next_state: s_{t+1}, dummy zeros for the final step
        next_S = np.zeros_like(S)
        next_S[:, :-1] = S[:, 1:]
        # If episode ended at t, then s_{t+1} should not be used (mask = 0)
        not_done = (~done_mask).astype(np.float32)
        not_done[:, -1] = 0.0  # by construction

        # flatten
        self.states = to_tensor(S.reshape(-1, S.shape[-1]), device)
        self.actions = to_tensor(A.reshape(-1, A.shape[-1]), device)
        self.rewards = to_tensor(R.reshape(-1, 1), device)
        self.next_states = to_tensor(next_S.reshape(-1, S.shape[-1]), device)
        self.not_done = to_tensor(not_done.reshape(-1, 1), device)

    # -- PyTorch dataset protocol ------------------------------------------
    def __len__(self):
        return self.states.shape[0]

    def __getitem__(self, idx):
        return (
            self.states[idx],
            self.actions[idx],
            self.rewards[idx],
            self.next_states[idx],
            self.not_done[idx],
        )


# -------------------------------------------------------------#
#                    4.  Q‑network (MLP)                       #
# -------------------------------------------------------------#
class QNetwork(nn.Module):
    def __init__(self, state_dim: int, action_dim: int, hidden_sizes=(256, 256)):
        super().__init__()
        dims = [state_dim + action_dim] + list(hidden_sizes) + [1]
        layers = []
        for d_in, d_out in zip(dims[:-1], dims[1:]):
            layers += [nn.Linear(d_in, d_out), nn.ReLU()]
        layers.pop()  # remove last ReLU
        self.net = nn.Sequential(*layers)

    def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        return self.net(torch.cat([s, a], dim=-1)).squeeze(-1)


# -------------------------------------------------------------#
#                   5.  Fitted‑Q Evaluation                     #
# -------------------------------------------------------------#
@dataclass
class FQEConfig:
    gamma: float = 0.99
    lr: float = 3e-4
    batch_size: int = 1024
    epochs: int = 50
    target_update_freq: int = 1000  # gradient steps
    hidden_sizes: Tuple[int, int] = (256, 256)
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


class FQE:
    """
    Fitted‑Q Evaluation.

    Parameters
    ----------
    eval_policy :
        Callable mapping a batch of states (Tensor [B, state_dim]) to
        either
            * deterministic actions  – Tensor [B, action_dim]             (continuous)
            * deterministic indices  – Tensor [B]                         (discrete)
            * action probabilities    – Tensor [B, n_actions] (OPTIONAL)  (discrete)
    action_type : {"continuous", "discrete"}
    n_actions :
        Required only for discrete actions.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        eval_policy: Callable[[torch.Tensor], torch.Tensor],
        action_type: str,
        n_actions: Optional[int] = None,
        cfg: FQEConfig = FQEConfig(),
    ):
        self.cfg = cfg
        self.device = torch.device(cfg.device)

        self.eval_policy = eval_policy
        self.action_type = action_type
        self.n_actions = n_actions

        self.q = QNetwork(state_dim, action_dim, cfg.hidden_sizes).to(self.device)
        self.q_target = QNetwork(state_dim, action_dim, cfg.hidden_sizes).to(self.device)
        self.q_target.load_state_dict(self.q.state_dict())

        self.optim = optim.Adam(self.q.parameters(), lr=cfg.lr)
        self._grad_steps = 0  # counter for target‑net updates

    # ------------------------------------------------------------------ #
    def _compute_q_next(self, next_states: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            act = self.eval_policy(next_states)
            if self.action_type == "discrete":
                # deterministic: indices -> one‑hot
                if act.ndim == 1 or act.shape[-1] == 1:
                    act = one_hot(act.view(-1), self.n_actions)
                # probabilities: take expectation
                if act.ndim == 2 and act.shape[-1] == self.n_actions:
                    actions_one_hot = torch.eye(self.n_actions, device=self.device)
                    # broadcast states to [B, n_actions, state_dim]
                    B = next_states.size(0)
                    s_rep = next_states.unsqueeze(1).expand(-1, self.n_actions, -1)
                    q_all = self.q_target(
                        s_rep.reshape(-1, next_states.size(-1)),
                        actions_one_hot.expand(B, -1, -1).reshape(-1, self.n_actions),
                    ).view(B, self.n_actions)
                    return (q_all * act).sum(-1)
            # continuous or already one‑hot
            q_next = self.q_target(next_states, act)
            return q_next

    # ------------------------------------------------------------------ #
    def fit(self, data: Dict):
        ds = TransitionDataset(data, self.device)
        loader = DataLoader(ds, batch_size=self.cfg.batch_size, shuffle=True, drop_last=True)

        for epoch in range(self.cfg.epochs):
            for batch in loader:
                s, a, r, s_next, not_done = batch
                # TD target
                q_next = self._compute_q_next(s_next)
                y = r.squeeze(-1) + self.cfg.gamma * q_next * not_done.squeeze(-1)

                # prediction and loss
                q_pred = self.q(s, a)
                loss = (q_pred - y.detach()).pow(2).mean()

                # optimisation step
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                # periodic hard‑update of target network
                self._grad_steps += 1
                if self._grad_steps % self.cfg.target_update_freq == 0:
                    self.q_target.load_state_dict(self.q.state_dict())

            print(f"[FQE] epoch {epoch+1:03d}/{self.cfg.epochs} | loss={loss.item():.4f}")

    # ------------------------------------------------------------------ #
    @torch.no_grad()
    def estimate_v(self, s0, a0) -> float:
        """Return \hat{V}^{\pi_e}(s_0) averaged over episodes."""
        v0 = self.q(s0, a0)
        return v0.cpu().detach().numpy()


# Assume `logged_dataset` has been produced by your convert_to_dataset(...)
import pickle

state_dim = logged_dataset["state_dim"]
action_dim = logged_dataset["action_dim"]
pytorch_actor = PyTorchContinuousActor(
    action_dim= action_dim,
    activation="tanh",
    jax_param_dict=actor_params,
    device="cuda" if torch.cuda.is_available() else "cpu",
)
# -------- evaluation policy (deterministic) ---------------
def eval_pi(s_batch: torch.Tensor) -> torch.Tensor:
    # This is behavior policy used for the dataset and we are using it as the evaluation policy for sanity check.
    return pytorch_actor(s_batch).mode
print("Starting FQE")
fqe_config = FQEConfig(
    gamma=0.99,
    lr=1e-4,
    batch_size=256,
    epochs=10,
    target_update_freq=100,  # gradient steps
    hidden_sizes=(32, 32),
    device="cuda" if torch.cuda.is_available() else "cpu"
)

fqe = FQE(
    state_dim=state_dim,
    action_dim=action_dim,
    eval_policy=eval_pi,
    action_type=logged_dataset["action_type"],
    n_actions=logged_dataset["n_actions"],
    cfg=fqe_config
)

fqe.fit(logged_dataset)



fqe_traj_values = fqe.estimate_v(
    to_tensor(np.array(trajs["obs"].reshape(-1, trajs["obs"].shape[-1])), fqe.device),
    to_tensor(np.array(trajs["action"].reshape(-1, action_dim)), fqe.device)
)
fqe_driver_values = fqe.estimate_v(
    to_tensor(np.array(driver_trajs["obs"].reshape(-1, driver_trajs["obs"].shape[-1])), fqe.device),
    to_tensor(np.array(driver_trajs["action"].reshape(-1, action_dim)), fqe.device)
)

print("FQE Traj Values:", fqe_traj_values.shape)
print("FQE Driver Values:", fqe_driver_values.shape)


# Implementing Trajectory-wise IS using JAX

@jax.jit
def compute_discounted_episodic_return(episodic_rewards, gamma=0.99):
    """Compute the discounted return for an episode."""
    
    def body_fn(carry, reward):
        discounted_return = carry * gamma + reward
        return discounted_return, discounted_return
    
    discounted_return, returns = jax.lax.scan(body_fn, 0.0, episodic_rewards, reverse=True)
    return discounted_return, returns

# Find discounted returns G_i for each trajectory
_, traj_returns = jax.vmap(compute_discounted_episodic_return, in_axes=(0, None))(offline_data["trajs"]["reward"].squeeze(-1), config["GAMMA"])
_, driver_returns = jax.vmap(compute_discounted_episodic_return, in_axes=(0, None))(offline_data["driver_trajs"]["reward"].transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), config["GAMMA"])
driver_returns = driver_returns.reshape(*offline_data["driver_trajs"]["reward"].shape)

# Log probabilities of behavior policy
traj_log_prob_b = offline_data["trajs"]["log_prob"].squeeze(-1)
driver_log_prob_b = offline_data["driver_trajs"]["log_prob"]

print("Starting T-IS")

# Log probabilities of evaluation policy
network = ActorCriticContinuousAction(action_dim=action_dim, activation=config["ACTIVATION"])
pi, value = network.apply(actor_params, trajs["obs"].transpose(0, 2, 1, 3).reshape(-1, trajs["obs"].shape[1], trajs["obs"].shape[-1]))

traj_log_prob_e = pi.log_prob(trajs["action"].squeeze(2))


pi, value = network.apply(actor_params, driver_trajs["obs"].transpose(0, 2, 1, 3).reshape(-1, driver_trajs["obs"].shape[1], driver_trajs["obs"].shape[-1]))
driver_log_prob_e = pi.log_prob(driver_trajs["action"].transpose(0, 2, 1, 3).reshape(-1, driver_trajs["action"].shape[1], driver_trajs["action"].shape[-1]))
driver_log_prob_e = driver_log_prob_e.reshape(*driver_trajs["done"].shape)

traj_log_w_per_step = traj_log_prob_e - traj_log_prob_b
traj_log_w_stepwise = jnp.cumsum(traj_log_w_per_step[:, ::-1], axis=1)[:, ::-1]

driver_log_w_per_step = driver_log_prob_e - driver_log_prob_b
driver_log_w_stepwise = jnp.cumsum(driver_log_w_per_step[:, ::-1, :], axis=1)[:, ::-1, :]

traj_w_stepwise = jnp.exp(traj_log_w_stepwise)
tis_traj_values = (traj_w_stepwise * traj_returns)

driver_w_stepsize = jnp.exp(driver_log_w_stepwise)
tis_driver_values = (driver_w_stepsize * driver_returns)

tis_traj_values = tis_traj_values.reshape(-1)
tis_driver_values = tis_driver_values.reshape(-1)


@jax.jit
def compute_pdis_return(episodic_rewards, per_decision_logws, gamma=0.99):
    """Compute the discounted return for an episode."""
    
    def body_fn(carry, reward_logw):
        reward, logw = reward_logw
        discounted_return = carry * gamma + reward * jnp.exp(logw)
        return discounted_return, discounted_return
    
    discounted_return, returns = jax.lax.scan(body_fn, 0.0, (episodic_rewards, per_decision_logws), reverse=True)
    return returns

traj_per_decision_log_w = jnp.cumsum(traj_log_w_per_step, axis=1)
driver_per_decision_log_w = jnp.cumsum(driver_log_w_per_step, axis=1)
pdis_traj_values = jax.vmap(compute_pdis_return, in_axes=(0, 0, None))(trajs["reward"].squeeze(-1), traj_per_decision_log_w, config["GAMMA"])
pdis_driver_values = jax.vmap(compute_pdis_return, in_axes=(0, 0, None))(driver_trajs["reward"].transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), driver_per_decision_log_w.transpose(0, 2, 1).reshape(-1, config["driver_traj_len"]), config["GAMMA"]).reshape(*driver_returns.shape)
pdis_traj_values = pdis_traj_values.reshape(-1)
pdis_driver_values = pdis_driver_values.reshape(-1)

# Doubly robust
print("starting doubly robust")

log_prob_e = jnp.concatenate([traj_log_prob_e, driver_log_prob_e.transpose(0, 2, 1).reshape(driver_log_prob_e.shape[0], -1)], axis=1)
log_prob_e = log_prob_e.reshape(-1, 1)



def build_traj_only_logged_dataset(result):
    # 1. Pull the two datasets --------------------------------------------------
    traj_ds    : Transition = result["trajs"]

    # 2. Flatten env axis separately -------------------------------------------
    traj      = _prepare_one_dataset(traj_ds)
    traj["terminal"][:, -1] = 1
    
    # 3. Concatenate along the trajectory axis ---------------------------------
    done = traj["done"]
    value = traj["value"]
    reward = traj["reward"]
    log_prob = traj["log_prob"]
    action = traj["action"]
    obs = traj["obs"]
    terminal = traj["terminal"]
    
    
    # 4. Final flat-table shapes -----------------------------------------------
    n_traj, T = done.shape
    n_samples = n_traj * T
    act_dim   = action.shape[-1]
    obs_dim   = obs.shape[-1]

    # 5. Assemble schema --------------------------------------------------------
    logged_dataset = {
        "size":                 n_samples,
        "n_trajectories":       n_traj,
        "step_per_trajectory":  T,
        "action_type":          "continuous",
        "n_actions":            None,
        "action_dim":           act_dim,
        "action_keys":          None,
        "action_meaning":       None,
        "state_dim":            obs_dim,
        "state_keys":           None,
        "state":                obs.reshape(n_samples, obs_dim),
        "action":               action.reshape(n_samples, act_dim),
        "reward":               reward.reshape(n_samples, 1),
        "done":                 done.reshape(n_samples, 1).astype(np.float32),
        "terminal":             terminal.reshape(n_samples, 1).astype(np.float32),
        "info":                 None,
        "pscore":               np.exp(log_prob.reshape(n_samples, 1)),
        "behavior_policy":      "hopper_continuous",
        "dataset_id":           0,
    }

    # Value estimates (optional, if you want to save them)
    logged_dataset["value"] = value.reshape(n_samples, 1)

    return logged_dataset



# logged_dataset = build_logged_dataset(offline_data)
traj_only_logged_dataset = build_traj_only_logged_dataset(offline_data)


from models import PyTorchContinuousActor
pytorch_actor = PyTorchContinuousActor(
    action_dim= action_dim,
    activation="tanh",
    jax_param_dict=actor_params,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)


import torch
import numpy as np
from typing import Optional, Union, Callable

def doubly_robust_estimate(
    data: dict,
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    eval_policy: Callable[[torch.Tensor], torch.distributions.Distribution],
    eval_log_prob,
    action_type: str,
    n_actions: Optional[int] = None,          # needed only for discrete policies
    gamma: float = 1.0,
    weighted: bool = False,
    mc_samples: int = 10,                     # K for MC expectation in continuous case
    device: Optional[str] = None,
    return_all: bool = False,                 # if True, return per-(traj,t) DR values
) -> Union[float, np.ndarray]:
    """
    Doubly-Robust off-policy state-value estimator.

    Parameters
    ----------
    data            : logged batch with keys
                      {state, action, reward, terminal, pscore,
                       n_trajectories, step_per_trajectory, state_dim}.
    q_fn            : callable  Q̂(s,a) -> shape [B,1] or [B].
    eval_policy     : callable  π_e(s)  -> torch.distributions.Distribution.
    action_type     : "discrete"  |  "continuous".
    n_actions       : required for discrete to enumerate actions.
    gamma           : discount factor ∈ [0,1].
    weighted        : if True, use weighted DR (WDR); else ordinary DR.
    mc_samples      : #Monte-Carlo draws for continuous-action V̂ expectation.
    device          : torch device.  Default "cpu".
    return_all      : if True, return DR value for every (traj,t).

    Returns
    -------
    float           if return_all == False.
    np.ndarray      shape (N,H) of per-state DR values otherwise.
    """
    if "pscore" not in data or data["pscore"] is None:
        raise ValueError("Behaviour propensities `pscore` are required.")

    if device is None:
        device = torch.device("cpu")
    float_dtype = torch.float32

    # ─────────── reshape logged batch ──────────────────────────────────
    N   = int(data["n_trajectories"])
    H   = int(data["step_per_trajectory"])
    Sd  = int(data["state_dim"])
    A_shape = data["action"].shape[-1:]        # tuple (act_dim,) for cont - or () for scalar IDs

    states   = torch.as_tensor(data["state"],  dtype=float_dtype, device=device) \
                 .view(N, H, Sd)
    actions  = torch.as_tensor(data["action"], dtype=float_dtype, device=device) \
                 .view(N, H, *A_shape)
    rewards  = torch.as_tensor(data["reward"], dtype=float_dtype, device=device) \
                 .view(N, H)
    terminal = torch.as_tensor(data["terminal"], dtype=torch.bool, device=device) \
                 .view(N, H)

    # logged behaviour log-probs  log π_b(a|s)
    log_pi_b = torch.as_tensor(data["pscore"], dtype=float_dtype, device=device) \
                 .clamp_min(1e-12).log().view(N, H)

    # ─────────── helper: V̂_πe(s) = E_{a~πe}[ Q̂(s,a) ] ─────────────────
    def V_hat_pi_e(s_batch: torch.Tensor) -> torch.Tensor:        # [B]
        if action_type == "discrete":
            if n_actions is None:
                raise ValueError("n_actions required for discrete V̂ expectation.")
            dist  = eval_policy(s_batch)                          # Categorical
            probs = dist.probs                                     # [B, n_actions]
            # enumerate actions 0..n_actions-1
            a_enum = torch.arange(n_actions, device=device)
            a_rep  = a_enum.unsqueeze(0).repeat(s_batch.size(0), 1)  # [B,n_actions]
            q_vals = q_fn(
                s_batch.repeat_interleave(n_actions, dim=0),      # [B*n, Sd]
                a_rep.flatten().unsqueeze(-1).float()             # scalar IDs → [B*n,1]
            ).view(-1, n_actions)                                 # [B, n_actions]
            return (probs * q_vals).sum(-1)                       # expectation over a
        else:  # continuous
            dist = eval_policy(s_batch)                           # some torch Distribution
            # print(dist.mode.device)
            # print(dist.sample())
            a_samples = dist.rsample((mc_samples,))               # [K,B,*A_shape]
            q_vals = q_fn(
                s_batch.repeat(mc_samples, 1),                    # [K*B,Sd]
                a_samples.view(-1, *A_shape)                      # [K*B,*A_shape]
            ).view(mc_samples, -1)                                # [K,B]
            return q_vals.mean(0)                               # MC expectation
            # a_batch = eval_policy(s_batch)                           # some torch Distribution
            # q_vals = q_fn(s_batch, a_batch)
            # print(q_vals.shape)
            # return q_vals                                 # MC expectation

    # ─────────── pre-compute V̂(s) and V̂(s') ────────────────────────────
    flat_states = states.view(-1, Sd)
    print("device of flat_states", flat_states.device)
    print("device of actions", actions.device)
    # print("device of q model", q_fn.)
    V_hat_all   = V_hat_pi_e(flat_states).view(N, H)              # V̂(s_t)
    V_hat_next  = torch.zeros_like(V_hat_all)
    V_hat_next[:, :-1] = V_hat_all[:, 1:]
    V_hat_next *= (~terminal).float()                            # zero after real terminals

    # ─────────── Q̂(s,a_b) on logged actions ────────────────────────────
    Q_hat_beh = q_fn(flat_states, actions.view(-1, *A_shape)).view(N, H)

    # ─────────── eval log-probs  log π_e(a_b | s) ───────────────────────
    # with torch.no_grad():
    #     dist_e  = eval_policy(flat_states)                        # distribution object
    #     log_pi_e = dist_e.log_prob(actions.view(-1, *A_shape))    # [N*H]
    # log_pi_e = log_pi_e.view(N, H)
    log_pi_e = np.array(eval_log_prob).reshape(N, H)  # [N, H]
    log_pi_e = torch.as_tensor(log_pi_e, dtype=float_dtype, device=device) \
    
    # ─────────── importance ratios ──────────────────────────────────────
    log_rho   = log_pi_e - log_pi_b                               # log ρ_t
    log_cum_w = torch.cumsum(log_rho, dim=1)                      # log Π_{0..t} ρ
    cum_w     = torch.exp(log_cum_w)

    # ─────────── mask steps beyond first terminal per trajectory ────────
    done_cum   = torch.cumsum(terminal.int(), dim=1)
    valid_mask = (done_cum <= 1).float()
    cum_w   *= valid_mask
    rewards *= valid_mask

    # ─────────── TD residual and discounts ──────────────────────────────
    delta   = rewards + gamma * V_hat_next - Q_hat_beh
    gammas  = gamma ** torch.arange(H, dtype=float_dtype, device=device)  # [H]

    # ─────────── Weighted DR (WDR / SN-DR)  ─────────────────────────────
    if weighted:
        eps = 1e-2
        w_sum_t = cum_w.sum(0).clamp_min(eps)                     # [H]
        baseline = (cum_w[:, 0] * V_hat_all[:, 0]).sum() / w_sum_t[0]
        corr = (gammas * (cum_w * delta).sum(0) / w_sum_t).sum()
        return (baseline + corr).cpu().item()

    # ─────────── Ordinary (unweighted) DR  ──────────────────────────────
    if not return_all:
        dr_per_traj = V_hat_all[:, 0] + (gammas * cum_w * delta).sum(1)
        return dr_per_traj.mean().cpu().item()

    # ─────────── Per-state DR (return_all=True) ─────────────────────────
    # cum_w_prev[:,t] = Π_{j< t} ρ_j
    cum_w_prev = torch.cat([torch.ones(N, 1, dtype=float_dtype, device=device),
                            cum_w[:, :-1]], dim=1)

    # γ^{k-t} lower-triangular matrix  L_{t,k} = γ^{k-t}   (for k ≥ t)
    idx = torch.arange(H, device=device)
    L   = gamma ** (idx.unsqueeze(1) - idx.unsqueeze(0)).clamp_min(0)  # [H,H]

    M   = cum_w * delta                                                # [N,H]
    out = torch.einsum('tk,nk->nt', L, M)                              # discounted sums
    V_dr_all = V_hat_all + out / (cum_w_prev.clamp_min(1e-1))

    return V_dr_all.detach().cpu().numpy()                                      # shape (N,H)


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
def eval_pi(s_batch: torch.Tensor) -> torch.Tensor:
    # This is behavior policy used for the dataset and we are using it as the evaluation policy for sanity check.
    return pytorch_actor(s_batch.to(device))

def eval_log_prob(states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
    # Use the PyTorch actor to get the log probabilities
    states, actions = states.to(device), actions.to(device)
    with torch.no_grad():
        policy = pytorch_actor(states)
        log_prob = policy.log_prob(actions)
    return log_prob

dr_traj_values = doubly_robust_estimate(
    traj_only_logged_dataset,
    q_fn=fqe.q.to(device),
    eval_policy=eval_pi,
    eval_log_prob=np.array(traj_log_prob_e),
    action_type=traj_only_logged_dataset["action_type"],
    n_actions=None,
    gamma=config["GAMMA"],
    weighted=False,
    device=device,
    return_all=True  # Set to True if you want per-state values
)
dr_traj_values = dr_traj_values.reshape(-1)


# # MC Evaluation
print("Starting MC")

selected_indices = np.random.choice(np.arange(0, np.prod(trajs["done"].shape), config["general_traj_len"]), size=num_eval_samples, replace=False) # Select the first state of each trajectory
print(selected_indices)


selected_indices_x, selected_indices_y = np.unravel_index(selected_indices, trajs["done"].shape[:-1])


eval_obs = trajs["obs"][selected_indices_x, selected_indices_y, :, :].reshape(-1, trajs["obs"].shape[-1])



rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, num_eval_samples)
_, eval_env_states = env.reset(reset_rng, None)


def replace_inner_env_state(driver_env_states: Any, env_states: Any, idx: jnp.ndarray):
    k = idx.size

    def tails_compatible(env_shape, rest_shape):
        if not rest_shape:
            return True
        return env_shape[-len(rest_shape):] == rest_shape

    def _replace(driver_leaf, env_leaf):
        d = jnp.asarray(driver_leaf)
        e = jnp.asarray(env_leaf)
        if d.ndim == 0 or d.shape[0] != k:
            return driver_leaf
        rest_shape = d.shape[1:]
        assert tails_compatible(e.shape, rest_shape), (
            f"Trailing shape mismatch: driver {d.shape}, env {e.shape}"
        )
        e_flat = e.reshape((-1, *e.shape[-len(rest_shape):])) if rest_shape else e.reshape((-1,))
        selected = e_flat[idx]
        return selected.reshape(d.shape)

    return jax.tree_util.tree_map(_replace, driver_env_states, env_states)


import flax.serialization as ser
env_states_recasted = ser.from_state_dict(eval_env_states,env_states)



new_eval_env_states = replace_inner_env_state(
    eval_env_states,
    env_states_recasted,
    selected_indices
)


# Using these eval states and the observations corresponding to them, we will collect trajectories for value estimation


@partial(jax.jit, static_argnames=("traj_len",))
def collect_trajectory(params, env_state, obsv, rng, traj_len):
    def _env_step(runner_state, _):
        params, env_state, last_obs, rng = runner_state
        B = last_obs.shape[0]
        rng, sub = jax.random.split(rng)
        pi, value = network.apply(params, last_obs)
        action = pi.sample(seed=sub)
        log_prob = pi.log_prob(action)
        rng, sub = jax.random.split(rng)
        rng_step = jax.random.split(sub, B)
        obsv, env_state, reward, done, info = env.step(rng_step, env_state, action, env_params)
        transition = Transition(done, action, value, reward, log_prob, last_obs, info)
        return (params, env_state, obsv, rng), (transition, env_state)

    runner_state, (traj_batch, env_states) = jax.lax.scan(
        _env_step, (params, env_state, obsv, rng), xs=None, length=traj_len)
    return traj_batch, env_states



rng_traj_runs = jax.random.split(rng, num_trajs_per_state)
trajs, _ = jax.vmap(collect_trajectory, in_axes=(None, None, None, 0, None))(
    actor_params,
    new_eval_env_states,
    eval_obs,
    rng_traj_runs,
    config["general_traj_len"]
)


mc_eval_returns, _ = jax.vmap(compute_discounted_episodic_return, in_axes=(0))(trajs.reward.transpose(2, 1, 0).mean(axis=-1))



with open(f'{config["SAVE_DIR"]}/{experiment_name}_predtran_train_data.pkl', "rb") as f:
    predtran_train_data = pickle.load(f)

driver_states = predtran_train_data["driver_states"]
driver_env_states = predtran_train_data["driver_env_states"]


predictor = PredictabilityHead(config["num_heads"],
                               config["hidden_dim"],
                               config["num_layers"],)


rng, driver_rng = jax.random.split(rng)


driver_trajs, _ = collect_trajectory(actor_params, 
                                    driver_env_states,
                                    driver_states,
                                    driver_rng,
                                    config["driver_traj_len"]
                                    )


driver_returns, _ = jax.vmap(
    compute_discounted_episodic_return, in_axes=(0, None)
)(driver_trajs.reward.transpose((1, 0)).reshape(-1, config['driver_traj_len']), config["GAMMA"])


# Feature extractor

feature_extractor = FeatExtractorDiscreteAction(activation=config["ACTIVATION"])


feature_extractor_params = pred_transformer_params["feat_extractor_params"]


driver_states_processed = feature_extractor.apply(feature_extractor_params, driver_states)


query_states = eval_obs
query_states_processed = feature_extractor.apply(feature_extractor_params, query_states)


predictor_value_estimates = predictor.apply(
    predictor_params, 
    driver_states_processed, 
    driver_returns,
    query_states_processed,
)


# Collect results
print("Collecting results..")
fqe_eval_values = np.array(fqe_traj_values[selected_indices])
tis_eval_values = np.array(tis_traj_values[selected_indices])
pdis_eval_values = np.array(pdis_traj_values[selected_indices])
dr_eval_values = np.array(dr_traj_values[selected_indices])
mc_eval_values = np.array(mc_eval_returns)
predictor_eval_values = np.array(predictor_value_estimates)


# Save eval values
eval_values = {
    "fqe": fqe_eval_values,
    "tis": tis_eval_values,
    "pdis": pdis_eval_values,
    "dr": dr_eval_values,
    "mc": mc_eval_values,
    "predictor": predictor_eval_values
}

with open(f'{config["SAVE_DIR"]}/{save_experiment_name}_eval_values.pkl', "wb") as f:
    pickle.dump(eval_values, f)




for i in range(num_eval_samples):
    wandb.log({
        "fqe": fqe_eval_values[i],
        "tis": tis_eval_values[i],
        "pdis": pdis_eval_values[i],
        "dr": dr_eval_values[i],
        "mc": mc_eval_values[i],
        "predictor": predictor_eval_values[i]
    })
    
for key, values in eval_values.items():
    print("MAE: ", key, np.abs(values - mc_eval_values).mean())