
import concurrent.futures
from glob import glob
import pathlib
import datetime
import json
import h5py
import pandas as pd
import time
from collections import deque, defaultdict

import numpy as np
import torch

# original imports
import sys
from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log

# dmlab examples
from sf_examples.dmlab.cumstom_enjoy import enjoy
from sf_examples.dmlab.train_hippo2025 import parse_dmlab_args, register_dmlab_components

# shared CLI args
mapname = "openfield_map2_fixed_loc3"
expname = 'Replicating_main'
cli = [
    "--algo", "APPO",
    "--env", mapname,
    "--experiment", expname,
    "--encoder_load_path", "./models/best_000025288_203030528_reward_94.185.pth",
    "--train_dir", "./train_dir",
    "--max_num_frames", "50000",
    "--num_envs", "8",
    "--dmlab_level_cache_path", "./.dmlab_cache",
    "--load_checkpoint_kind", "latest",
    "--use_jit", "False",
    "--with_pos_obs", "True",
    "--no_render",
]
cli_dict = {k.strip('-'): v for k, v in zip(cli[::2], cli[1::2])}


def _ensure_parent(path: pathlib.Path):
    path.parent.mkdir(parents=True, exist_ok=True)


def get_cfg():
    register_dmlab_components()
    cfg = parse_dmlab_args(evaluation=True, argv=cli)
    cfg.cli_args = cli_dict
    cfg = load_from_checkpoint(cfg)
    return cfg


def run_single(pth_path: str):
    cfg = get_cfg()
    cfg.use_jit = False
    # adjust eval settings
    eval_env_fs = cfg.env_frameskip
    render_action_repeat = cfg.env_frameskip // eval_env_fs
    cfg.env_frameskip = cfg.eval_env_frameskip = eval_env_fs
    cfg.num_envs = 1
    render_mode = None

    env = make_env_func_batched(
        cfg,
        env_config=AttrDict(worker_index=0, vector_index=0, env_id=0),
        render_mode=render_mode
    )
    env_info = extract_env_info(env, cfg)
    if hasattr(env.unwrapped, "reset_on_init"):
        env.unwrapped.reset_on_init = False

    actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space)
    actor_critic.eval()
    device = torch.device("cpu" if cfg.device == "cpu" else "cuda")
    actor_critic.model_to_device(device)

    # hooks
    act_buffers = defaultdict(list)
    layers_to_log = [
        'encoder.basic_encoder.mlp_layers.0',
        "encoder.DG_projection.linear",
        "core",
        "decoder.mlp.0",
        "decoder.mlp.2"
    ]
    def make_hook(name):
        def _hook(_m, _inp, out):
            if isinstance(out, (tuple, list)):
                out = out[0]
            act_buffers[name].append(out.detach().cpu())
        return _hook
    for layer in layers_to_log:
        dict(actor_critic.named_modules())[layer].register_forward_hook(make_hook(layer))

    # load
    checkpoint = torch.load(pth_path, device)
    actor_critic.load_state_dict(checkpoint["model"])

    num_frames = 0
    pose_records = []
    obs, infos = env.reset()
    rnn_states = torch.zeros([env.num_agents, get_rnn_size(cfg)], device=device)
    reward_list = []
    num_episodes = 0

    with torch.no_grad():
        while not (cfg.max_num_frames and num_frames > cfg.max_num_frames):
            normalized_obs = prepare_and_normalize_obs(actor_critic, obs)
            outputs = actor_critic(normalized_obs, rnn_states)
            actions = outputs["actions"]
            if cfg.eval_deterministic:
                actions = argmax_actions(actor_critic.action_distribution())
            if actions.ndim == 1:
                actions = unsqueeze_tensor(actions, dim=-1)
            actions = preprocess_actions(env_info, actions)
            rnn_states = outputs["new_rnn_states"]

            for _ in range(render_action_repeat):
                obs, rew, terminated, truncated, infos = env.step(actions)
                pos = obs['DEBUG.POS.TRANS']
                rot = obs['DEBUG.POS.ROT']
                for i in range(env.num_agents):
                    pose_records.append({
                        "frame": num_frames,
                        "agent": i,
                        "x": float(pos[i,0]),
                        "y": float(pos[i,1]),
                        "z": float(pos[i,2]),
                        "rot_x": float(rot[i,0]),
                        "rot_y": float(rot[i,1]),
                        "rot_z": float(rot[i,2]),
                        "info": json.dumps(infos[i], default=str)
                    })
                dones = make_dones(terminated, truncated)
                num_frames += 1
                if all(dones.cpu().numpy()):
                    break
            if num_frames > cfg.max_num_frames:
                break

    env.close()

    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    telemetry = pathlib.Path(experiment_dir(cfg=cfg)) / "telemetry" / pathlib.Path(pth_path).name
    _ensure_parent(telemetry)
    telemetry.mkdir(parents=True, exist_ok=True)

    df = pd.DataFrame(pose_records)
    df.to_csv(telemetry / f"pose_{ts}.csv", index=False)

    with h5py.File(telemetry / f"activations_{ts}.h5", "w") as h5f:
        for layer, lst in act_buffers.items():
            if not lst:
                continue
            data = torch.cat(lst, dim=0).numpy()
            h5f.create_dataset(layer, data=data, compression="gzip")

    return pathlib.Path(pth_path).name


def main():
    root = pathlib.Path(
        f"./train_dir/{expname}"
        "/checkpoint_p0/milestones"
    )
    ckpt_files = sorted(root.glob("*.pth"), key=lambda x: x.name)
    

    max_workers = min(len(ckpt_files), 3)
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(run_single, str(p)): p for p in ckpt_files}
        for fut in concurrent.futures.as_completed(futures):
            name = futures[fut].name
            try:
                result = fut.result()
                print(f"[✔] {name} completed.")
            except Exception as e:
                print(f"[✖] {name} failed: {e}")

if __name__ == "__main__":
    main()
