import sys
sys.path.append("/home/yz639/just_fqi")
import torch
import argparse
import numpy as np

import random
import os

from envs.Lock_batch import LockBatch
from envs.Dlock import DiabolicalLockMaze
from tqdm import tqdm
import time

from collections import deque

def save(args, save_name, model, wandb, ep=None):
    import os
    save_dir = './trained_models/' 
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if not ep == None:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + str(ep) + ".pth")
        wandb.save(save_dir + args.run_name + save_name + str(ep) + ".pth")
    else:
        torch.save(model.state_dict(), save_dir + args.run_name + save_name + ".pth")
        wandb.save(save_dir + args.run_name + save_name + ".pth")

def collect_random(env, dataset, num_samples=200):
    state = env.reset()
    for _ in range(num_samples):
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        dataset.add(state, action, reward, next_state, done)
        state = next_state
        if done:
            state = env.reset()

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', default="test", type=str)
    parser.add_argument('--num_threads', default=10, type=int)
    parser.add_argument('--update_frequency', default=1, type=int)

    parser.add_argument('--temp_path', default="temp", type=str)
    parser.add_argument('--epsilon', default=0.5, type=float)

    parser.add_argument('--num_envs', default=50, type=int)
    parser.add_argument('--recent_size', default=10000, type=int)
    parser.add_argument('--lsvi_recent_size', default=10000, type=int)
    parser.add_argument('--load', default=False, type=bool)
    parser.add_argument('--dense', default=False, type=bool)

    parser.add_argument('--seed', default=12, type=int)
    parser.add_argument('--num_warm_start', default=0, type=int)
    parser.add_argument('--num_episodes', default=1e7, type=int)
    parser.add_argument('--batch_size', default=512, type=int)


    #environment
    parser.add_argument('--horizon', default=100, type=int)
    parser.add_argument('--switch_prob', default=0.5, type=float)
    parser.add_argument('--anti_reward', default=0.1, type=float)
    parser.add_argument('--anti_reward_prob', default=0.5, type=float)
    parser.add_argument('--num_actions', default=10, type=int)
    parser.add_argument('--observation_noise', default=0.1, type=float)
    parser.add_argument('--variable_latent', default=False, type=bool)
    parser.add_argument('--env_temperature', default=0.2, type=float)
    parser.add_argument('--optimal_reward', default=5, type=float)
    parser.add_argument('--sub_optimal_reward', default=2, type=float)

    #rep
    parser.add_argument('--fqi_num_update', default=30, type=int)
    parser.add_argument('--learning_rate', default=1e-2, type=float)
    parser.add_argument('--beta', default=0.9, type=float)
    parser.add_argument('--hidden_dim', default=256, type=int)
    parser.add_argument('--temperature', default=1, type=float)

    parser.add_argument('--reuse_weights', default=True, type=bool)
    parser.add_argument('--optimizer', default='Adam', type=str)

    parser.add_argument('--softmax', default='vanilla', type=str)

    #lsvi
    parser.add_argument('--alpha', default=10, type=float)
    parser.add_argument('--lsvi_lamb', default=1, type=float)

    #eval
    parser.add_argument('--num_eval', default=20, type=int)

    args = parser.parse_args()
    return args

def make_batch_env(args):
    env = LockBatch()
    env.init(horizon=args.horizon, 
             action_dim=args.num_actions, 
             p_switch=args.switch_prob, 
             p_anti_r=args.anti_reward_prob, 
             anti_r=args.anti_reward,
             noise=args.observation_noise,
             num_envs=args.num_envs,
             temperature=args.env_temperature,
             variable_latent=args.variable_latent,
             dense=args.dense)

    env.seed(args.seed)
    env.action_space.seed(args.seed)

    eval_env = LockBatch()
    eval_env.init(horizon=args.horizon, 
             action_dim=args.num_actions, 
             p_switch=args.switch_prob, 
             p_anti_r=args.anti_reward_prob, 
             anti_r=args.anti_reward,
             noise=args.observation_noise,
             num_envs=args.num_eval,
             temperature=args.env_temperature,
             variable_latent=args.variable_latent,
             dense=args.dense)

    eval_env.seed(args.seed)
    eval_env.opt_a = env.opt_a
    eval_env.opt_b = env.opt_b

    return env, eval_env

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    #if torch.cuda.is_available():
    #    torch.cuda.manual_seed_all(seed)

    np.random.seed(seed)
    random.seed(seed)

class Buffer(object):
    def __init__(self, num_actions):
        
        self.num_actions = num_actions
        self.obses = []
        self.next_obses = []
        self.actions = []
        self.rewards = []
        self.idx = 0

    def add(self, obs, action, reward, next_obs):
        self.obses.append(obs)
        aoh = np.zeros(self.num_actions)
        aoh[action] = 1
        self.actions.append(aoh)
        self.rewards.append(reward)
        self.next_obses.append(next_obs)

        self.idx += 1

    def get_batch(self):
        return self.idx, np.array(self.obses), np.array(self.actions), np.array(self.rewards), np.array(self.next_obses) 

    def get(self, h):
        return self.obses[h], self.actions[h], self.rewards[h], self.next_obses[h]


class ReplayBuffer(object):
    """Buffer to store environment transitions."""
    def __init__(self, obs_shape, num_actions, capacity, batch_size, device, recent_size=0):
        self.capacity = capacity
        self.batch_size = batch_size
        self.device = device
        self.num_actions = num_actions

        self.obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.next_obses = np.empty((capacity, *obs_shape), dtype=np.float32)
        self.actions = np.empty((capacity, num_actions), dtype=np.int)
        self.rewards = np.empty((capacity, 1), dtype=np.float32)
        self.dones = np.empty((capacity, 1), dtype=np.float32)

        self.recent_size = recent_size

        self.idx = 0
        self.last_save = 0
        self.full = False

    def add(self, obs, action, reward, next_obs, done):
        np.copyto(self.obses[self.idx], obs)
        aoh = np.zeros(self.num_actions, dtype=np.int)
        aoh[action] = 1
        np.copyto(self.actions[self.idx], aoh)
        np.copyto(self.rewards[self.idx], reward)
        np.copyto(self.next_obses[self.idx], next_obs)
        np.copyto(self.dones[self.idx], done)

        self.idx = (self.idx + 1) % self.capacity
        self.full = self.full or self.idx == 0

    def add_batch(self, obs, action, reward, next_obs, done, size):
        np.copyto(self.obses[self.idx:self.idx+size], obs)
        aoh = np.zeros((size,self.num_actions), dtype=np.int)
        aoh[np.arange(size), action] = 1
        np.copyto(self.actions[self.idx:self.idx+size], aoh)
        np.copyto(self.rewards[self.idx:self.idx+size], reward)
        np.copyto(self.next_obses[self.idx:self.idx+size], next_obs)
        np.copyto(self.dones[self.idx:self.idx+size], done)

        self.idx = (self.idx + size) % self.capacity
        self.full = self.full or self.idx == 0

    def add_from_buffer(self, buf, batch_size = 1):
        obs, action, reward, next_obs, done = buf.sample(batch_size = batch_size)
        # print(self.obses[self.idx: self.idx + batch_size].shape)
        # print(self.obses.shape)
        # print(self.idx)
        # print(batch_size)
        np.copyto(self.obses[self.idx: self.idx + batch_size], obs)
        np.copyto(self.actions[self.idx: self.idx + batch_size], action)
        np.copyto(self.rewards[self.idx: self.idx + batch_size], reward)
        np.copyto(self.next_obses[self.idx: self.idx + batch_size], next_obs)
        np.copyto(self.dones[self.idx:self.idx+batch_size], done)

        self.idx = (self.idx + batch_size) % self.capacity
        self.full = self.full or self.idx == 0

    def get_full(self, recent_size=0, device=None):

        if device is None:
            device = self.device

        if self.idx <= recent_size or recent_size == 0: 
            start_index = 0
        else:
            start_index = self.idx - recent_size

        if self.full:
            obses = torch.as_tensor(self.obses[start_index:], device=device)
            actions = torch.as_tensor(self.actions[start_index:], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:], device=device)
            
            return obses, actions, rewards, next_obses
                
        else:
            obses = torch.as_tensor(self.obses[start_index:self.idx], device=device)
            actions = torch.as_tensor(self.actions[start_index:self.idx], device=device)
            rewards = torch.as_tensor(self.rewards[start_index:self.idx], device=device)
            next_obses = torch.as_tensor(self.next_obses[start_index:self.idx], device=device)
                
            return obses, actions, rewards, next_obses

    def sample(self, batch_size=None):
        if batch_size is None:
            batch_size = self.batch_size

        if self.recent_size == 0 or self.idx < self.recent_size: 
            idxs = np.random.randint(
                0, self.capacity if self.full else self.idx, size= batch_size 
            )
        else:
            idxs = np.random.randint(
                self.idx - self.recent_size, self.capacity if self.full else self.idx, size=batch_size 
            )


        obses = torch.as_tensor(self.obses[idxs], device=self.device)
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device)
        done = torch.as_tensor(self.dones[idxs], device=self.device)
        
        return obses, actions, rewards, next_obses, done

        
    def save(self, save_dir):
        if self.idx == self.last_save:
            return
        path = os.path.join(save_dir, '%d_%d.pt' % (self.last_save, self.idx))
        payload = [
            self.obses[self.last_save:self.idx],
            self.next_obses[self.last_save:self.idx],
            self.actions[self.last_save:self.idx],
            self.rewards[self.last_save:self.idx],
            self.dones[self.last_save:self.idx]
        ]
        self.last_save = self.idx
        torch.save(payload, path, pickle_protocol = 4)

    def load(self, save_dir):
        chunks = os.listdir(save_dir)
        chucks = sorted(chunks, key=lambda x: int(x.split('_')[0]))
        for chunk in chucks:
            start, end = [int(x) for x in chunk.split('.')[0].split('_')]
            path = os.path.join(save_dir, chunk)
            payload = torch.load(path)
            assert self.idx == start
            self.obses[start:end] = payload[0]
            self.next_obses[start:end] = payload[1]
            self.actions[start:end] = payload[2]
            self.rewards[start:end] = payload[3]
            self.dones[start:end] = payload[4]
            self.idx = end

def collect_uniform_random_buffer_lock(args, env, num_episodes):
    
    set_seed_everywhere(args.seed)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    num_runs = int(num_episodes/ env.horizon/ env.num_envs)

    buffer = ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(num_episodes)*2+1, 
                             args.batch_size, 
                             device,
                             recent_size=args.recent_size)

    for n in tqdm(range(num_runs)):
        for h in range(args.horizon):
            #reset the environments
            obs = env.reset()
            random_actions = np.random.randint(low=0, \
                high=env.action_space.n, size=env.num_envs)
            if h != 0:
                env.h = h
                random_state = np.random.randint(low=0, high=env.state_dim , size=env.num_envs)
                env.state = random_state
                obs = env.make_obs(random_state)
            next_obs, reward, done, _ = env.step(random_actions)
            new_state = env.get_state()
            #assert no cumulated reward after getting to s = 2
            if h != 0:
                assert np.logical_or(random_state.flatten() != 2, np.array(reward).flatten() == 0).all()
                assert np.logical_or(random_state.flatten() != 2, new_state.flatten() == 2).all()
            buffer.add_batch(obs,random_actions, reward,next_obs,done, args.num_envs)

    return buffer

def eps_greedy_actions(env, args, epsilon = -1):
    """
    return a list of epsilon greedy actions
    """
    num_envs = env.num_envs
    h = env.h
    if epsilon == -1:
        epsilon = 1/env.horizon
    # if h == 0:
    #     action = np.array([env.action_dim - 1 for _ in range(num_envs)])
    # else:
    #     action = np.array([env.locks[0].opt_a[h-1] if state == 0 else env.locks[0].opt_b[h-1] for state in env.get_state()])
    action = np.array([env.opt_a[h] if state == 0 else env.opt_b[h] for state in env.get_state()]) 
    random_actions = np.random.randint(low=0, high=env.action_space.n, size=env.num_envs)
    ber = np.random.binomial(1, 1 - epsilon, env.num_envs)
    action = np.where(ber, action, random_actions)
    # if h == args.horizon//2:
    #     action = random_actions
    return action


def collect_offline_buffer(args, env, num_episodes, epsilon = 0, verbose = False, buffer = None):
    """
    collect offline replay buffer with an epsilon greedy policy
    Args:
        - :param: `args` (parsed argument): the main arguments
        - :param: 'num_episodes': the number of episodes to collect
        - :param: 'epsilon': eps for eps-greedy, chance to do random action
            default is 1/horizon
        - :param: 'verbose': if set to True, print out fraction of episodes that
            reach the end

    Return:
        - :param: 'buffers': a list of ReplayBuffer of the number of horizon
    """
    
    set_seed_everywhere(args.seed)

    num_actions = env.action_space.n

    device = torch.device("cpu")

    num_runs = int(num_episodes/ env.horizon/ env.num_envs)

    if buffer == None:
        buffer = ReplayBuffer(env.observation_space.shape, 
                             env.action_space.n, 
                             int(num_episodes)*2+1, 
                             args.batch_size, 
                             device,
                             recent_size=0)  

    #num_reaches keep track of the number of episodes that make to the end
    num_reaches = 0

    if args.dense:
        args.alpha = args.horizon / 50
    else:
        args.alpha = args.horizon / 5



    if args.variable_latent:
        returns = deque(maxlen=50)
    else:
        returns = deque(maxlen=5)

    collection_time = time.time()
    horizon_episodes = 0
    for n in tqdm(range(num_runs)):
        for h in range(args.horizon):
            t = 0
            obs = env.reset()
            while t < h:
                action = eps_greedy_actions(env, args, 0)
                next_obs, reward, done, _ = env.step(action)
                # assert (env.get_state()[0] == 0 or env.get_state()[0] == 1)
                # print('states:')
                # print(env.get_state())
                # print('locks  ids')
                # print(env.lock_index)
                # assert env.get_state()[0] == 0 or env.get_state()[0] == 1
                obs = next_obs
                t += 1
            # action = np.random.randint(0, num_actions, args.num_envs)
            action = eps_greedy_actions(env, args, 0.5)
            next_obs, reward, done, _ = env.step(action)
            buffer.add_batch(obs,action,reward,next_obs,done, args.num_envs)


            # if h != args.horizon - 1:
            #     obs = next_obs
            #     action = np.random.randint(0, num_actions, args.num_envs)
            #     next_obs, reward, done, _ = env.step(action)
            #     buffers[h+1].add_batch(obs,action,reward,next_obs,args.num_envs)


            if h == args.horizon - 1:
                count = env.get_counts()
            #     obs = env.reset()
            #     action = np.random.randint(0, num_actions, args.num_envs)
            #     next_obs, reward, done, _ = env.step(action)
            #     buffers[0].add_batch(obs,action,reward,next_obs,args.num_envs)


        # reached = 0
        # for h in range(args.horizon):
        #     if counts[h,:2].sum() < 5:
        #         reached = h
        #         break

        num_reaches += count[:2].sum()
        # print(counts[-1,:2])
    print(f'num_reaches: {num_reaches}')
    print(f'horizon episodes: {horizon_episodes}')
    collection_time = time.time() - collection_time
    if verbose:
        print(f"fraction of episodes reach the end: {num_reaches/num_episodes * env.horizon}")
        print(f'it took {collection_time}s')
    return buffer

