import os
import ast
import random

import torch
import numpy as np
import importlib


device = None
logger = None


def set_global_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def set_device_and_logger(gpu_id, logger_ent):
    global device, logger
    if gpu_id < 0 or torch.cuda.is_available() == False:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:{}".format(gpu_id))
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    print("setting device:", device)
    logger = logger_ent


def relative_path_to_module_path(relative_path):
    path = relative_path.replace(".py", "").replace(os.path.sep,'.')
    return path


def load_config(config_path, update_args):
    default_config_path_elements = config_path.split(os.sep)
    default_config_path_elements[-1] = "default.py"
    default_config_path = os.path.join(*default_config_path_elements)
    default_args_module = importlib.import_module(relative_path_to_module_path(default_config_path))
    overwrite_args_module = importlib.import_module(relative_path_to_module_path(config_path))
    default_args_dict = getattr(default_args_module, 'default_args')
    args_dict = getattr(overwrite_args_module, 'overwrite_args')
    assert type(default_args_dict) == dict, "default args file should be default_args=\{...\}"
    assert type(args_dict) == dict, "args file should be default_args=\{...\}"

    #update args is tpule type, convert to dict type
    update_args_dict = {}
    for update_arg in update_args:
        key, val = update_arg.split("=")
        update_args_dict[key] = ast.literal_eval(val)
    
    #update env specific args to default 
    args_dict = merge_dict(default_args_dict, args_dict)
    default_args_dict = update_parameters(default_args_dict, update_args_dict)
    if 'common' in args_dict:
        for sub_key in args_dict:
            if type(args_dict[sub_key]) == dict:
                args_dict[sub_key] = merge_dict(args_dict[sub_key], default_args_dict['common'], "common")
    return args_dict


def merge_dict(source_dict, update_dict, ignored_dict_name=""):
    for key in update_dict:
        if key == ignored_dict_name:
            continue
        if key not in source_dict:
            #print("\033[32m new arg {}: {}\033[0m".format(key, update_dict[key]))
            source_dict[key] = update_dict[key]
        else:
            if type(update_dict[key]) == dict:
                source_dict[key] = merge_dict(source_dict[key], update_dict[key], ignored_dict_name)
            else:
                print("updated {} from {} to {}".format(key, source_dict[key], update_dict[key]))
                source_dict[key] = update_dict[key]
    return source_dict


def update_parameters(source_args, update_args):
    print("updating args", update_args)
    #command line overwriting case, decompose the path and overwrite the args
    for key_path in update_args:
        target_value = update_args[key_path]
        print("key:{}\tvalue:{}".format(key_path, target_value))
        source_args = overwrite_argument_from_path(source_args, key_path, target_value)
    return source_args


def overwrite_argument_from_path(source_dict, key_path, target_value):
    key_path = key_path.split("/")
    curr_dict = source_dict
    for key in key_path[:-1]:
        if not key in curr_dict:
            #illegal path
            return source_dict
        curr_dict = curr_dict[key]
    final_key = key_path[-1] 
    curr_dict[final_key] = target_value
    return source_dict


def second_to_time_str(remaining:int):
    dividers = [86400, 3600, 60, 1]
    names = ['day', 'hour', 'minute', 'second']
    results = []
    for d in dividers:
        re = int(np.floor(remaining / d))
        results.append(re)
        remaining -= re * d
    time_str = ""
    for re, name in zip(results, names):
        if re > 0 :
            time_str += "{} {}  ".format(re, name)
    return time_str

def calc_sim_next_obs(simulator, state, action, parallel=True, config=None):
    qpos_dim = simulator.sim.data.qpos.size
    if len(state.shape) > 1:
        # ---------------------------------
        # ----------  parallel: -----------
        # ---------------------------------
        # # FIXME: this doesn't work
        # import time

        # sims = []
        # for i in range(32):
        #     cur_sim, _ = get_transformed_env(config.env.train_env, config.simulator.transform_list)
        #     # cur_sim.reset()
        #     sims.append(cur_sim)

        # def calc_sim_next_obs_once(i, cur_state, cur_action):
        #     # sim = copy.deepcopy(simulator)
        #     cur_sim = sims[i % 32]
        #     cur_sim.reset()
        #     cur_qpos = np.concatenate(([0], cur_state[0:qpos_dim - 1]), axis=0)
        #     cur_qvel = cur_state[qpos_dim - 1:]
        #     cur_sim.set_state(cur_qpos, cur_qvel)
        #     next_sim_obs_once, _, _, _ = cur_sim.step(cur_action)
        #     return next_sim_obs_once
        #
        # s = time.time()
        # # next_sim_obs_parralel_test = calc_sim_next_obs_once(state[2], action[2])
        # next_sim_obs_parallel = Parallel(n_jobs=32)(delayed(calc_sim_next_obs_once)(i, state[i], action[i]) for i in range(state.shape[0]))
        # next_sim_obs_parallel = np.stack(next_sim_obs_parallel, axis=0)
        # print(f'Parallel: {time.time() - s}')
        # next_sim_obs = next_sim_obs_parallel

        # ---------------------------------
        # ------- without parallel: -------
        # ---------------------------------
        # s = time.time()
        next_sim_obs = np.zeros_like(state)
        for i in range(state.shape[0]):
            simulator.reset()
            cur_qpos = np.concatenate(([0], state[i, 0:qpos_dim - 1]), axis=0)
            cur_qvel = state[i, qpos_dim - 1:]
            simulator.set_state(cur_qpos, cur_qvel)
            next_sim_obs[i], _, _, _ = simulator.step(action[i])

        # print(f'Not Parallel: {time.time() - s}')

        # print(next_sim_obs[2])
        # print(next_sim_obs_parallel[2])
        # print(next_sim_obs_parralel_test)
        # print(((next_sim_obs - next_sim_obs_parallel)**2).sum())
        # quit()

    else:
        simulator.reset()
        cur_qpos = np.concatenate(([0], state[0:qpos_dim - 1]), axis=0)
        cur_qvel = state[qpos_dim - 1:]
        simulator.set_state(cur_qpos, cur_qvel)
        next_sim_obs, _, _, _ = simulator.step(action)

    return next_sim_obs