# train_utils.py
import os
from typing import Optional, Dict, Any, Tuple, List

import gymnasium as gym
import torch
import torch.nn as nn
import numpy as np
import json
import minigrid
from minigrid.wrappers import FullyObsWrapper
from minigrid.core.world_object import Door as MGDoor
from minigrid.core.world_object import Wall as MGWall
from gymnasium.wrappers import TimeLimit
from omegaconf import OmegaConf

HIDDEN_DIM = 256
MAX_STEPS_PER_EPISODE = 200
META_FIXED_SEED = 42
DEFAULT_CONFIG_DIR = "configs"

class FixedSeedWrapper(gym.Wrapper):
    def __init__(self, env, fixed_seed: Optional[int] = None, allow_override: bool = True):
        super().__init__(env)
        self.fixed_seed = fixed_seed
        self.allow_override = allow_override

    def reset(self, *, seed=None, **kwargs):
        if seed is None or not self.allow_override:
            assert self.fixed_seed is not None, "Fixed seed is not set!!!"
            seed_to_use = int(self.fixed_seed) if self.fixed_seed is not None else None
        else:
            seed_to_use = int(seed)
        return self.env.reset(seed=seed_to_use, **kwargs)

def _load_cfg(game_code: str, config_dir: str = DEFAULT_CONFIG_DIR):
    base = OmegaConf.load(os.path.join(config_dir, "base.yaml"))
    envp = OmegaConf.load(os.path.join(config_dir, f"{game_code}.yaml"))
    cfg = OmegaConf.merge(base, envp)  
    OmegaConf.resolve(cfg)
    return cfg



def setup_environment(game_code, max_steps=MAX_STEPS_PER_EPISODE, meta_fixed_seed=META_FIXED_SEED, render_mode=None):
    cfg = _load_cfg(game_code)
    env_id = str(cfg.env.env_id)
    base_env = gym.make(env_id, render_mode='rgb_array', max_episode_steps=max_steps)
    base_env = FullyObsWrapper(base_env)
    env = FixedSeedWrapper(base_env, fixed_seed=meta_fixed_seed, allow_override=True)
    return env