import numpy as np
import torch
import os
import math
import gym
import sys
import random
import time
import json
import copy
from tqdm import tqdm
import pickle

from collections import deque

import multiprocessing

from utils import parse_args, set_seed_everywhere, ReplayBuffer,  make_batch_dlock
(3)


os.environ["OMP_NUM_THREADS"] = "1"
def evaluate_policy(env, epsilon, args):
    returns = np.zeros((args.num_envs,1))
    
    obs = env.reset()
    for h in range(args.horizon):
        action =  eps_greedy_actions(env, args, epsilon)
        next_obs, reward, done, _ = env.step(action)
        obs = next_obs
        # print(reward)
        returns += reward

    return np.mean(returns)

def eps_greedy_actions(env, args, epsilon = -1):
    """
    return a list of epsilon greedy actions
    """
    num_envs = env.num_envs
    h = env.h
    if epsilon == -1:
        epsilon = 1/env.horizon
    # if h == 0:
    #     action = np.array([env.action_dim - 1 for _ in range(num_envs)])
    # else:
    #     action = np.array([env.locks[0].opt_a[h-1] if state == 0 else env.locks[0].opt_b[h-1] for state in env.get_state()])
    action = np.array([env.opt_a[h] if state == 0 else env.opt_b[h] for state in env.get_state()]) 
    random_actions = np.random.randint(low=0, high=env.action_space.n, size=env.num_envs)
    if h >= args.horizon//2 and epsilon == -2:
        return random_actions
    elif epsilon == -2:
        epsilon = 0
    ber = np.random.binomial(1, 1 - epsilon, env.num_envs)
    action = np.where(ber, action, random_actions)
    if h == args.horizon//2:
        action = random_actions
    return action


def collect_offline_buffer(args, env, num_episodes, epsilon = -1, verbose = False, buffers = None):
    """
    collect offline replay buffer with an epsilon greedy policy
    Args:
        - :param: `args` (parsed argument): the main arguments
        - :param: 'num_episodes': the number of episodes to collect
        - :param: 'epsilon': eps for eps-greedy, chance to do random action
            default is 1/horizon
        - :param: 'verbose': if set to True, print out fraction of episodes that
            reach the end

    Return:
        - :param: 'buffers': a list of ReplayBuffer of the number of horizon
    """
    
    set_seed_everywhere(args.seed)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    num_runs = int(num_episodes/ env.horizon/ env.num_envs)

    #num_reaches keep track of the number of episodes that make to the end
    num_reaches = 0
    if buffers == None:
        buffers = []
        for _ in range(args.horizon):
            buffers.append(
                    ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(num_episodes/ args.horizon)*3+1, 
                             args.batch_size, 
                             device)
                )


    if args.dense:
        args.alpha = args.horizon / 50
    else:
        args.alpha = args.horizon / 5



    if args.variable_latent:
        returns = deque(maxlen=50)
    else:
        returns = deque(maxlen=5)

    collection_time = time.time()
    horizon_episodes = 0
    for n in tqdm(range(num_runs)):
        for h in range(args.horizon):
            t = 0
            obs = env.reset()
            while t < h:
                action = eps_greedy_actions(env, args, epsilon)
                next_obs, _, _, _ = env.step(action)
                # assert (env.get_state()[0] == 0 or env.get_state()[0] == 1)
                # print('states:')
                # print(env.get_state())
                # print('locks  ids')
                # print(env.lock_index)
                # assert env.get_state()[0] == 0 or env.get_state()[0] == 1
                obs = next_obs
                t += 1
            # action = np.random.randint(0, num_actions, args.num_envs)
            action = eps_greedy_actions(env, args, epsilon)
            next_obs, reward, done, _ = env.step(action)
            buffers[h].add_batch(obs,action,reward,next_obs,args.num_envs)


            # if h != args.horizon - 1:
            #     obs = next_obs
            #     action = np.random.randint(0, num_actions, args.num_envs)
            #     next_obs, reward, done, _ = env.step(action)
            #     buffers[h+1].add_batch(obs,action,reward,next_obs,args.num_envs)


            if h == args.horizon - 1:
                count = env.get_counts()
            #     obs = env.reset()
            #     action = np.random.randint(0, num_actions, args.num_envs)
            #     next_obs, reward, done, _ = env.step(action)
            #     buffers[0].add_batch(obs,action,reward,next_obs,args.num_envs)


        # reached = 0
        # for h in range(args.horizon):
        #     if counts[h,:2].sum() < 5:
        #         reached = h
        #         break

        num_reaches += count[:2].sum()
        # print(counts[-1,:2])
    print(f'num_reaches: {num_reaches}')
    print(f'horizon episodes: {horizon_episodes}')
    collection_time = time.time() - collection_time
    if verbose:
        print(f"fraction of episodes reach the end: {num_reaches/num_episodes * env.horizon}")
        print(f'it took {collection_time}s')
    return buffers

def collect_uniform_random_buffer(args, env, num_episodes):
    """
    collect uniformly sampled buffer
    arguments same as above
    """
    set_seed_everywhere(args.seed)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    num_runs = int(num_episodes/ env.horizon/ env.num_envs)

    buffers = []    

    for _ in range(args.horizon):
        buffers.append(
                ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(num_episodes/ args.horizon)*2+1, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)
            )

    for n in tqdm(range(num_runs)):
        for h in range(args.horizon):
            #reset the environments
            obs = env.reset()
            random_actions = np.random.randint(low=0, \
                high=env.action_space.n, size=env.num_envs)
            if h != 0:
                env.h = h
                random_state = np.random.randint(low=0, \
                        high=env.state_dim // 2, size=env.num_envs)
                env.lock_index = np.random.randint(low=0, \
                        high=len(env.locks), size=env.num_envs)
                obs = env.make_obs(random_state)
                for lock in env.locks:
                    lock.reset()
                    lock.h = h -1
                    lock.state = random_state
                random_state = env.get_state()
            next_obs, reward, done, _ = env.step(random_actions)
            #assert no cumulated reward after getting to s = 2
            if h != 0:
                new_state = env.get_state()
                assert np.logical_or(random_state.flatten() != 2, np.array(reward).flatten() == 0).all()
                # if not np.logical_or(random_state.flatten() != 2, new_state.flatten() == 2).all():
                #     bad_index = np.where(np.logical_and(random_state.flatten() == 2, new_state.flatten() != 2))
                #     print(random_state[bad_index])
                #     print(new_state[bad_index])
                assert np.logical_or(random_state.flatten() != 2, new_state.flatten() == 2).all()
            buffers[h].add_batch(obs,random_actions,\
                    reward,next_obs,args.num_envs)

    return buffers

def collect_uniform_random_buffer_lock(args, env, num_episodes):
    
    set_seed_everywhere(args.seed)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    num_runs = int(num_episodes/ env.horizon/ env.num_envs)

    buffers = []    

    for _ in range(args.horizon):
        buffers.append(
                ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(num_episodes/ args.horizon)*2+1, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)
            )

    for n in tqdm(range(num_runs)):
        for h in range(args.horizon):
            #reset the environments
            obs = env.reset()
            random_actions = np.random.randint(low=0, \
                high=env.action_space.n, size=env.num_envs)
            if h != 0:
                env.h = h
                random_state = np.random.randint(low=0, high=env.state_dim , size=env.num_envs)
                env.state = random_state
                obs = env.make_obs(random_state)
            next_obs, reward, done, _ = env.step(random_actions)
            new_state = env.get_state()
            #assert no cumulated reward after getting to s = 2
            if h != 0:
                assert np.logical_or(random_state.flatten() != 2, np.array(reward).flatten() == 0).all()
                assert np.logical_or(random_state.flatten() != 2, new_state.flatten() == 2).all()
            buffers[h].add_batch(obs,random_actions, reward,next_obs,args.num_envs)

    return buffers


                




if __name__ == '__main__':

    args = parse_args()
    env, eval_env = make_batch_dlock(args)

    evaluate_policy(env, 0, args)

    # returns = evaluate(env,eps_greedy_actions,args)
    # collect_offline_buffer(args,env, num_episodes=  2e5, epsilon= .5,verbose = True)
    # print(returns)
    # #temporarily use pickle to save the buffers
    # buffers = collect_offline_buffer(args,env, num_episodes=  2e5, epsilon= 0,verbose = True)
    # with open('offline_3pa/buffers.pickle', 'wb') as fb:
    #     pickle.dump(buffers, fb)









