import datetime
from d4rl.locomotion.wrappers import NormalizedBoxEnv
import numpy as np
from rlf.envs.env_interface import get_env_interface
from rlf.envs.widowx_interface import EasyObsWidowxWrapper
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image

from acil_envs.half_cheetah_interface import HalfCheetah
from acil_envs.hopper_interface import Hopper
from acil_envs.maze2d_interface import Maze2d
from iq_learn.dataset.memory import Memory
import os
from PIL import Image
from omegaconf import DictConfig, OmegaConf

from demo_collection.utils.constrain_wrapper import ActionSpaceBoxWrapper, ActionDimBlockWrapper, EnvObsWrapper

def set_up_log_dirs(args, run_id):
    # Setup logging
    def mkdir(path):
        if not os.path.exists(path):
            os.makedirs(path)
    log_dir = os.path.join(args.log_dir, run_id)

    wandb_dir = os.path.join(log_dir, 'wandb')
    agent_save_dir = os.path.join(log_dir, 'agent_model')
    agent_best_dir = os.path.join(log_dir, 'results_best')
    reward_save_dir = os.path.join(log_dir, 'reward_model')
    video_save_dir = os.path.join(log_dir, 'videos')

    mkdir(args.log_dir)
    mkdir(log_dir)
    mkdir(wandb_dir)
    mkdir(agent_save_dir)
    mkdir(agent_best_dir)
    mkdir(reward_save_dir)
    mkdir(video_save_dir)
    logging(f'Log dir: {log_dir}')
    logging(f'Wandb dir: {wandb_dir}')
    logging(f'Agent save dir: {agent_save_dir}')
    logging(f'Agent best dir: {agent_best_dir}')
    logging(f'Reward save dir: {reward_save_dir}')
    logging(f'Video save dir: {video_save_dir}')
    return log_dir, wandb_dir, agent_save_dir, agent_best_dir, reward_save_dir, video_save_dir

def logging(*msg):
    # def prRed(prt): print("\033[91m {}\033[00m".format(prt))
    # def prGreen(prt): print("\033[92m {}\033[00m".format(prt))
    # def prYellow(prt): print("\033[93m {}\033[00m".format(prt))
    # def prLightPurple(prt): print("\033[94m {}\033[00m".format(prt))
    # def prPurple(prt): print("\033[95m {}\033[00m".format(prt))
    # def prCyan(prt): print("\033[96m {}\033[00m".format(prt))
    # def prLightGray(prt): print("\033[97m {}\033[00m".format(prt))
    # def prBlack(prt): print("\033[98m {}\033[00m".format(prt))

    print("{}>".format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')), *msg)




def make_envs(args):
    # box constraint
    if 'MBRL' in args.env_name:
        env_interface = get_env_interface(args.env_name)(args)
        env = env_interface.create_from_id(args.env_name)
    elif 'WidowX' in args.env_name:
        from rlf.rl.envs import make_env
        env_interface = get_env_interface(args.env_name)(args)
        env = make_env(0, args.env_name, args.seed, allow_early_resets=True, env_interface=env_interface,
                set_eval=False, alg_env_settings=None, args=args, immediate_call=True)
        env = NormalizedBoxEnv(env)
    elif args.env_name.startswith('MiniGrid'):
        from rlf.rl.envs import make_env
        env_interface = get_env_interface(args.env_name)(args)
        env = make_env(0, args.env_name, args.seed, allow_early_resets=True, env_interface=env_interface,
                set_eval=False, alg_env_settings=None, args=args, immediate_call=True)
        
        # env = EnvObsWrapper(env)
        # env_interface = get_env_interface(args.env_name)(args)
        # env = env_interface.create_from_id(args.env_name)
    else:
        from rlf.rl.envs import make_env
        env_interface = get_env_interface(args.env_name)(args)
        env = make_env(0, args.env_name, args.seed, allow_early_resets=True, env_interface=env_interface,
                set_eval=False, alg_env_settings=None, args=args, immediate_call=True)
        if args.env_name == 'AntGoal-v0':
            pass
        else:
            env = EnvObsWrapper(env)
        env = NormalizedBoxEnv(env)
        env = ActionSpaceBoxWrapper(env, ub=args.box_ub)

    # dim filter constraint
    if args.dim_filter < 1.0:
        action_dim = env.action_space.shape[0]
        dim_filter = int(action_dim * args.dim_filter)
        dim_filter = max(1, dim_filter)
        env = ActionDimBlockWrapper(env, dim_filter)
    return env



def make_envs_widowx(args):
    # from rlf.rl.envs import make_env
    import gym
    env_interface = get_env_interface(args.env_name)(args)
    # env = make_env(0, args.env_name, args.seed, allow_early_resets=True, env_interface=env_interface,
    #         set_eval=False, alg_env_settings=None, args=args, immediate_call=True)
    
    env = gym.make(args.env_name, constrained_action_space=False)

    if args.widowx_easy_obs:
        env = EasyObsWidowxWrapper(env)
    env = NormalizedBoxEnv(env)
    return env