import GPUtil
deviceIDs = GPUtil.getAvailable(order = 'memory', limit = 1, maxLoad = 1.0, maxMemory = 1.0, includeNan=False, excludeID=[], excludeUUID=[])
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(deviceIDs[0]) #
import sys, os, time
import numpy as np
import gym
from ruamel.yaml import YAML
from samples import envs
# from continuous import envs
import torch
from viskit import logger
from logging_utils.logx import EpochLogger
import datetime
import dateutil.tz
import json, copy
import d4rl
from continuous.divs.f_div import  contrastive_maxentirl_sigmoid_loss_validation,maxentirl_loss_sigmoid_validation, maxentirl_loss_sigmoid_validation_preferences
from continuous.divs.f_div import maxentirl_loss_sigmoid_validation_offline, maxentirl_loss_sigmoid_validation_preferences_weighted
from continuous.models.irl_agent import MLPReward, MLPSigmoidReward
from continuous.models.bnn_reward import EnsembleRewardModel
from continuous.models.sac import ReplayBuffer, SAC
from continuous.models.ppo import PPO
from continuous.utils import system, collect
from continuous.plots.train_plot import plot_disc as visual_disc
from continuous import eval
import pickle




def isPossible(arr, n, C, mid):
    magnet = 1
    currPosition = arr[0]
 
    for i in range(1, n):
        if (arr[i] - currPosition >= mid):
            magnet += 1
            currPosition = arr[i]
            if (magnet == C):
                return True
 
    return False
 
# Function for modified binary search
def binarySearch(n, C, arr):
    np.sort(arr)
    lo = 0
    hi = arr[n - 1]
    ans = 0
    while (lo <= hi):
        mid = int((lo + hi) / 2)
        if (isPossible(arr, n, C, mid) == False):
            hi = mid - 1
        else:
            ans = max(ans, mid)
            lo = mid + 1
    return ans

def generate_preference_dataset(env_name, env, preference_levels=10, episode_per_preference_level=1): # 3, 10
    if 'Pen' in env_name or 'Door' in env_name or 'Hammer' in env_name \
        or 'DoorFH' in env_name or 'LiftFH' in env_name or 'TwoArmPegInHole' in env_name:
        if 'DoorFH' in env_name:
            dataset_name = 'samples/replay_data/doorfh/sac_dense_new_s0replay_buffer.npy'#'samples/replay_data/doorfh/sac_dense_new_s0replay_buffer.npy'
            max_steps = 500
        elif 'Pen' in env_name:
            dataset_name = 'samples/replay_data/pen/sac_n_s0'+'replay_buffer.npy'
            max_steps = 100

        temp_env = gym.make(env_name)
        replay_buffer = np.load( dataset_name,allow_pickle=True).item()
        dataset = {}
        dataset['observations'] = replay_buffer['state'][:replay_buffer['size']]
        dataset['actions'] = replay_buffer['action'][:replay_buffer['size']]
        dataset['next_observations'] = replay_buffer['next_state'][:replay_buffer['size']]
        dataset['terminals'] = replay_buffer['done'][:replay_buffer['size']]
        dataset['rewards'] = replay_buffer['reward'][:replay_buffer['size']]
        
    else:
        if ('HalfCheetah' in env_name):
            dataset_name = 'halfcheetah-medium-expert-v0'
        elif ('Hopper' in  env_name):
            dataset_name= 'hopper-medium-expert-v0'
        elif ('Walker' in env_name):
            dataset_name= 'walker2d-medium-expert-v0'
        max_steps = 1000
        temp_env = gym.make(dataset_name)
        dataset = d4rl.qlearning_dataset(temp_env)


    offline_dataset = {}
    n_traj = 10000
    offline_dataset['state'] = np.zeros((n_traj,max_steps,env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,max_steps,env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,max_steps,env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,max_steps,1)) 
    offline_dataset['done'] = np.zeros((n_traj,max_steps,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==max_steps):
            while(episode_step_ctr!=max_steps):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    episode_step_ctr+=1
            traj_no+=1
            episode_step_ctr=0

    total_trajs = traj_no-1
    # Sample episode for preferences
    episode_returns = (offline_dataset['reward']*(1-offline_dataset['done'])).sum(1)[:total_trajs]
    max_episode_return = np.max(episode_returns)
    preference_dataset = {}
    pref_trajs = preference_levels*episode_per_preference_level

    preference_dataset['state'] = np.zeros((pref_trajs,max_steps,env.observation_space.shape[0])) 
    preference_dataset['action'] = np.zeros((pref_trajs,max_steps,env.action_space.shape[0])) 
    preference_dataset['next_state'] = np.zeros((pref_trajs,max_steps,env.observation_space.shape[0])) 
    preference_dataset['reward'] = np.zeros((pref_trajs,max_steps,1)) 
    preference_dataset['done'] = np.zeros((pref_trajs,max_steps,1))   
    preference_dataset['levels'] = preference_levels
    preference_dataset['episode_per_levels']=episode_per_preference_level
    if episode_per_preference_level==1:
        sorted_args = np.argsort(episode_returns.reshape(-1))
        maximum_possible_distance = binarySearch(episode_returns.shape[0],preference_levels, episode_returns[sorted_args])
        prev_val = -np.inf
        selected_idx = []
        for i in sorted_args:
            if(episode_returns[i]>prev_val+maximum_possible_distance):
                selected_idx.append(i)
                prev_val = episode_returns[i]
            # sample_return_idx = np.random.choice(episode_returns.shape[0], preference_levels, replace=False)
            # sorted_return_idx = np.argsort(episode_returns[sample_return_idx].reshape(-1))
            # sorted_idx = sample_return_idx[sorted_return_idx]
            # import ipdb;ipdb.set_trace()
    pref_returns = []
    for pref_level in range(preference_levels):
        if episode_per_preference_level==1:
            idx = np.array([selected_idx[pref_level]])
            # idx = sorted_idx[pref_level]
            sample_idx = np.random.choice(idx,size=(episode_per_preference_level))
        else:
            
            sorted_idx = np.argsort(episode_returns.reshape(-1))
            idx = sorted_idx[int(pref_level*sorted_idx.shape[0]/preference_levels):int(pref_level*sorted_idx.shape[0]/preference_levels)+episode_per_preference_level]
            sample_idx = idx.reshape(-1)
            print("Average reward for pref level : {} is {}".format(pref_level,episode_returns[idx].mean() ))
            # idx = np.where(np.logical_and(episode_returns>=max_episode_return*(pref_level/preference_levels), episode_returns<=max_episode_return*((pref_level+1.)/preference_levels)))[0]
        # import ipdb;ipdb.set_trace()
        
        preference_dataset['state'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['state'][sample_idx,:,:]
        preference_dataset['action'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['action'][sample_idx,:,:]
        preference_dataset['next_state'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['next_state'][sample_idx,:,:]
        preference_dataset['reward'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['reward'][sample_idx,:,:]
        pref_returns.append(preference_dataset['reward'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:].sum(1).mean())
        preference_dataset['done'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['done'][sample_idx,:,:]

    print("Preference episode returns :{}".format(pref_returns))
    # import ipdb;ipdb.set_trace()
    return preference_dataset

def load_expert_from_offline_dataset(env_name, num_traj=10):
    max_steps = 200
    if ('Door' in env_name):
        dataset_name = 'door-expert-v0'
    elif ('Pen' in  env_name):
        max_steps = 100
        dataset_name= 'pen-expert-v0'
    elif ('Hammer' in env_name):
        dataset_name= 'hammer-expert-v0'
    
   
    temp_env = gym.make(dataset_name)
    dataset = d4rl.qlearning_dataset(temp_env)
    offline_dataset = {}
    n_traj = num_traj
    offline_dataset['state'] = np.zeros((n_traj,max_steps,temp_env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,max_steps,temp_env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,max_steps,temp_env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,max_steps,1)) 
    offline_dataset['done'] = np.zeros((n_traj,max_steps,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        if(traj_no==n_traj):
            break
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==max_steps):
            

            while(episode_step_ctr!=max_steps):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    episode_step_ctr+=1
            traj_no+=1
            episode_step_ctr=0
    episode_returns = (offline_dataset['reward']*(1-offline_dataset['done'])).sum(1)
    print("Dataset return average: {}, std: {}".format(episode_returns.mean(),episode_returns.std()))
    return offline_dataset['state'], offline_dataset['action']



def generate_offline_dataset(env_name, env):
    if ('HalfCheetah' in env_name):
        dataset_name = 'halfcheetah-medium-v0'
    elif ('Hopper' in  env_name):
        dataset_name= 'hopper-medium-v0'
    elif ('Walker' in env_name):
        dataset_name= 'walker2d-medium-v0'

    temp_env = gym.make(dataset_name)
    dataset = d4rl.qlearning_dataset(temp_env)
    offline_dataset = {}
    n_traj = 10000
    offline_dataset['state'] = np.zeros((n_traj,1000,env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,1000,env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,1000,env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,1000,1)) 
    offline_dataset['done'] = np.zeros((n_traj,1000,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        # offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        # offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        # offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        # offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==1000):
            
            if (dataset['terminals'][i] and episode_step_ctr!=env.max_episode_len-1):
                while(episode_step_ctr!=env.max_episode_len):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    # offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
                    # offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
                    # offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
                    # offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
                    episode_step_ctr+=1
            traj_no+=1
            episode_step_ctr=0


    total_trajs = traj_no-1
    # Sample episode for preferences

    offline_trajs = 300
    sample_idx = np.random.choice(total_trajs, size=(offline_trajs))
    return offline_dataset['state'][sample_idx]



def try_evaluate(itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = samples[0].copy()
    # assert agent_emp_states.shape[0] == v['irl']['training_trajs']
    # metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states[-10:].shape[2]), 
    #                      env_steps, policy_type)
    metrics = {}
    # eval real reward
    real_return_det = eval.evaluate_real_return(sac_agent.get_action,env_name, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.log_tabular("Real Det Return", round(real_return_det, 2))
    # real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
    #                                         v['irl']['eval_episodes'], v['env']['T'], False)
    # metrics['Real Sto Return'] = real_return_sto
    # print(f"real sto return avg: {real_return_sto:.2f}")
    # logger.log_tabular("Real Sto Return", round(real_return_sto, 2))

    logger.log_tabular(f"{policy_type} Update Time", update_time)
    logger.log_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, 0

if __name__ == "__main__":
    yaml = YAML()
    v = yaml.load(open(sys.argv[1]))

    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    seed = v['seed']
    num_expert_trajs = v['irl']['expert_episodes']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:0" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # logs
    exp_id = f"samples/results/{env_name}/" + v['exp_name'] # task/obj/date structure
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id+'/'+v['exp_name'] +'_s'+str(seed) 
    logger_kwargs={'output_dir':log_folder, 'exp_name':exp_id}
    logger = EpochLogger(**logger_kwargs)
    # logger.configure(dir=log_folder, )            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp samples/irl.py {log_folder}')
    os.system(f'cp {sys.argv[1]} {log_folder}/variant.yml')
    with open(os.path.join(logger.output_dir, 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'),exist_ok=True)
    os.makedirs(os.path.join(log_folder, 'model'),exist_ok=True)

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    if 'PenFH-v0' in env_name or 'DoorAdroitFH-v0' in env_name or 'HammerFH-v0' in env_name:
        expert_trajs, expert_actions = load_expert_from_offline_dataset(env_name)
    elif( torch.is_tensor(torch.load(f'samples/expert_data/states/{env_name}_airl.pt'))):
        # expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt').numpy()[:, :, state_indices]
        expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt').numpy()[:, :, state_indices]
        expert_actions = torch.load(f'samples/expert_data/actions/{env_name}_airl.pt').numpy()[:num_expert_trajs,:,:]
    else:
        expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt')[:, :, state_indices]
        expert_actions = torch.load(f'samples/expert_data/actions/{env_name}_airl.pt')[:num_expert_trajs,:,:]
    
    expert_trajs = expert_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    # print(expert_trajs.shape, expert_samples.shape) # ignored starting state
    
    # Initilialize reward as a neural network
    if v['reward']['type']=='vanilla':
        reward_func = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    elif v['reward']['type']=='sigmoid':
        reward_func = MLPSigmoidReward(len(state_indices), **v['reward'], device=device).to(device)
    elif v['reward']['type']=='ensemble_sigmoid':
        reward_func = EnsembleRewardModel(len(state_indices),'sigmoid').to(device)
        reward_func.scaler.fit(expert_samples)
    elif v['reward']['type']=='ensemble_clamp':
        reward_func = EnsembleRewardModel(len(state_indices),'clamp').to(device)
        reward_func.scaler.fit(expert_samples)

    reward_optimizer = torch.optim.Adam(reward_func.parameters(), lr=v['reward']['lr'], 
        weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))
    

    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    cum_samples = None

    
    if 'preferences' in v['obj'] or v['obj']=='contrastive-maxentirl-with-preferences':
        preference_dataset = generate_preference_dataset(env_name,env_fn())


    if 'offline' in  v['obj']:
        offline_dataset = generate_offline_dataset(env_name,env_fn())

    for itr in range(v['irl']['n_itrs']):
        if v['rl']=='sac':
            if v['sac']['reinitialize'] or itr == 0:
                # Reset SAC agent with old policy, new environment, and new replay buffer
                print("Reinitializing sac")
                replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                # ac_kwargs = {'hidden_sizes':(1024,1024)}
                ac_kwargs = {}
                sac_agent = SAC(env_fn, replay_buffer,
                    steps_per_epoch=v['env']['T'],
                    update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                    max_ep_len=v['env']['T'],
                    seed=seed,
                    start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                    reward_state_indices=state_indices,
                    device=device,
                    ac_kwargs=ac_kwargs,
                    **v['sac']
                )
            
        elif v['rl']=='ppo':
            if v['sac']['reinitialize'] or itr == 0:
                sac_agent =  PPO(env_fn, steps_per_epoch=v['env']['T'], epochs=v['sac']['epochs'],
            logger_kwargs=None)

        sac_agent.reward_function = reward_func.get_scalar_reward # only need to change reward in sac

        if v['obj'] !='offline-il':
            sac_info, samples = sac_agent.learn_mujoco_and_collect(print_out=True,state_indices=state_indices)

            start = time.time()
            agent_emp_states = samples[0].copy()
            if cum_samples is None:
                cum_samples = agent_emp_states
            else:
                cum_samples = np.concatenate((cum_samples,agent_emp_states),axis=0)
            agent_emp_states = agent_emp_states.reshape(-1,agent_emp_states.shape[2]) # n*T states
            print(f'collect trajs {time.time() - start:.0f}s', flush=True)

        incorrect_ordering_ratio = None
        identity_firing_ratio = None
        # optimization w.r.t. reward
        reward_losses = []
        for _ in range(v['reward']['gradient_step']):
            expert_trajs_train = None
            if v['obj'] == 'rank-pal':
                loss = maxentirl_loss_sigmoid_validation(v['obj'], samples, expert_samples, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
            elif v['obj'] == 'maxentirl-sigmoid-validation-offline':
                loss = maxentirl_loss_sigmoid_validation_offline(v['obj'], samples, expert_samples, offline_dataset, reward_func,reward_optimizer, itr, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
            elif v['obj'] == 'maxentirl-sigmoid-validation-preferences':
                loss = maxentirl_loss_sigmoid_validation_preferences(v['obj'], samples, expert_samples,preference_dataset, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
            elif v['obj'] == 'pal-preferences-weighted':
                loss = maxentirl_loss_sigmoid_validation_preferences_weighted(v['obj'], samples, expert_samples,preference_dataset, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
            elif v['obj'] == 'rank-ral':
                incorrect_ordering_ratio, loss = contrastive_maxentirl_sigmoid_loss_validation(v['obj'], cum_samples, expert_samples, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])            

        # evaluating the learned reward
        if v['obj']!='offline-il':
            real_return_det, real_return_sto = try_evaluate(itr, "Running", sac_info)
            if real_return_det > max_real_return_det and real_return_sto > max_real_return_sto:
                max_real_return_det, max_real_return_sto = real_return_det, real_return_sto
        else:
            torch.save(reward_func.state_dict(), os.path.join(logger.output_dir, 
                        f"model/reward_model.pkl"))
        if incorrect_ordering_ratio is not None:
            logger.log_tabular("IncorrectOrderingRatio", incorrect_ordering_ratio)
        else:
            logger.log_tabular("IncorrectOrderingRatio", -1)
        if identity_firing_ratio is not None:
            logger.log_tabular("IdentityFiringRatio", identity_firing_ratio)
        else:
            logger.log_tabular("IdentityFiringRatio", -1)
        
        logger.log_tabular("Iteration", itr)
        logger.log_tabular("Reward Loss", loss)
        logger.log_tabular("Learned Reward Eval", np.array(sac_info[0]).mean())
        if v['sac']['automatic_alpha_tuning']:
            logger.log_tabular("alpha", sac_agent.alpha.item())
        
        # if v['irl']['save_interval'] > 0 and (itr % v['irl']['save_interval'] == 0 or itr == v['irl']['n_itrs']-1):
        #     torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), f"model/reward_model_{itr}.pkl"))

        logger.dump_tabular()