import pickle
import time
from datetime import timedelta
import random
import os

import torch
import numpy as np
from scipy.special import expit
from scipy.stats import dirichlet
from utils import seed_everything
from tqdm import tqdm

MAX_STEPS = 1000

def sim_env(env, episodes, noise_params, rand_init=True, gym_state=None, 
            actions=[], replay_buffer=None, sample =False, 
            dyna_model = None, dyna_horizon = 10, policy = None, store =True, 
            deterministic_policy = True, open_loop=False):
    observe_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    ground_truth = np.zeros([episodes, observe_dim, dyna_horizon])
    steps = 0
    for j in tqdm(range(episodes)):
        done = False
        if rand_init:
            state = env.reset()
        else:
            state = env.reset(state = gym_state)
        i = 0
        tot_reward = 0
        while not done:
            if (actions is not None) & (i<len(actions)):
                action = actions[i]
            elif policy:
                action = policy[0].select_action(state, replay_buffer, 
                    evaluate = deterministic_policy)
            else:
                action = env.action_space.sample()
                if env.env.spec.id =='Hopper-v2':
                    action = action/20
            unseen  = np.clip(action, -env.action_space.high, env.action_space.high)
            next_state, reward, done, _ = env.step(unseen)
            if env.env.spec.id =='Ant-v2':
                trunc_state = state[:27]
                trunc_next_state = next_state[:27]
            elif env.env.spec.id =='Hopper-v2':
                trunc_state = state[1:]
                trunc_next_state = next_state[1:]
            elif env.env.spec.id =='Humanoid-v2':
                rmv_idxs = np.array([ 45,  46,  47,  48,  49,  50,  51,  52,  53,  
                    54,  64,  74,  84, 94, 104, 114, 124, 134, 144, 154, 164, 174, 
                    184, 185, 186, 187, 188, 189, 190, 269, 270, 271, 272, 273, 274, 
                    292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 
                    305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 
                    318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 
                    331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 
                    344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 
                    357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 
                    370, 371, 372, 373, 374, 375])
                mask = np.ones(len(state), dtype=bool)
                mask[rmv_idxs] = False
                trunc_state = state[mask]
                trunc_next_state = next_state[mask]
            else:
                trunc_state = state
                trunc_next_state = next_state
            tot_reward += reward
            if store:
                #if env.env.spec.id in ['Hopper-v2', 'Ant-v2', 'Humanoid-v2']:
                #    if i > 500:
                #        replay_buffer.push(trunc_state, action, reward, trunc_next_state, done, unseen)
                #else:
                replay_buffer.push(trunc_state, action, reward, trunc_next_state, done, unseen)
            if sample:
                ground_truth[j, :, i] = next_state
                if  (i+1) >= dyna_horizon:
                    done = True
            state = next_state
            i += 1
            if i >= MAX_STEPS:
                done = True
        steps += i
        #print(tot_reward)
    return ground_truth, tot_reward, steps

def sample_actions_seqs(env, policy):
    done = False
    state = env.reset()
    state_0 = state
    actions = []
    while not done:
        action = policy[0].select_action(state, '',
                            evaluate = False)
        next_state, reward, done, _ = env.step(action)
        state = next_state
        actions.append(action)
    return state_0, actions


def get_noise(noise_params, env):
    noise = np.zeros(env.action_space.shape[0])
    return noise

def create_noise(args, store_dir, suffix, seed):
    noise_params = []
    return noise_params

