import os
import time
import copy
import pickle
import jax
import random
import string
import d4rl
import d4rl.gym_mujoco
import d4rl.locomotion
import gym
import numpy as np
import tqdm
from absl import app, flags
import cloudpickle as pickle

import gymnasium

from ..utils.env_utils import GymnasiumWrapper, TrajSampler, get_d4rl_dataset, evaluate, wrap_gym

try:
    from flax.training import checkpoints
except:
    print("Not loading checkpointing functionality.")
from ml_collections import config_flags

import wandb
from ..agents.rlpd import RLPDSamplerPolicy, get_rlpd_policy_from_model, SACLearner
from ..agents.rlphi import RLPhiLearner
from ..agents.iql import get_iql_policy_from_model, IQLSamplerPolicy
from ..utils.dataset_utils import ReplayBuffer, AdjusterReplayBuffer
from ..utils.dataset_utils import D4RLDataset

from ..models.model import SamplerPolicy, get_policy_from_model, load_model, evaluate_policy, LSTM_RLPD

from ..utils.utils import define_flags_with_default, set_random_seed
from ..utils.dataset_utils import expert_dict
from functools import partial


@partial(jax.jit, static_argnames="apply_fn")
def _get_adjuster_q_value(apply_fn, params, observations):
    return apply_fn({"params": params}, observations, False).squeeze()

@partial(jax.jit, static_argnames="apply_fn")
def _get_dynamics_output(apply_fn, params, observations, actions):
    return apply_fn({"params": params}, observations, actions, True)

if __name__ == "__main__":
    
    FLAGS_DEF = define_flags_with_default(
        project_name="rlpd_itv_test",
        env_name="hopper-expert-v2",
        sparse_env='Hopper-v2',
        offline_ratio=0.5,
        seed=43,
        train_sparse=False,
        dataset_dir='',

        expert_dir='./AGSA/experts/rlpd_experts/s24_hopper-expert-v2env/model.pkl',
        ground_truth_agent_dir='./AGSA/experts/rlpd_experts/s24_hopper-expert-v2env/model.pkl',
        intervene_threshold=0.0,
        p_intervene_threshold=0.2,
        adj_acc_rew_ratio=1.0,
        use_p_intervene=False, # deprecated
        intervene_type='value', # value, uncertainty, neural
        intervention_strategy='',
        intervene_n_steps=4,
        threshold_n_steps=2, # number of steps reaching intervening threshold before intervening
        use_pbrl=False,
        use_pbrl_in_learning_agent=False,
        pbrl_normalizing_ratio=1e-4,
        pbrl_rew_ratio=0.01,
        combine_un=False,

        eval_episodes=100,
        log_interval=1000,
        eval_interval=10000,
        save_interval=10000,
        max_traj_length=200,
        batch_size=256,
        pbrl_batch_size=64,
        max_steps=int(1e6),
        start_training=0,
        pretrain_steps=0,

        tqdm=True,
        save_video=False,
        save_model=False,
        checkpoint_model=False,
        checkpoint_buffer=False,
        utd_ratio=1,
        binary_include_bc=True,
        )



    config_flags.DEFINE_config_file(
        "config",
        "./AGSA/configs/rlphi_config.py",
        "File path to the training hyperparameter configuration.",
        lock_config=False,
    )


def combine(one_dict, other_dict):
    combined = {}
    for k, v in one_dict.items():
        if len(v.shape) > 1:
            tmp = np.vstack((v, other_dict[k]))
        else:
            tmp = np.hstack((v, other_dict[k]))
        combined[k] = tmp
    return combined

def sample_from_pbrl_batch(pbrl_batch, batch_size):
    
    if len(pbrl_batch) < batch_size:
        return {}
        
    pbrl_batch = random.sample(pbrl_batch, batch_size)
        
    better_traj = []
    worse_traj = []
    for b in pbrl_batch:
        better_traj.append(b['better_traj'])
        worse_traj.append(b['worse_traj'])
        
    # (batch_size, traj_len, obs_dim)
    better_obs = np.stack([np.array([sample["observations"] for sample in one_traj]) for one_traj in better_traj])
    # (batch_size, traj_len, action_dim)
    better_action = np.stack([np.array([sample["actions"] for sample in one_traj]) for one_traj in better_traj])
    
    worse_obs = np.stack([np.array([sample["observations"] for sample in one_traj]) for one_traj in worse_traj])
    worse_action = np.stack([np.array([sample["actions"] for sample in one_traj]) for one_traj in worse_traj])
    
    # (batch_size, traj_len, 2, obs_dim)
    obs = np.stack([better_obs, worse_obs], axis=-2)
    action = np.stack([better_action, worse_action], axis=-2)

    return {
        "obs": obs,
        "action": action,
    }


def main(_):
    FLAGS = flags.FLAGS
    assert FLAGS.offline_ratio >= 0.0 and FLAGS.offline_ratio <= 1.0

    os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

    rand_str = ''.join(random.choices(string.ascii_uppercase + string.digits, k=6))
    set_random_seed(FLAGS.seed)

    if FLAGS.expert_dir.split("/")[-2] in expert_dict.keys():
        expert_type = expert_dict[FLAGS.expert_dir.split("/")[-2]]
    else:
        raise ValueError("Expert type not supported")
    wandb_name = f"s{FLAGS.seed}-{FLAGS.pretrain_steps}pt-{FLAGS.threshold_n_steps}th-{expert_type}_exp-pinterv_{FLAGS.use_p_intervene}-pthresh{FLAGS.p_intervene_threshold}-{rand_str}"
    wandb.init(project=FLAGS.project_name, 
               mode='online',
               name=wandb_name)
    wandb.config.update(FLAGS)

    exp_prefix = f"{FLAGS.sparse_env}_{expert_type}exp_s{FLAGS.seed}_{FLAGS.pretrain_steps}pretrain_{FLAGS.utd_ratio}utd_{FLAGS.offline_ratio}offline"
    if hasattr(FLAGS.config, "critic_layer_norm") and FLAGS.config.critic_layer_norm:
        exp_prefix += "_LN"

    log_dir = os.path.join(FLAGS.log_dir, exp_prefix)

    if FLAGS.checkpoint_model:
        chkpt_dir = os.path.join(log_dir, "checkpoints")
        os.makedirs(chkpt_dir, exist_ok=True)

    if FLAGS.checkpoint_buffer:
        buffer_dir = os.path.join(log_dir, "buffers")
        os.makedirs(buffer_dir, exist_ok=True)
    
    if FLAGS.save_model:
        model_dir = os.path.join(log_dir, "model")
        os.makedirs(model_dir, exist_ok=True)


    env = gym.make(FLAGS.env_name)
    env = wrap_gym(env, rescale_actions=True)
    env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=1)
    env.seed(FLAGS.seed)
    ds = D4RLDataset(env)

    eval_env = gym.make(FLAGS.env_name)
    eval_env = wrap_gym(eval_env, rescale_actions=True)
    eval_env.seed(FLAGS.seed + 42)

    sparse_eval_sampler = TrajSampler(GymnasiumWrapper(gymnasium.make(FLAGS.sparse_env).unwrapped), FLAGS.max_traj_length)

    # load agents
    expert_model_pkl_dir = FLAGS.expert_dir
    if 'iql' in expert_model_pkl_dir:
        saved_ckpt_expert = load_model(expert_model_pkl_dir)
        intervene_policy = get_iql_policy_from_model(eval_env, saved_ckpt_expert)
    elif 'rlpd' in expert_model_pkl_dir:
        saved_ckpt_expert = load_model(expert_model_pkl_dir)
        intervene_policy = get_rlpd_policy_from_model(eval_env, saved_ckpt_expert)
    else:
        saved_ckpt_expert = load_model(expert_model_pkl_dir)
        intervene_policy = get_policy_from_model(eval_env, saved_ckpt_expert)
    
    if FLAGS.ground_truth_agent_dir != '':
        if 'iql' in FLAGS.ground_truth_agent_dir:
            ground_truth_agent = load_model(FLAGS.ground_truth_agent_dir)['iql']
            ground_truth_policy = IQLSamplerPolicy(ground_truth_agent.actor)
            ground_truth_agent_type = 'iql'
        elif 'sac' in FLAGS.ground_truth_agent_dir or 'bc' in FLAGS.ground_truth_agent_dir:
            ground_truth_agent = load_model(FLAGS.ground_truth_agent_dir)['sac']
            ground_truth_policy = SamplerPolicy(ground_truth_agent.policy, ground_truth_agent.train_params['policy'])
            ground_truth_agent_type = 'sac'
        elif 'rlpd' in FLAGS.ground_truth_agent_dir:
            ground_truth_agent = load_model(FLAGS.ground_truth_agent_dir)['rlpd']
            ground_truth_policy = RLPDSamplerPolicy(ground_truth_agent.actor)
            ground_truth_agent_type = 'rlpd'
        else:
            raise ValueError("agent type not supported") 
    else:
        ground_truth_agent = FLAGS.ground_truth_agent_dir
        ground_truth_agent_type = ''


    kwargs = dict(FLAGS.config)
    model_cls = kwargs.pop("model_cls")
    agent: RLPhiLearner = globals()[model_cls].create(
        FLAGS.seed, env.observation_space, env.action_space, **kwargs
    )

    if FLAGS.dataset_dir != '':
            with open(FLAGS.dataset_dir, 'rb') as handle:
                dataset = pickle.load(handle)
    else:
        dataset = get_d4rl_dataset(env)

    dataset['actions'] = np.clip(dataset['actions'], -0.999, 0.999)
    dataset['rewards'] = np.zeros_like(dataset['rewards'])
    dataset['masks'] = 1 - dataset['dones']

    replay_buffer = AdjusterReplayBuffer(
        env.observation_space, env.action_space, 
        FLAGS.max_steps, None, FLAGS.intervene_n_steps,
    )
    replay_buffer.seed(FLAGS.seed)

    for i in range(len(dataset['rewards'])):
        replay_buffer.insert(
            dict(
                observations=dataset['observations'][i],
                actions=dataset['actions'][i],
                human_observation_list=[],
                human_action_list=[],
                has_human=False,
                adjuster_actions=np.random.randint(2),
                rewards=0,
                adjuster_rewards=0,
                masks=dataset['masks'][i],
                dones=dataset['dones'][i],
                next_observations=dataset['next_observations'][i],
                gt_intervenes=0,
            )
        )
    

    for i in tqdm.tqdm(
        range(0, FLAGS.pretrain_steps), smoothing=0.1, disable=not FLAGS.tqdm
    ):
        offline_batch = ds.sample(FLAGS.batch_size * FLAGS.utd_ratio)
        batch = {}
        for k, v in offline_batch.items():
            batch[k] = v
            if "antmaze" in FLAGS.env_name and k == "rewards":
                batch[k] -= 1

        agent, update_info = agent.update(batch, FLAGS.utd_ratio)

        if i % FLAGS.log_interval == 0:
            for k, v in update_info.items():
                wandb.log({f"offline-training/{k}": v}, step=i)

        if i % FLAGS.eval_interval == 0:
            eval_info = evaluate(agent, eval_env, num_episodes=FLAGS.eval_episodes)

            for k, v in eval_info.items():
                wandb.log({f"offline-evaluation/{k}": v}, step=i)

            sampler_policy = RLPDSamplerPolicy(agent.actor)
            sparse_trajs = sparse_eval_sampler.sample(
                    sampler_policy,
                    FLAGS.eval_episodes, deterministic=False
                )
            avg_success = evaluate_policy(sparse_trajs,
                                            success_rate=True,
                                            success_function=lambda t: np.all(t['rewards'][-1:]>=10),
                                            )
            wandb.log({f"offline-evaluation/avg_success": avg_success}, step=i)
        
    all_observations = []
    all_actions = []
    all_rewards = []
    all_adjuster_rewards = []
    all_masks = []
    all_dones = []
    all_next_observations = []
    all_intervene = []
    all_p_intervene = []
    all_n_intervene = []
    all_final_intervene = []
    all_adjuster_acc = []
    all_adjuster_preds = []

    # for preference-based rl
    current_human_traj = []
    prev_human_traj = []
    pbrl_dataset = []

    observation, done = env.reset(), False
    t = 0
    intervene = False
    p_intervene = False
    n_intervene = False
    final_intervene_flag = False
    prev_final_intervene = False
    stop_final_intervene_time = -1
    threshold_n_steps = 0
    adjuster_reward = 0
    first_intervene_action_mask = []
    first_p_intervene_action_mask = []
    hstate = LSTM_RLPD.initialize_carry(FLAGS.batch_size, FLAGS.config.adjuster_hidden_dim)
    inference_hstate = LSTM_RLPD.initialize_carry(1, FLAGS.config.adjuster_hidden_dim)
    
    for i in tqdm.tqdm(
        range(1, FLAGS.max_steps + 1), smoothing=0.1, disable=not FLAGS.tqdm
    ):
        
        if t == stop_final_intervene_time:
            final_intervene_flag = False 

        policy_action, agent = agent.sample_actions(observation)

        expert_action = intervene_policy(observation.reshape(1, -1), deterministic=False).reshape(-1)
        ground_truth_action = ground_truth_policy(observation.reshape(1, -1), deterministic=False).reshape(-1)

        if 'ref' in FLAGS.intervention_strategy:
            reference_action = expert_action
        else:
            reference_action = ground_truth_action

        if not final_intervene_flag:
            # determine whether to intervene at the current step
            if ground_truth_agent_type == 'iql':
                gt_q1, gt_q2 = ground_truth_agent.critic(observation, reference_action)
                gt_q = np.min([gt_q1, gt_q2])
                policy_q1, policy_q2 = ground_truth_agent.critic(observation, policy_action)

                policy_q = np.min([policy_q1, policy_q2])
            elif ground_truth_agent_type == 'sac':
                gt_q1 = ground_truth_agent.qf.apply(ground_truth_agent.train_params['qf1'], observation, reference_action)
                gt_q2 = ground_truth_agent.qf.apply(ground_truth_agent.train_params['qf2'], observation, reference_action)
                gt_q = np.min([gt_q1, gt_q2])

                policy_q1 = ground_truth_agent.qf.apply(ground_truth_agent.train_params['qf1'], observation, policy_action)
                policy_q2 = ground_truth_agent.qf.apply(ground_truth_agent.train_params['qf2'], observation, policy_action)
                policy_q = np.min([policy_q1, policy_q2])
            else:
                gt_qs = ground_truth_agent.critic.apply_fn(
                    {"params": ground_truth_agent.critic.params},
                    observation,
                    reference_action,
                    True,
                )
                gt_q = gt_qs.mean(axis=0)

                policy_qs = ground_truth_agent.critic.apply_fn(
                    {"params": ground_truth_agent.critic.params},
                    observation,
                    policy_action,
                    True,
                )
                policy_q = policy_qs.mean(axis=0)

            if policy_q < gt_q * FLAGS.intervene_threshold:
                intervene = np.random.choice([0, 1], p=[0.05, 1-0.05])
            else:
                intervene = np.random.choice([0, 1], p=[1-0.05, 0.05])
                
            intervene = bool(intervene)

        
        if (FLAGS.intervene_type == 'uncertainty' or FLAGS.combine_un) and not final_intervene_flag:

            # next_state_prediction = agent.dynamics.apply_fn(
            #     {"params": agent.dynamics.params},
            #     observation,
            #     policy_action,
            #     True,
            # )
            next_state_prediction = _get_dynamics_output(
                agent.dynamics.apply_fn,
                agent.dynamics.params,
                observation,
                policy_action,
            )
            state_uncertainty = next_state_prediction.std(axis=0).mean()
            # uncertainty-based intervene
            if state_uncertainty > FLAGS.p_intervene_threshold:
                p_intervene = np.random.choice([0, 1], p=[0.05, 1-0.05])
            else:
                p_intervene = np.random.choice([0, 1], p=[1-0.05, 0.05])

            p_intervene = bool(p_intervene)

        if FLAGS.intervene_type == 'neural': 
            # compute adjuster prediction, mainly to maintain good hstate
            inference_hstate, pred = agent.adjuster.apply_fn(
                {"params": agent.adjuster.params},
                inference_hstate,
                (observation, policy_action, [done]),
                False,
            )

        if not final_intervene_flag:
            
            if FLAGS.intervene_type == 'neural_rl':
                
                if np.random.rand() < 0.1:
                    n_intervene = np.random.randint(2)
                else:
                    # adjuster_q_value = agent.adjuster_Q.apply_fn(
                    #     {"params": agent.adjuster_Q.params},
                    #     observation,
                    #     False,
                    # ).squeeze()
                    adjuster_q_value = _get_adjuster_q_value(
                        agent.adjuster_Q.apply_fn,
                        agent.adjuster_Q.params,
                        observation,
                    )
                    n_intervene = np.mean(adjuster_q_value, axis=0).argmax()
                
            elif FLAGS.intervene_type == 'neural':

                if jax.nn.sigmoid(pred) > 0.5:
                    n_intervene = np.random.choice([0, 1], p=[0.05, 1-0.05])
                else:
                    n_intervene = np.random.choice([0, 1], p=[1-0.05, 0.05])

            else:
                n_intervene = 0

            n_intervene = bool(n_intervene)
            

            
        if FLAGS.intervene_type in ['neural', 'neural_rl']:
            if FLAGS.combine_un:
                one_step_intervene_flag = p_intervene or n_intervene
            else:
                one_step_intervene_flag = n_intervene
        elif FLAGS.intervene_type == 'uncertainty':
            one_step_intervene_flag = p_intervene
        elif FLAGS.intervene_type == 'value':
            one_step_intervene_flag = intervene
        else:
            raise ValueError("intervene type not supported")
        
        if not final_intervene_flag:
            # determine whether intervene has been triggered for threshold_n_steps
            if one_step_intervene_flag:
                threshold_n_steps += 1
                if threshold_n_steps == FLAGS.threshold_n_steps:
                    final_intervene_flag = True
                    stop_final_intervene_time = t + FLAGS.intervene_n_steps
                    threshold_n_steps = 0
            else:
                threshold_n_steps = 0
            
        if final_intervene_flag:
        
            if t != 0 and not prev_final_intervene:
                # append state action pair that led to previous intervention
                first_intervene_action_mask[-1] = 1

                # compute rl adjuster reward
                adjuster_accuracy_reward = 2 * (final_intervene_flag == intervene) - (final_intervene_flag != intervene)
                # adjuster_intervene_reward = -int(intervene)
                adjuster_intervene_reward = -1
                # possible better intervene reward
                # adjuster_intervene_reward = -int(final_intervene_flag)
                
                adjuster_reward = FLAGS.adj_acc_rew_ratio * adjuster_accuracy_reward + adjuster_intervene_reward
            
                replay_buffer.insert(
                dict(
                        observations=all_observations[-1],
                        actions=all_actions[-1],
                        human_observation_list=[],
                        human_action_list=[],
                        has_human=False,
                        adjuster_actions=1,
                        rewards=-1,
                        adjuster_rewards=adjuster_reward,
                        masks=all_masks[-1],
                        dones=all_dones[-1],
                        next_observations=all_next_observations[-1],
                        gt_intervenes=intervene,
                    )
                )
                
            if 'label' in FLAGS.intervention_strategy:
                action = policy_action
            else:
                action = expert_action
        else:

            if prev_final_intervene and FLAGS.use_pbrl:
                last_agent_action_step = -1-FLAGS.intervene_n_steps
                # if intervened at the first step, skip adding obs since no last_agent_action
                if len(all_observations) >= abs(last_agent_action_step):
                    replay_buffer.insert(
                        dict(
                                observations=all_observations[last_agent_action_step],
                                actions=all_actions[last_agent_action_step],
                                human_observation_list=np.array([all_observations[last_agent_action_step+1:]]),
                                human_action_list=np.array([all_actions[last_agent_action_step+1:]]),
                                has_human=True,
                                adjuster_actions=1,
                                rewards=-1,
                                adjuster_rewards=adjuster_reward,
                                masks=all_masks[last_agent_action_step],
                                dones=all_dones[last_agent_action_step],
                                next_observations=all_next_observations[last_agent_action_step],
                                gt_intervenes=intervene,
                            )
                        )

            action = policy_action
            adjuster_reward = 0

            if t != 0:
                replay_buffer.insert(
                   dict(
                        observations=all_observations[-1],
                        actions=all_actions[-1],
                        human_observation_list=[],
                        human_action_list=[],
                        has_human=False,
                        adjuster_actions=0,
                        rewards=0,
                        adjuster_rewards=0,
                        masks=all_masks[-1],
                        dones=all_dones[-1],
                        next_observations=all_next_observations[-1],
                        gt_intervenes=intervene,
                    )
                )

        next_observation, reward, done, info = env.step(action)
        
        if final_intervene_flag and FLAGS.use_pbrl:
            # add human trajectory to the preference dataset
            current_human_traj.append(
                dict(
                    observations=observation,
                    actions=action,
                    rewards=reward,
                )
            )

        if not done or "TimeLimit.truncated" in info:
            mask = 1.0
        else:
            mask = 0.0

        prev_final_intervene = copy.deepcopy(final_intervene_flag)
        all_observations += [observation]
        all_actions += [action]
        all_rewards += [0]
        if FLAGS.intervene_type == 'neural_rl':
            all_adjuster_rewards += [adjuster_reward]
        all_masks += [mask]
        all_dones += [done]
        all_next_observations += [next_observation]
        first_intervene_action_mask.append(0)
        all_intervene += [intervene]
        all_p_intervene += [p_intervene]
        all_n_intervene += [n_intervene]
        all_final_intervene += [final_intervene_flag]
        all_adjuster_acc += [final_intervene_flag==intervene]
        if FLAGS.intervene_type == 'neural':
            all_adjuster_preds += [jax.nn.sigmoid(pred)]
        else:
            all_adjuster_preds += [0]
        t += 1

        observation = next_observation

        if done or t > FLAGS.max_traj_length:
            observation, done = env.reset(), False
            intervene = False
            p_intervene = False
            n_intervene = False
            final_intervene_flag = False
            prev_final_intervene = False
            stop_final_intervene_time = -1
            threshold_n_steps = 0
            t = 0

            # reset human trajectory for preference-based rl
            current_human_traj = []
            prev_human_traj = []

            try:
                for k, v in info["episode"].items():
                    decode = {"r": "return", "l": "length", "t": "time"}
                    wandb.log({f"training/{decode[k]}": v}, step=i + FLAGS.pretrain_steps)
            except:
                pass

        online_batch = replay_buffer.sample(
            int(FLAGS.batch_size * FLAGS.utd_ratio * (1 - FLAGS.offline_ratio))
        )
        if FLAGS.offline_ratio > 0:
            raise ValueError("Offline ratio not supported")
            offline_batch = ds.sample(
                int(FLAGS.batch_size * FLAGS.utd_ratio * FLAGS.offline_ratio)
            )

            batch = combine(offline_batch, online_batch)
        else:
            batch = online_batch

        if "antmaze" in FLAGS.env_name:
            batch["rewards"] -= 1
            
        # prepare batch for preference-based training
        
        if FLAGS.use_pbrl:
            if len(current_human_traj) == FLAGS.intervene_n_steps:
                if len(prev_human_traj) == FLAGS.intervene_n_steps:
                    current_traj_rew = sum([t['rewards'] for t in current_human_traj])
                    prev_traj_rew = sum([t['rewards'] for t in prev_human_traj])
                    better_traj = current_human_traj if current_traj_rew > prev_traj_rew else prev_human_traj
                    worse_traj = current_human_traj if current_traj_rew <= prev_traj_rew else prev_human_traj
                    one_traj = {
                        "better_traj": better_traj,
                        "worse_traj": worse_traj,
                    }
                    pbrl_dataset.append(one_traj)
                
                prev_human_traj = copy.deepcopy(current_human_traj)
                current_human_traj = []

            pbrl_batch = sample_from_pbrl_batch(pbrl_dataset, FLAGS.pbrl_batch_size)
        else:
            pbrl_batch = {}

        agent, update_info, hstate = agent.update(
            batch, 
            pbrl_batch, 
            hstate, 
            FLAGS.intervene_type, 
            FLAGS.utd_ratio, 
            FLAGS.use_pbrl, 
            FLAGS.pbrl_normalizing_ratio,
            FLAGS.pbrl_rew_ratio,
            FLAGS.use_pbrl_in_learning_agent,
        )

        if i % FLAGS.log_interval == 0:
            for k, v in update_info.items():
                if "/" in k:
                    wandb.log({f"{k}": v}, step=i + FLAGS.pretrain_steps)
                else:
                    wandb.log({f"training/{k}": v}, step=i + FLAGS.pretrain_steps)

        if i % FLAGS.eval_interval == 0:
            eval_info = evaluate(
                agent,
                eval_env,
                num_episodes=FLAGS.eval_episodes,
                save_video=FLAGS.save_video,
            )

            for k, v in eval_info.items():
                wandb.log({f"evaluation/{k}": v}, step=i + FLAGS.pretrain_steps)
            
            # value-based intervene rate
            wandb.log({f"evaluation/intervene_rate": np.mean(all_intervene[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            # uncertainty-based intervene rate
            wandb.log({f"evaluation/p_intervene_rate": np.mean(all_p_intervene[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            # neural-based intervene rate
            wandb.log({f"evaluation/n_intervene_rate": np.mean(all_n_intervene[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            wandb.log({f"evaluation/final_intervene_rate": np.mean(all_final_intervene[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            wandb.log({f"evaluation/test_adjuster_consistency": np.mean(all_adjuster_acc[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            if FLAGS.intervene_type == 'neural':
                wandb.log({f"evaluation/avg_adjuster_preds": np.mean(all_adjuster_preds[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
            if FLAGS.intervene_type == 'neural_rl':
                wandb.log({f"evaluation/avg_adjuster_rewards": np.mean(all_adjuster_rewards[-FLAGS.eval_interval+1:])}, step=i + FLAGS.pretrain_steps)
                if FLAGS.use_pbrl:
                    wandb.log({f"training/pbrl_dataset_size": len(pbrl_dataset)}, step=i + FLAGS.pretrain_steps)
            
            sampler_policy = RLPDSamplerPolicy(agent.actor)
            sparse_trajs = sparse_eval_sampler.sample(
                    sampler_policy,
                    FLAGS.eval_episodes, deterministic=False
                )
            avg_success = evaluate_policy(sparse_trajs,
                                            success_rate=True,
                                            success_function=lambda t: np.all(t['rewards'][-1:]>=10),
                                            )
            wandb.log({f"evaluation/avg_success": avg_success}, step=i + FLAGS.pretrain_steps)

            if FLAGS.checkpoint_model:
                try:
                    checkpoints.save_checkpoint(
                        chkpt_dir, agent, step=i, keep=20, overwrite=True
                    )
                except:
                    print("Could not save model checkpoint.")

            if FLAGS.checkpoint_buffer:
                try:
                    with open(os.path.join(buffer_dir, f"buffer"), "wb") as f:
                        pickle.dump(replay_buffer, f, pickle.HIGHEST_PROTOCOL)
                except:
                    print("Could not save agent buffer.")

        
        if i % FLAGS.save_interval == 0:
            
            if FLAGS.save_model:
                save_data = {'rlpd': agent}
                with open(os.path.join(model_dir, f"model_{i//1000}k.pkl"), 'wb') as fout:
                    pickle.dump(save_data, fout)




if __name__ == "__main__":
    app.run(main)