import torch
import numpy as np
import sys
import json
import imageio
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.visualization.state_visualizer import StateVisualizer

from zsceval.config import get_config
from zsceval.envs.env_wrappers import ShareDummyVecEnv
from zsceval.envs.overcooked.Overcooked_Env import Overcooked
from zsceval.envs.overcooked_new.Overcooked_Env import Overcooked as Overcooked_new
from zsceval.overcooked_config import get_overcooked_args
from zsceval.utils.train_util import setup_seed

# Load agent model
def load_agent(model_path, device):
    model = torch.load(model_path, map_location=device)
    return model

# Initialize the evaluation environment
def make_eval_env(all_args, run_dir):
    def get_env_fn(rank):
        def init_env():
            if all_args.overcooked_version == "old":
                env = Overcooked(all_args, run_dir, evaluation=True)
            else:
                env = Overcooked_new(all_args, run_dir, evaluation=True)
            env.seed(all_args.seed * 50000 + rank * 10000)
            return env

        return init_env

    return ShareDummyVecEnv([get_env_fn(0)])

def rollout_and_save(env, agents, device, max_steps=100, layout_name="coordination_ring", output_path="trajectories.json"):
    mdp = OvercookedGridworld.from_layout_name(layout_name)
    visualizer = StateVisualizer()
    obs, share_obs, available_actions = env.reset()  # Adjusted to unpack the tuple    done = False
    done = False
    step = 0
    trajectories = []

    while not done and step < max_steps:
        actions = []
        for i, agent in enumerate(agents):
            obs_tensor = torch.tensor(share_obs[0,i], dtype=torch.float32).to(device)
            with torch.no_grad():
                action = agent.act(obs_tensor).cpu().numpy()
            actions.append(action)

        actions = np.array(actions).transpose((1, 0))
        next_obs, rewards, dones, infos = env.step(actions)
        trajectories.append({
            "observations": obs.tolist(),
            "actions": actions.tolist(),
            "rewards": rewards.tolist(),
            "dones": dones.tolist(),
            "infos": infos
        })
        obs = next_obs
        done = dones.any()
        step += 1

    # Save trajectories to a JSON file
    with open(output_path, "w") as f:
        json.dump(trajectories, f)
    print(f"Saved trajectories to {output_path}")

def main(arg):
    # Load args and setup
    parser = get_config()
    parser = get_overcooked_args(parser)
    parser.add_argument(
        "--use_phi",
        default=False,
        action="store_true",
        help="While existing other agent like planning or human model, use an index to fix the main RL-policy agent.",
    )

    parser.add_argument("--use_task_v_out", default=False, action="store_true")
    args = parser.parse_args(arg)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    setup_seed(args.seed)

    run_dir = "~/ZSC/results/Overcooked/random1/mappo/hsp-testS1/wandb/run-20250217_162126-ylgvd4eq/files"

    agent0_path = run_dir + "/actor_agent0_periodic_10000000.pt"
    agent1_path = run_dir + "/actor_agent1_periodic_10000000.pt"
    # Load models
    agent0 = load_agent(agent0_path, device)
    agent1 = load_agent(agent1_path, device)

    # Create environment
    eval_env = make_eval_env(args, run_dir)

    # Run rollout and save trajectories
    rollout_and_save(eval_env, [agent0, agent1], device)

    eval_env.close()

if __name__ == "__main__":
    main(sys.argv[1:])