import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
from gymnasium.wrappers import FlattenObservation

import envs.GridWorld
import envs.GridWorldXOR
from envs.cube2x2 import *

from envs.wrappers import *

import minigrid
from minigrid.wrappers import *

import popgym
from popgym.wrappers import PreviousAction, Antialias, Flatten, DiscreteAction

import bsuite
from bsuite.utils import gym_wrapper


# Environment setup (unchanged)
def make_env(args, seed):
    name = f"{args.algo}-arch_{args.arch}-env_{args.env}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"

    if "tmaze" in args.env:
        env = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active,
                    continual=args.continual, fix_start=True, goal_obs=False,
                    fully_obs=(args.mask_type=="fully_obs"), render_mode=args.render_mode)
        maze_type = "active" if args.active else "passive"
        name = f"{args.algo}-arch_{args.arch}-env_{maze_type}_tmaze-v0_maze_length_{args.maze_length+2}-random_length_{args.random_length}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"
        if args.continual:
            name = f"{args.algo}-arch_{args.arch}-env_{maze_type}_tmaze-continual-v0_maze_length_{args.maze_length+2}-random_length_{args.random_length}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"
    
    elif "xormaze" in args.env:
        env = gym.make(args.env, length=args.maze_length, random_length=args.random_length, active=args.active,
                    continual=False, fix_start=True, goal_obs=False,
                    fully_obs=(args.mask_type=="fully_obs"), render_mode=args.render_mode)
        maze_type = "active" if args.active else "passive"
        name = f"{args.algo}-arch_{args.arch}-env_{maze_type}_xormaze-v0_maze_length_{args.maze_length}-random_length_{args.random_length}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"
        if args.continual:
            name = f"{args.algo}-arch_{args.arch}-env_{maze_type}_xormaze-continual-v0_maze_length_{args.maze_length}-random_length_{args.random_length}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"
    
    elif "cube" in args.env:
        if args.mask_type=="fully_obs":
            cube_cam = "full"
        else: 
            cube_cam = args.cube_cam
        env = gym.make(args.env, episode_steps=100, scramble_steps=args.scramble_steps, random_length=args.random_length,
                    cube_cam=cube_cam, seed=args.run+seed, render_mode=args.render_mode)
        name = f"{args.algo}-arch_{args.arch}-env_cube-v0_scramble_steps_{args.scramble_steps}-random_length_{args.random_length}-cube_cam_{cube_cam}-num_stack_{args.num_stack}-mask_type_{args.mask_type}-run_{args.run}"
    
    elif "popgym" in args.env:
        if "PositionOnlyCartPoleHard" in args.env:
            env = popgym.envs.position_only_cartpole.PositionOnlyCartPoleHard()
            env = DiscreteAction(Flatten(PreviousAction(env)))
        elif "VelocityOnlyCartPoleHard" in args.env:
            env = popgym.envs.velocity_only_cartpole.VelocityOnlyCartPoleHard()
            env = DiscreteAction(Flatten(PreviousAction(env))) 
        elif "NoisyPositionOnlyCartPole" in args.env:
            env = popgym.envs.noisy_position_only_cartpole.NoisyPositionOnlyCartPole()
            env = DiscreteAction(Flatten(PreviousAction(env))) 
        elif "concentration" in args.env:
            env = popgym.envs.concentration.ConcentrationEasy()
            env = DiscreteAction(Flatten(PreviousAction(env)))
        elif "autoencode" in args.env:
            env = popgym.envs.autoencode.AutoencodeEasy()
            env = DiscreteAction(Flatten(PreviousAction(env)))
        elif "repeat_previous" in args.env:
            env = popgym.envs.repeat_previous.RepeatPreviousEasy()
            env = DiscreteAction(Flatten(PreviousAction(env)))
        else:
            env = DiscreteAction(Flatten(PreviousAction(gym.make(args.env, render_mode=args.render_mode))))

    elif "MiniGrid" in args.env:
        env = gym.make(args.env, render_mode=args.render_mode)#, agent_view_size=3)
        # env = FlatObsWrapper(OneHotPartialObsWrapper(env))
        if True: #args.mask_type=="fully_obs":
            # env = FullyObsWrapper(env)
            # env = OneHotPartialObsWrapper(env)
            env = ImgObsWrapper(env)

    elif "bsuite" in args.env:
        env = gym.make(args.env, render_mode=args.render_mode)
        env = bsuite.load_and_record_to_csv('catch/0', results_dir='/path/to/results')
        env = gym_wrapper.GymFromDMEnv(env)

    elif "Fetch" in args.env:
        env = gym.make(args.env, render_mode=args.render_mode, max_episode_steps=args.max_episode_steps)
        env = PartialObsGoal(env, visible_goal_steps=args.visible_goal_steps)
        env = FlattenObservation(env)

    else:
        env = gym.make(args.env, render_mode=args.render_mode, max_episode_steps=args.max_episode_steps)
        env = FlattenObservation(env)

        # env = gym.make(args.env, render_mode="human")
        # env = PartialObsGoal(env)
        # env = FlattenObservation(env)
        # for _ in range(100):
        #     print(_,"____________________________________________________")
        #     s,_ = env.reset()
        #     for _ in range(1000):
        #         action = env.action_space.sample()
        #         s,r,d,t,i = env.step(action)
        #         env.render()
        #         if "is_success" in i:
        #             if i["is_success"]: print(_,"success")
        #         if d: break
        # asdasdf
        

    use_multidiscrete = True
    if args.mask_type=="framestack":
        env = FrameStack(env, args.num_stack)
    if args.mask_type=="demir":
        env = DemirFrameStack(env, args.num_stack)
    if "masked" in args.mask_type:
        env = MaskedFrameStack(env, args.num_stack, use_multidiscrete=use_multidiscrete)

    # if len(env.observation_space.shape) == 4:
    #     env = FlatObs(env)

    if args.algo == "QL":
        env = TupleObs(env)

    return env, name
