import os
import argparse
import numpy as np
import torch
from networkx.algorithms.core import onion_layers

from config.config_loader import ConfigLoader
from environments.env_loader import parallel_env_maker
from torchrl.envs.utils import step_mdp
from utils import create_actor_critic, group_episodes_by_outcome, get_device
from minigrid.core.constants import OBJECT_TO_IDX

AGENT_IDX = OBJECT_TO_IDX["agent"]

def main(args, config):
    device = get_device(config)

    if args.random:
        torch.manual_seed(args.base_seed)
        np.random.seed(args.base_seed)

    actor, critic = create_actor_critic(config, device)

    if args.ckpt_dir.endswith('.pt'):
        checkpoint = torch.load(args.ckpt_dir, map_location=device)
    else:
        checkpoint = torch.load(f'{args.ckpt_dir}/model-best.pt', map_location=device)

    if isinstance(checkpoint, dict) and 'actor' in checkpoint:
        actor.load_state_dict(checkpoint['actor'])
    else:
        actor.load_state_dict(checkpoint)

    env = parallel_env_maker(config, 1, device=device, base_seed=args.base_seed)
    env.eval()

    rewards = []
    lengths = []
    trajs = [] if args.save else None
    all_images = [] if args.Encoding else None

    for i in range(args.n_episodes):
        td = env.reset()
        traj = []
        print(f'Episode {i}')

        max_steps = args.max_step if args.max_step is not None else config['env_config']['max_steps']

        for _ in range(max_steps):
            if config['env'] == 'metaworld':
                td['frame'] = torch.as_tensor(np.stack([f for f in env.render()]))
            elif config['env'] == 'single_cook':
                # td['frame'] = torch.as_tensor(np.stack(env.get_frame(color_mode="RGB", reward_mode="sparse")))
                print("skip frame in overcooked")
            elif config['env'] in ['multi_grid', 'two_goal_grid']:
                td['frame'] = torch.as_tensor(np.stack(env.get_frame(highlight=not config['env_config']['see_through_walls'])))
            else:
                td['frame'] = torch.as_tensor(np.stack(env.get_frame(highlight=not config['env_config']['fully_observable'])))

            actor(td)

            if args.random:
                if "action_mask" in td.keys():
                    action_mask = td["action_mask"].squeeze(0)
                    valid_actions = torch.nonzero(action_mask, as_tuple=True)[0]
                    random_action_index = valid_actions[torch.randint(0, len(valid_actions), (1,))]
                else:
                    num_actions = td["logits"].shape[-1]
                    random_action_index = torch.randint(0, num_actions, (1,))

                random_action = torch.nn.functional.one_hot(random_action_index,
                                                            num_classes=td["logits"].shape[-1]).float()
                td["action"] = random_action

            if args.Encoding and "image" in td.keys() and not td["done"].any():
                if "recipe" in td.keys():
                    all_images.append({"image": td["image"], "recipe": td["recipe"]})
                else:
                    all_images.append(td["image"])

            traj.append(td)
            env.step(td)
            td = step_mdp(td)
            if td["done"].any() and config['env'] == "single_cook":
                print(env.get_sparse_reward())
                break

            if td["done"].any() and not args.continue_episode:
                if config['env'] in ['multi_grid', 'two_goal_grid']:
                    td['frame'] = torch.as_tensor(np.stack(env.get_frame(highlight=not config['env_config']['see_through_walls'])))
                elif config['env'] == 'single_cook':
                    # td['frame'] = torch.as_tensor(np.stack(env.get_frame()))
                    print("skip frame in overcooked")
                else:
                    td['frame'] = torch.as_tensor(np.stack(env.get_frame(highlight=not config['env_config']['fully_observable'])))

                actor(td)
                if args.random:
                    if "action_mask" in td.keys():
                        action_mask = td["action_mask"].squeeze(0)
                        valid_actions = torch.nonzero(action_mask, as_tuple=True)[0]
                        random_action_index = valid_actions[torch.randint(0, len(valid_actions), (1,))]
                    else:
                        num_actions = td["logits"].shape[-1]
                        random_action_index = torch.randint(0, num_actions, (1,))

                    random_action = torch.nn.functional.one_hot(random_action_index,
                                                                num_classes=td["logits"].shape[-1]).float()
                    td["action"] = random_action

                traj.append(td)
                env.step(td)
                td = step_mdp(td)

        data = torch.stack(traj, 0)

        done_mask = data["next", "done"]
        reward_done = data["next", "reward"][done_mask]
        step_done = data["next", "step_count"][done_mask]

        if reward_done.numel() > 0:
            reward_value = reward_done.max().item()
            step_value = step_done.max().item()
        else:
            reward_value = 0.0
            step_value = max_steps

        print(f"E{i} total rewards:", round(reward_value, 3))

        rewards.append(reward_value)
        lengths.append(step_value)

        if args.save:
            trajs.append(traj)

    print(f'\nMean reward over {args.n_episodes} episodes: {np.array(rewards).mean()}')
    print(f'Mean length over {args.n_episodes} episodes: {np.array(lengths).mean()}')

    save_dir = os.path.dirname(args.ckpt_dir) if args.ckpt_dir.endswith('.pt') else args.ckpt_dir
    if args.ckpt_dir.endswith('.pt'):
        check_num = args.ckpt_dir.split("/")[-1].split("-")[1].split(".")[0]
    else:
        check_num = "model-best"

    name = f"{check_num}_{args.n_episodes}ep"
    if args.random:
        name = f"random_{name}"
    if args.continue_episode:
        name = f"cont_{name}"
    if args.max_step is not None:
        name = f"{name}_{args.max_step}step"

    if args.save:
        torch.save(trajs, f"{save_dir}/{name}_{args.config}.pt")
        print(f"Saved trajectory to {save_dir}/{name}_{args.config}.pt")

    if args.Encoding and all_images:
        encoding_dir = os.path.join(save_dir, "Encodings")
        os.makedirs(encoding_dir, exist_ok=True)
        encoding_path = os.path.join(encoding_dir, f"evaluate_{name}_{args.config}.pt")

        if isinstance(all_images[0], dict):
            keys = all_images[0].keys()
            stacked = {k: torch.cat([d[k] for d in all_images], dim=0) for k in keys}
            with open(encoding_path, 'wb') as f:
                torch.save(stacked, f)
                reloaded = torch.load(encoding_path)
                for k in keys:
                    assert torch.allclose(stacked[k], reloaded[k]), f"Mismatch in key '{k}' after saving and loading!"
                    print(f"Saved {k} tensor to {encoding_path} with shape: {tuple(stacked[k].shape)}")
        else:
            image_tensor = torch.cat(all_images, dim=0)
            with open(encoding_path, 'wb') as f:
                torch.save(image_tensor, f)
                reloaded = torch.load(encoding_path)
                assert torch.allclose(image_tensor, reloaded), "Mismatch after saving and loading!"
            print(f"Saved image tensor to {encoding_path} with shape: {tuple(image_tensor.shape)}")



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="multi_grid", help="Configuration file to use.")
    parser.add_argument("--n_episodes", type=int, default=100, help="Number of evaluation episodes")
    parser.add_argument("--ckpt_dir", type=str, help="Directory for saved pretrained model")
    parser.add_argument("--save", action='store_true')
    parser.add_argument("--random", action='store_true', help="If set, agent will take random actions")
    parser.add_argument("--base_seed", type=int, default=1000, help="Base seed to ensure reproducible env reset")
    parser.add_argument("--cont", dest="continue_episode", action='store_true', help="If set, continue episodes even after done is True")
    parser.add_argument("--max_step", type=int, default=None, help="Override max steps per episode")
    parser.add_argument("--Encoding", action='store_true', help="If set, save all valid image tensors and print actual shape")

    args = parser.parse_args()
    config = ConfigLoader.load_config(args.config, None)
    main(args, config)
