import GPUtil
deviceIDs = GPUtil.getAvailable(order = 'memory', limit = 1, maxLoad = 1.0, maxMemory = 1.0, includeNan=False, excludeID=[], excludeUUID=[])
import os
import psutil
process = psutil.Process(os.getpid())
os.environ["CUDA_VISIBLE_DEVICES"] = str(deviceIDs[0]) #
import sys, os, time
import numpy as np
import gym
from ruamel.yaml import YAML
from samples import envs
import torch
from viskit import logger
from logging_utils.logx import EpochLogger
import datetime
import dateutil.tz
import json, copy
import d4rl
import d4rl.gym_mujoco
from continuous.divs.f_div import  contrastive_maxentirl_sigmoid_loss_validation,maxentirl_loss_sigmoid_validation, maxentirl_loss_sigmoid_validation_preferences, maxentirl_loss_sigmoid_validation_preferences_weighted
from continuous.divs.f_div import maxentirl_loss_sigmoid_validation_offline, snippet_loss_validation_preferences_weighted
from continuous.models.irl_agent import MLPReward, MLPSigmoidReward
from continuous.models.bnn_reward import EnsembleRewardModel
from continuous.models.sac import ReplayBuffer, SAC
from continuous.models.ppo import PPO
from continuous.utils import system, collect
from continuous import eval
import pickle




def isPossible(arr, n, C, mid):
     
    # Variable magnet will store count of
    # magnets that got placed and
    # currPosition will store the position
    # of last placed magnet
    magnet = 1
    currPosition = arr[0]
 
    for i in range(1, n):
         
        # If difference between current index
        # and last placed index is greater than
        # or equal to mid it will allow placing
        # magnet to this index
        if (arr[i] - currPosition >= mid):
            magnet += 1
 
            # Now this index will become
            # last placed index
            currPosition = arr[i]
 
            # If count of magnets placed becomes C
            if (magnet == C):
                return True
 
    # If count of placed magnet is
    # less than C then return false
    return False
 
# Function for modified binary search
def binarySearch(n, C, arr):
     
    # Sort the indices in ascending order
    np.sort(arr)
    # arr.sort(reverse = False)
 
    # Minimum possible distance
    lo = 0
 
    # Maximum possible distance
    hi = arr[n - 1]
    ans = 0
 
    # Run the loop until lo becomes
    # greater than hi
    while (lo <= hi):
        mid = int((lo + hi) / 2)
 
        # If not possible, decrease value of hi
        if (isPossible(arr, n, C, mid) == False):
            hi = mid - 1
        else:
             
            # Update the answer
            ans = max(ans, mid)
            lo = mid + 1
 
    # Return maximum possible distance
    return ans

def generate_preference_dataset(env_name, env, preference_levels=10, episode_per_preference_level=1): # 3, 10
    if 'Pen' in env_name or 'Door' in env_name or 'Hammer' in env_name:
        if 'Pen' in env_name:
            dataset_name = 'samples/replay_data/pen/sac_n_s0'+'replay_buffer.npy'
            max_steps = 100
        elif 'Door' in env_name:
            dataset_name = 'samples/replay_data/door/sac_n_s0'+'replay_buffer.npy'
            max_steps = 200
        elif 'Hammer' in env_name:
            dataset_name = 'samples/replay_data/hammer/sac_n_s0'+'replay_buffer.npy'
            max_steps = 200

        temp_env = gym.make(env_name)
        replay_buffer = np.load( dataset_name,allow_pickle=True).item()
        dataset = {}
        dataset['observations'] = replay_buffer['state'][:replay_buffer['size']]
        dataset['actions'] = replay_buffer['action'][:replay_buffer['size']]
        dataset['next_observations'] = replay_buffer['next_state'][:replay_buffer['size']]
        dataset['terminals'] = replay_buffer['done'][:replay_buffer['size']]
        dataset['rewards'] = replay_buffer['reward'][:replay_buffer['size']]
        
    else:
        if ('HalfCheetah' in env_name):
            dataset_name = 'halfcheetah-medium-expert-v0'
        elif ('Hopper' in  env_name):
            dataset_name= 'hopper-medium-expert-v0'
        elif ('Walker' in env_name):
            dataset_name= 'walker2d-medium-expert-v0'
        max_steps = 1000
        temp_env = gym.make(dataset_name)
        dataset = d4rl.qlearning_dataset(temp_env)


    offline_dataset = {}
    n_traj = 10000
    offline_dataset['state'] = np.zeros((n_traj,max_steps,env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,max_steps,env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,max_steps,env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,max_steps,1)) 
    offline_dataset['done'] = np.zeros((n_traj,max_steps,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==max_steps):
            while(episode_step_ctr!=max_steps):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    episode_step_ctr+=1
            traj_no+=1
            episode_step_ctr=0

    total_trajs = traj_no-1
    # Sample episode for preferences
    episode_returns = (offline_dataset['reward']*(1-offline_dataset['done'])).sum(1)[:total_trajs]
    max_episode_return = np.max(episode_returns)
    preference_dataset = {}
    pref_trajs = preference_levels*episode_per_preference_level

    preference_dataset['state'] = np.zeros((pref_trajs,max_steps,env.observation_space.shape[0])) 
    preference_dataset['action'] = np.zeros((pref_trajs,max_steps,env.action_space.shape[0])) 
    preference_dataset['next_state'] = np.zeros((pref_trajs,max_steps,env.observation_space.shape[0])) 
    preference_dataset['reward'] = np.zeros((pref_trajs,max_steps,1)) 
    preference_dataset['done'] = np.zeros((pref_trajs,max_steps,1))   
    preference_dataset['levels'] = preference_levels
    preference_dataset['episode_per_levels']=episode_per_preference_level
    if episode_per_preference_level==1:
        sorted_args = np.argsort(episode_returns.reshape(-1))
        maximum_possible_distance = binarySearch(episode_returns.shape[0],preference_levels, episode_returns[sorted_args])
        prev_val = -np.inf
        selected_idx = []
        for i in sorted_args:
            if(episode_returns[i]>prev_val+maximum_possible_distance):
                selected_idx.append(i)
                prev_val = episode_returns[i]
            # sample_return_idx = np.random.choice(episode_returns.shape[0], preference_levels, replace=False)
            # sorted_return_idx = np.argsort(episode_returns[sample_return_idx].reshape(-1))
            # sorted_idx = sample_return_idx[sorted_return_idx]
            # import ipdb;ipdb.set_trace()
    for pref_level in range(preference_levels):
        if episode_per_preference_level==1:
            idx = np.array([selected_idx[pref_level]])
            # idx = sorted_idx[pref_level]
            sample_idx = np.random.choice(idx,size=(episode_per_preference_level))
        else:
            
            sorted_idx = np.argsort(episode_returns.reshape(-1))
            idx = sorted_idx[int(pref_level*sorted_idx.shape[0]/preference_levels):int(pref_level*sorted_idx.shape[0]/preference_levels)+episode_per_preference_level]
            sample_idx = idx.reshape(-1)
            print("Average reward for pref level : {} is {}".format(pref_level,episode_returns[idx].mean() ))
            # idx = np.where(np.logical_and(episode_returns>=max_episode_return*(pref_level/preference_levels), episode_returns<=max_episode_return*((pref_level+1.)/preference_levels)))[0]
        # import ipdb;ipdb.set_trace()
        
        preference_dataset['state'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['state'][sample_idx,:,:]
        preference_dataset['action'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['action'][sample_idx,:,:]
        preference_dataset['next_state'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['next_state'][sample_idx,:,:]
        preference_dataset['reward'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['reward'][sample_idx,:,:]
        preference_dataset['done'][pref_level*episode_per_preference_level:(pref_level+1)*episode_per_preference_level,:,:] = offline_dataset['done'][sample_idx,:,:]

    print("Preference episode returns :{}".format(preference_dataset['reward'].sum(1)))
    # import ipdb;ipdb.set_trace()
    return preference_dataset


def generate_offline_dataset(env_name, env):
    if ('HalfCheetah' in env_name):
        dataset_name = 'halfcheetah-medium-v0'
    elif ('Hopper' in  env_name):
        dataset_name= 'hopper-medium-v0'
    elif ('Walker' in env_name):
        dataset_name= 'walker2d-medium-v0'

    temp_env = gym.make(dataset_name)
    dataset = d4rl.qlearning_dataset(temp_env)
    offline_dataset = {}
    n_traj = 10000
    offline_dataset['state'] = np.zeros((n_traj,1000,env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,1000,env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,1000,env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,1000,1)) 
    offline_dataset['done'] = np.zeros((n_traj,1000,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        # offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        # offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        # offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        # offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==1000):
            if (dataset['terminals'][i] and episode_step_ctr!=env.max_episode_len-1):
                while(episode_step_ctr!=env.max_episode_len):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    # offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
                    # offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
                        # offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
                    # offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
                    episode_step_ctr+=1
            traj_no+=1
            episode_step_ctr=0

    total_trajs = traj_no-1
    # Sample episode for preferences

    offline_trajs = 300
    sample_idx = np.random.choice(total_trajs, size=(offline_trajs))
    return offline_dataset['state'][sample_idx]

def load_expert_from_offline_dataset(env_name, num_traj=10):
    max_steps = 200
    if ('Door' in env_name):
        dataset_name = 'door-expert-v0'
    elif ('Pen' in  env_name):
        max_steps = 100
        dataset_name= 'pen-expert-v0'
    elif ('Hammer' in env_name):
        dataset_name= 'hammer-expert-v0'
    
   
    temp_env = gym.make(dataset_name)
    dataset = d4rl.qlearning_dataset(temp_env)
    offline_dataset = {}
    n_traj = num_traj
    offline_dataset['state'] = np.zeros((n_traj,max_steps,temp_env.observation_space.shape[0])) 
    offline_dataset['action'] = np.zeros((n_traj,max_steps,temp_env.action_space.shape[0])) 
    offline_dataset['next_state'] = np.zeros((n_traj,max_steps,temp_env.observation_space.shape[0])) 
    offline_dataset['reward'] = np.zeros((n_traj,max_steps,1)) 
    offline_dataset['done'] = np.zeros((n_traj,max_steps,1))    
    traj_no=0
    episode_step_ctr = 0
    # import ipdb;ipdb.set_trace()
    for i in range(dataset['observations'].shape[0]):
        if(traj_no==n_traj):
            break
        # print(traj_no,episode_step_ctr,i)
        offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
        offline_dataset['action'][traj_no,episode_step_ctr,:] = dataset['actions'][i]
        offline_dataset['next_state'][traj_no,episode_step_ctr,:] = dataset['next_observations'][i]
        offline_dataset['reward'][traj_no,episode_step_ctr,:] = dataset['rewards'][i]
        offline_dataset['done'][traj_no,episode_step_ctr,:] = dataset['terminals'][i]
        episode_step_ctr+=1
        # print(episode_step_ctr,dataset['terminals'][i])
        if(dataset['terminals'][i] or episode_step_ctr==max_steps):
            

            while(episode_step_ctr!=max_steps):
                    offline_dataset['state'][traj_no,episode_step_ctr,:] = dataset['observations'][i]
                    episode_step_ctr+=1
            
            # if (offline_dataset['reward'][traj_no, :,:].sum()<2500):
            #     traj_no+=0
            # else:
            traj_no+=1
            episode_step_ctr=0
    episode_returns = (offline_dataset['reward']*(1-offline_dataset['done'])).sum(1)
    print("Dataset return average: {}, std: {}".format(episode_returns.mean(),episode_returns.std()))
    return offline_dataset['state'], offline_dataset['action']


def try_evaluate(itr: int, policy_type: str, sac_info):
    assert policy_type in ["Running"]
    update_time = itr * v['reward']['gradient_step']
    env_steps = itr * v['sac']['epochs'] * v['env']['T']
    agent_emp_states = samples[0].copy()
    # assert agent_emp_states.shape[0] == v['irl']['training_trajs']

    # metrics = eval.KL_summary(expert_samples, agent_emp_states.reshape(-1, agent_emp_states[-10:].shape[2]), 
    #                      env_steps, policy_type)
    metrics = {}
    # eval real reward
    real_return_det = eval.evaluate_real_return(sac_agent.get_action,env_name, env_fn(), 
                                            v['irl']['eval_episodes'], v['env']['T'], True)
    metrics['Real Det Return'] = real_return_det
    print(f"real det return avg: {real_return_det:.2f}")
    logger.log_tabular("Real Det Return", round(real_return_det, 2))

    # real_return_sto = eval.evaluate_real_return(sac_agent.get_action, env_fn(), 
    #                                         v['irl']['eval_episodes'], v['env']['T'], False)
    # metrics['Real Sto Return'] = real_return_sto
    # print(f"real sto return avg: {real_return_sto:.2f}")
    # logger.log_tabular("Real Sto Return", round(real_return_sto, 2))

    if v['obj'] in ["emd"]:
        eval_len = int(0.1 * len(critic_loss["main"]))
        emd = -np.array(critic_loss["main"][-eval_len:]).mean()
        metrics['emd'] = emd
        logger.log_tabular(f"{policy_type} EMD", emd)
    
    # plot_disc(v['obj'], log_folder, env_steps, 
    #     sac_info, critic_loss if v['obj'] in ["emd"] else disc_loss, metrics)
    if "PointMaze" in env_name:
        visual_disc(agent_emp_states, reward_func.get_scalar_reward, disc.log_density_ratio, v['obj'],
                log_folder, env_steps, gym_env.range_lim,
                sac_info, disc_loss, metrics)

    logger.log_tabular(f"{policy_type} Update Time", update_time)
    logger.log_tabular(f"{policy_type} Env Steps", env_steps)

    return real_return_det, 0

if __name__ == "__main__":

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--sac_epochs', type=int, default=5)
    parser.add_argument('--irl_epochs', type=int, default=1)
    parser.add_argument('--exp_name', type=str, default='dump')
    parser.add_argument('--config', type=str, default='samples/configs/agents/hopper.yml')
    args = parser.parse_args()

    yaml = YAML()
    v = yaml.load(open(args.config))



    # common parameters
    env_name = v['env']['env_name']
    state_indices = v['env']['state_indices']
    # seed = v['seed']
    seed = args.seed
    v['irl']['epochs']=args.irl_epochs
    v['sac']['epochs']=args.sac_epochs
    v['exp_name']=args.exp_name

    num_expert_trajs = v['irl']['expert_episodes']

    # system: device, threads, seed, pid
    device = torch.device(f"cuda:0" if torch.cuda.is_available() and v['cuda'] >= 0 else "cpu")
    # device = torch.device("cuda")
    # import ipdb;ipdb.set_trace()
    torch.set_num_threads(1)
    np.set_printoptions(precision=3, suppress=True)
    system.reproduce(seed)
    pid=os.getpid()
    
    # logs
    exp_id = f"samples/results/{env_name}/" + v['exp_name'] # task/obj/date structure

    # exp_id = f"samples/logs/{env_name}/exp-{num_expert_trajs}/{v['obj']}" # task/obj/date structure
    # exp_id = 'debug'
    if not os.path.exists(exp_id):
        os.makedirs(exp_id)

    now = datetime.datetime.now(dateutil.tz.tzlocal())
    log_folder = exp_id + '/'+exp_id+'_s'+str(seed) 
    logger_kwargs={'output_dir':log_folder, 'exp_name':exp_id}
    logger = EpochLogger(**logger_kwargs)
    # logger.configure(dir=log_folder, )            
    print(f"Logging to directory: {log_folder}")
    os.system(f'cp samples/irl.py {log_folder}')
    os.system(f'cp {args.config} {log_folder}/variant.yml')
    with open(os.path.join(logger.output_dir, 'variant.json'), 'w') as f:
        json.dump(v, f, indent=2, sort_keys=True)
    print('pid', pid)
    os.makedirs(os.path.join(log_folder, 'plt'),exist_ok=True)
    os.makedirs(os.path.join(log_folder, 'model'),exist_ok=True)

    # environment
    env_fn = lambda: gym.make(env_name)
    gym_env = env_fn()
    # import ipdb;ipdb.set_trace()
    state_size = gym_env.observation_space.shape[0]
    action_size = gym_env.action_space.shape[0]
    if state_indices == 'all':
        state_indices = list(range(state_size))

    # load expert samples from trained policy
    if 'PenFH-v0' in env_name or 'DoorAdroitFH-v0' in env_name or 'HammerFH-v0' in env_name:
        expert_trajs, expert_actions = load_expert_from_offline_dataset(env_name)
    elif( torch.is_tensor(torch.load(f'samples/expert_data/states/{env_name}_airl.pt'))):
        # expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt').numpy()[:, :, state_indices]
        expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt').numpy()[:, :, state_indices]
        expert_actions = torch.load(f'samples/expert_data/actions/{env_name}_airl.pt').numpy()[:num_expert_trajs,:,:]
    else:
        expert_trajs = torch.load(f'samples/expert_data/states/{env_name}_airl.pt')[:, :, state_indices]
        expert_actions = torch.load(f'samples/expert_data/actions/{env_name}_airl.pt')[:num_expert_trajs,:,:]

    
    expert_trajs = expert_trajs[:num_expert_trajs, :, :] # select first expert_episodes
    expert_samples = expert_trajs.copy().reshape(-1, len(state_indices))
    
    # Initilialize reward as a neural network
    if v['reward']['type']=='vanilla':
        reward_func = MLPReward(len(state_indices), **v['reward'], device=device).to(device)
    elif v['reward']['type']=='sigmoid':
        reward_func = MLPSigmoidReward(len(state_indices), **v['reward'], device=device).to(device)
    elif v['reward']['type']=='ensemble_sigmoid':
        reward_func = EnsembleRewardModel(len(state_indices),'sigmoid').to(device)
        reward_func.scaler.fit(expert_samples)
    elif v['reward']['type']=='ensemble_clamp':
        reward_func = EnsembleRewardModel(len(state_indices),'clamp').to(device)
        reward_func.scaler.fit(expert_samples)

    reward_optimizer = torch.optim.Adam(reward_func.parameters(), lr=v['reward']['lr'], 
        weight_decay=v['reward']['weight_decay'], betas=(v['reward']['momentum'], 0.999))

    max_real_return_det, max_real_return_sto = -np.inf, -np.inf
    cum_samples = None

    
    if 'preferences' in v['obj'] or v['obj']=='contrastive-maxentirl-with-preferences':
        preference_dataset = generate_preference_dataset(env_name,env_fn())

    if 'offline' in  v['obj']:
        offline_dataset = generate_offline_dataset(env_name,env_fn())


    # Train reward untill convergence

    incorrect_ordering_ratio = None
    identity_firing_ratio = None
    samples = None
    # optimization w.r.t. reward
    reward_losses = []
    for itr in range(10000):
        print("Reward training iteration :{}".format(itr))
        expert_trajs_train = None

        if v['obj'] == 'maxentirl-sigmoid-validation-preferences-weighted':
            loss = maxentirl_loss_sigmoid_validation_preferences_weighted(v['obj'], samples, expert_samples,preference_dataset, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
        elif v['obj'] == 'snippet-validation-preferences-weighted':
            loss = snippet_loss_validation_preferences_weighted(v['obj'], samples, expert_trajs,preference_dataset, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
        else:
            print("Use snippet-validation-preferences-weighted or maxentirl-sigmoid-validation-preferences-weighted in config")
            exit()
        # elif v['obj']=='baseline-sup-validation':
        #     loss,_ = baseline_sup_loss_validation(v['obj'], samples, expert_samples, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])
        # elif v['obj']=='baseline-sup-validation-preferences':
        #     loss,_ = baseline_sup_loss_validation_with_preferences(v['obj'], samples, expert_samples,preference_dataset, reward_func,reward_optimizer, device, regularization=v['irl']['regularization'],epochs=v['irl']['epochs'])



    for itr in range(v['irl']['n_itrs']):
        # import ipdb;ipdb.set_trace()
        if v['rl']=='sac':
            if v['sac']['reinitialize'] or itr == 0:
                # Reset SAC agent with old policy, new environment, and new replay buffer
                print("Reinitializing sac")
                replay_buffer = ReplayBuffer(
                state_size, 
                action_size,
                device=device,
                size=v['sac']['buffer_size'])
                sac_agent = SAC(env_fn, replay_buffer,
                    steps_per_epoch=v['env']['T'],
                    update_after=v['env']['T'] * v['sac']['random_explore_episodes'], 
                    max_ep_len=v['env']['T'],
                    seed=seed,
                    start_steps=v['env']['T'] * v['sac']['random_explore_episodes'],
                    reward_state_indices=state_indices,
                    device=device,
                    **v['sac']
                )
            
        elif v['rl']=='ppo':
            if v['sac']['reinitialize'] or itr == 0:
                sac_agent =  PPO(env_fn, steps_per_epoch=v['env']['T'], epochs=v['sac']['epochs'],
            logger_kwargs=None)

        sac_agent.reward_function = reward_func.get_scalar_reward # only need to change reward in sac
        

     
        #  Train SAC until convergence
        if v['obj'] !='offline-il':
            sac_info, samples = sac_agent.learn_mujoco_and_collect(print_out=True,state_indices=state_indices)


        # evaluating the learned reward
        if v['obj']!='offline-il':
            real_return_det, real_return_sto = try_evaluate(itr, "Running", sac_info)
            if real_return_det > max_real_return_det and real_return_sto > max_real_return_sto:
                max_real_return_det, max_real_return_sto = real_return_det, real_return_sto
                # torch.save(reward_func.state_dict(), os.path.join(logger.get_dir(), 
                #         f"model/reward_model_itr{itr}_det{max_real_return_det:.0f}_sto{max_real_return_sto:.0f}.pkl"))
        else:
            torch.save(reward_func.state_dict(), os.path.join(logger.output_dir, 
                        f"model/reward_model.pkl"))

        
        logger.log_tabular("Iteration", itr)
        logger.log_tabular("Learned Reward Eval", np.array(sac_info[0]).mean())
        if v['sac']['automatic_alpha_tuning']:
            logger.log_tabular("alpha", sac_agent.alpha.item())
        
        # if v['irl']['save_interval'] > 0 and (itr % v['irl']['save_interval'] == 0 or itr == v['irl']['n_itrs']-1):
        #     # import ipdb;ipdb.set_trace()
        #     torch.save(reward_func.state_dict(), logger.output_dir+f"/reward_model_{itr}.pkl")

        logger.dump_tabular()