import os
import sys
username = os.environ['USER']
sys.path.append(f'/home/{username}/.mujoco/mujoco210/bin')
os.environ['LD_LIBRARY_PATH'] += f':/home/{username}/.mujoco/mujoco210/bin'
import d4rl
import collections
import h5py
import os
import yaml
import wandb
import torch
import time
import smart_settings
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from smart_settings.param_classes import AttributeDict, ImmutableAttributeDict



class Ticker:
    # for testing
    config = None

    def __init__(self, config=None, restart_time=60 * 4) -> None:

        if Ticker.config is not None and config is None:
            raise ValueError('Ticker.config is already set')

        Ticker.config = config
        Ticker.start_time = time.time()
        Ticker.restart_time = restart_time * 60
        Ticker.working_dir = config.working_dir

def recursive_dictify(obj):
    if isinstance(obj, list):
        return [recursive_dictify(v) for v in obj]
    elif isinstance(obj, tuple):
        return tuple(recursive_dictify(v) for v in obj)
    elif isinstance(obj, set):
        return {recursive_dictify(v) for v in obj}
    elif isinstance(obj, AttributeDict):
        return {k: recursive_dictify(v) for k, v in obj.items()}
    elif isinstance(obj, ImmutableAttributeDict):
        return {k: recursive_dictify(v) for k, v in obj.items()}
    else:
        return obj


def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)


def dice_dataset(env, standardize_observation=True, absorbing_state=True, standardize_reward=True, dataset=None):
    if dataset is None:
        dataset = env.get_dataset()
    N = dataset['rewards'].shape[0]
    initial_obs_, obs_, next_obs_, action_, reward_, done_, expert_ = [], [], [], [], [], [], []

    use_timeouts = ('timeouts' in dataset)

    episode_step = 0
    reverse_current_traj = False
    for i in range(N-1):
        obs = dataset['observations'][i].astype(np.float32)
        new_obs = dataset['observations'][i+1].astype(np.float32)
        action = dataset['actions'][i].astype(np.float32)
        reward = dataset['rewards'][i].astype(np.float32)
        done_bool = bool(dataset['terminals'][i])
        is_final_timestep = dataset['timeouts'][i] if use_timeouts else (episode_step == env._max_episode_steps - 1)
        if is_final_timestep:
            # Skip this transition and don't apply terminals on the last step of an episode
            episode_step = 0
            continue

        if episode_step == 0:
            initial_obs_.append(obs)

        obs_.append(obs)
        next_obs_.append(new_obs)
        action_.append(action)
        reward_.append(reward)
        done_.append(done_bool)
        expert_.append(bool(0)) # assume not expert
        episode_step += 1

        if done_bool or is_final_timestep:
            episode_step = 0

    initial_obs_dataset = {
        'initial_observations': np.array(initial_obs_, dtype=np.float32)
    }
    dataset = {
        'observations': np.array(obs_, dtype=np.float32),
        'actions': np.array(action_, dtype=np.float32),
        'next_observations': np.array(next_obs_, dtype=np.float32),
        'rewards': np.array(reward_, dtype=np.float32),
        'terminals': np.array(done_, dtype=np.float32),
        'experts': np.array(expert_, dtype=np.float32)
    }
    dataset_statistics = {
        'observation_mean': np.mean(dataset['observations'], axis=0),
        'observation_std': np.std(dataset['observations'], axis=0),
        'reward_mean': np.mean(dataset['rewards']),
        'reward_std': np.std(dataset['rewards']),
        'N_initial_observations': len(initial_obs_),
        'N': len(obs_),
        'observation_dim': dataset['observations'].shape[-1],
        'action_dim': dataset['actions'].shape[-1]
    }

    if standardize_observation:
        initial_obs_dataset['initial_observations'] = (initial_obs_dataset['initial_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['observations'] = (dataset['observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['next_observations'] = (dataset['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if standardize_reward:
        dataset['rewards'] = (dataset['rewards'] - dataset_statistics['reward_mean']) / (dataset_statistics['reward_std'] + 1e-10)

    if absorbing_state:
        # add additional dimension to observations to deal with absorbing state
        initial_obs_dataset['initial_observations'] = np.concatenate((initial_obs_dataset['initial_observations'], np.zeros((dataset_statistics['N_initial_observations'], 1))), axis=1).astype(np.float32)
        dataset['observations'] = np.concatenate((dataset['observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        terminal_indices = np.where(dataset['terminals'])[0]
        absorbing_state = np.eye(dataset_statistics['observation_dim'] + 1)[-1].astype(np.float32)
        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            list(dataset['observations']), list(dataset['actions']), list(dataset['rewards']), list(dataset['next_observations']), list(dataset['terminals'])
        for terminal_idx in terminal_indices:
            dataset['next_observations'][terminal_idx] = absorbing_state
            dataset['observations'].append(absorbing_state)
            dataset['actions'].append(dataset['actions'][terminal_idx])
            dataset['rewards'].append(0)
            dataset['next_observations'].append(absorbing_state)
            dataset['terminals'].append(1)

        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            np.array(dataset['observations'], dtype=np.float32), np.array(dataset['actions'], dtype=np.float32), np.array(dataset['rewards'], dtype=np.float32), \
            np.array(dataset['next_observations'], dtype=np.float32), np.array(dataset['terminals'], dtype=np.float32)

    return initial_obs_dataset, dataset, dataset_statistics


def dice_combined_dataset(expert_env, env, num_expert_traj=2000, num_offline_traj=2000, expert_dataset=None, offline_dataset=None,
                            standardize_observation=True, absorbing_state=True, standardize_reward=True, reverse=False):
    """
    env: d4rl environment
    """
    initial_obs_, obs_, next_obs_, action_, reward_, done_, expert_ = [], [], [], [], [], [], []

    def add_data(env, num_traj, dataset=None, expert_data=False):
        if dataset is None:
            dataset = env.get_dataset()
        N = dataset['rewards'].shape[0]
        use_timeouts = ('timeouts' in dataset)
        traj_count = 0
        episode_step = 0
        reverse_current_traj = 0
        for i in range(N-1):
            # only use this condition when num_traj < 2000
            if num_traj != 2000 and traj_count == num_traj:
                break
            obs = dataset['observations'][i].astype(np.float32)
            new_obs = dataset['observations'][i+1].astype(np.float32)
            action = dataset['actions'][i].astype(np.float32)
            reward = dataset['rewards'][i].astype(np.float32)
            done_bool = bool(dataset['terminals'][i])

            is_final_timestep = dataset['timeouts'][i] if use_timeouts else (episode_step == env._max_episode_steps - 1)
            if is_final_timestep:
                # Skip this transition and don't apply terminals on the last step of an episode
                traj_count += 1
                episode_step = 0
                reverse_current_traj = not reverse_current_traj
                continue

            if episode_step == 0:
                initial_obs_.append(obs)

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            expert_.append(expert_data)
            episode_step += 1

            if done_bool or is_final_timestep:
                traj_count += 1
                episode_step = 0
                reverse_current_traj = not reverse_current_traj

    add_data(expert_env, num_expert_traj, dataset=expert_dataset, expert_data=True)
    expert_size = len(obs_)
    print(f"Expert Traj {num_expert_traj}, Expert Size {expert_size}")
    add_data(env, num_offline_traj, dataset=offline_dataset, expert_data=False)
    offline_size = len(obs_) - expert_size
    print(f"Offline Traj {num_offline_traj}, Offline Size {offline_size}")

    initial_obs_dataset = {
        'initial_observations': np.array(initial_obs_, dtype=np.float32)
    }
    dataset = {
        'observations': np.array(obs_, dtype=np.float32),
        'actions': np.array(action_, dtype=np.float32),
        'next_observations': np.array(next_obs_, dtype=np.float32),
        'rewards': np.array(reward_, dtype=np.float32),
        'terminals': np.array(done_, dtype=np.float32),
        'experts': np.array(expert_, dtype=np.float32)
    }
    dataset_statistics = {
        'observation_mean': np.mean(dataset['observations'], axis=0),
        'observation_std': np.std(dataset['observations'], axis=0),
        'reward_mean': np.mean(dataset['rewards']),
        'reward_std': np.std(dataset['rewards']),
        'N_initial_observations': len(initial_obs_),
        'N': len(obs_),
        'observation_dim': dataset['observations'].shape[-1],
        'action_dim': dataset['actions'].shape[-1]
    }

    if standardize_observation:
        initial_obs_dataset['initial_observations'] = (initial_obs_dataset['initial_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['observations'] = (dataset['observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
        dataset['next_observations'] = (dataset['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if standardize_reward:
        dataset['rewards'] = (dataset['rewards'] - dataset_statistics['reward_mean']) / (dataset_statistics['reward_std'] + 1e-10)

    if absorbing_state:
        # add additional dimension to observations to deal with absorbing state
        initial_obs_dataset['initial_observations'] = np.concatenate((initial_obs_dataset['initial_observations'], np.zeros((dataset_statistics['N_initial_observations'], 1))), axis=1).astype(np.float32)
        dataset['observations'] = np.concatenate((dataset['observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.zeros((dataset_statistics['N'], 1))), axis=1).astype(np.float32)
        terminal_indices = np.where(dataset['terminals'])[0]
        absorbing_state = np.eye(dataset_statistics['observation_dim'] + 1)[-1].astype(np.float32)
        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            list(dataset['observations']), list(dataset['actions']), list(dataset['rewards']), list(dataset['next_observations']), list(dataset['terminals'])
        for terminal_idx in terminal_indices:
            dataset['next_observations'][terminal_idx] = absorbing_state
            dataset['observations'].append(absorbing_state)
            dataset['actions'].append(dataset['actions'][terminal_idx])
            dataset['rewards'].append(0)
            dataset['next_observations'].append(absorbing_state)
            dataset['terminals'].append(1)

        dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
            np.array(dataset['observations'], dtype=np.float32), np.array(dataset['actions'], dtype=np.float32), np.array(dataset['rewards'], dtype=np.float32), \
            np.array(dataset['next_observations'], dtype=np.float32), np.array(dataset['terminals'], dtype=np.float32)

    return initial_obs_dataset, dataset, dataset_statistics



def evaluate(env, agent, dataset_statistics, absorbing_state=True, num_evaluation=10,
             pid=None, normalize=True, make_gif=False, iteration=0, max_steps=None, run_name='',
             prefix='', disc_idxs=None, div_idxs=None):
    normalized_scores = []
    if max_steps is None:
        max_steps = env._max_episode_steps
    imgs = []

    obs_dim = env.observation_space.shape[0] if not absorbing_state else env.observation_space.shape[0] + 1

    observations = np.zeros((num_evaluation, max_steps, env.observation_space.shape[0]))
    observations_agent = np.zeros((num_evaluation, max_steps, obs_dim))
    actions_arr = np.zeros((num_evaluation, max_steps, env.action_space.shape[0]))
    start_time = time.time()
    for eval_iter in range(num_evaluation):
        obs = env.reset()
        episode_reward = 0

        for t in tqdm(range(max_steps), ncols=70, desc='evaluate', ascii=True, disable=os.environ.get("DISABLE_TQDM", False)):
            if absorbing_state:
                obs_standardized = np.append((obs - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10), 0)
            else:
                obs_standardized = (obs - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)

            actions = agent.step((np.array([obs_standardized])[:, disc_idxs]).astype(np.float32))
            action = actions[0][0].numpy()

            # prevent NAN
            action = np.clip(action, env.action_space.low, env.action_space.high)

            actions_arr[eval_iter, t] = action
            observations[eval_iter, t] = obs
            observations_agent[eval_iter, t] = obs_standardized

            next_obs, reward, done, info = env.step(action)

            if make_gif and eval_iter == 0:
                img = env.render(mode="rgb_array")
                imgs.append(img)
            episode_reward += reward
            if done:
                break
            obs = next_obs
        if normalize:
            normalized_score = 100 * (episode_reward - d4rl.infos.REF_MIN_SCORE[env.spec.id]) / (d4rl.infos.REF_MAX_SCORE[env.spec.id] - d4rl.infos.REF_MIN_SCORE[env.spec.id])
        else:
            normalized_score = episode_reward
        if pid is not None:
            print(f'PID [{pid}], Eval Iteration {eval_iter}')
        # print(f'normalized_score: {normalized_score} (elapsed_time={time.time() - start_time:.3f}) ')
        normalized_scores.append(normalized_score)

    if make_gif:
        imgs = np.array(imgs)
        #imgs[0].save(f"{Ticker.working_dir}/policy_gifs/{run_name}-iter{iteration}.gif", save_all=True,
        #    append_images=imgs[1:], duration=30, loop=0)
    # print(normalized_scores)
    mu_score, std_score, inf =  np.mean(normalized_scores), np.std(normalized_scores), {'imgs': imgs, 'observations': observations,
                                                                   'actions': actions_arr,
                                                                   'observations_agent': observations_agent}
    print("-"*50)
    print(f"mu_score: {mu_score:.3f}, std_score: {std_score:.3f}, time: {time.time() - start_time:.3f}")
    print("-"*50)
    return mu_score, std_score, inf


def evaluate_bp(env, behavior_policy, dataset_statistics, action_max=1.0, action_min=-1.0):
    max_steps = env._max_episode_steps
    imgs = []
    obs_mean = torch.from_numpy(dataset_statistics['observation_mean']).to(env.device).unsqueeze(0)
    obs_std = torch.from_numpy(dataset_statistics['observation_std']).to(env.device).unsqueeze(0)
    act_max = torch.from_numpy(action_max).to(env.device).unsqueeze(0)
    act_min = torch.from_numpy(action_min).to(env.device).unsqueeze(0)
    with torch.no_grad():
        env.reset()
        obs = env.get_observations().detach()

        for t in tqdm(range(max_steps), ncols=70, desc='evaluate behavior policy', ascii=True,
                      disable=os.environ.get("DISABLE_TQDM", False)):
            obs_standardized = (obs - obs_mean) / (obs_std + 1e-10)

            (actions, _, _, _, _), _ = behavior_policy((obs_standardized,))

            # prevent NAN
            action = torch.clamp(actions, -10, 10).detach()

            action_ = (action + 1.0) / 2.0 * (act_max - act_min) + act_min
            next_obs, _, done, _, img_f = env.step(action_)
            imgs.append(img_f)
            obs = next_obs.detach()
        del obs_mean, obs_std, act_max, act_min
        torch.cuda.empty_cache()
        imgs = np.array(imgs).transpose(1, 0, 4, 2, 3)
        videos = [wandb.Video(img_seq, fps=50, format="mp4") for img_seq in imgs]
        wandb.log({'pretrain/behavior_policy_evaluate': videos})


def evaluate_solo(env, agent, dataset_statistics, absorbing_state=True, num_evaluation=2,
                  pid=None, normalize=True, make_gif=False, iteration=0, max_steps=None, run_name='',
                  prefix='', action_max=1.0, action_min=-1.0, div_idxs=None, disc_idxs=None):
    obs_mean = torch.from_numpy(dataset_statistics['observation_mean']).to(env.device).unsqueeze(0)
    obs_std = torch.from_numpy(dataset_statistics['observation_std']).to(env.device).unsqueeze(0)
    act_max = torch.from_numpy(action_max).to(env.device).unsqueeze(0)
    act_min = torch.from_numpy(action_min).to(env.device).unsqueeze(0)

    if max_steps is None:
        max_steps = env._max_episode_steps
    imgs = []

    obs_dim = env.num_obs if not absorbing_state else env.num_obs + 1

    skills = agent._skills
    num_skills = len(skills)
    # repeat each skill num_evaluation times
    skills = torch.repeat_interleave(skills, num_evaluation, 0)
    observations = torch.zeros((num_evaluation*num_skills, max_steps, env.num_obs), device=env.device)
    observations_agent = torch.zeros((num_evaluation*num_skills, max_steps, obs_dim), device=env.device)
    heights = torch.zeros((num_evaluation*num_skills, max_steps, 1), device=env.device)
    start_time = time.time()

    if disc_idxs is None:
        disc_idxs = np.arange(env.num_obs)
    if div_idxs is None:
        div_idxs = np.arange(env.num_obs)

    with torch.no_grad():
        env.reset()
        obs = env.get_observations().detach()
        episode_reward = torch.zeros(num_evaluation*num_skills, device=env.device)

        for t in tqdm(range(max_steps), ncols=70, desc='evaluate', ascii=True,
                      disable=os.environ.get("DISABLE_TQDM", False)):
            
            
            obs_standardized = (obs - obs_mean) / (obs_std + 1e-10)
            actions, _ = agent.step(obs_standardized[:, disc_idxs], gpu=True, eval_skills=skills)
            actions = torch.clamp(actions, -10, 10).detach()
            observations[:, t] = obs
            observations_agent[:, t] = obs_standardized
            action_ = (actions + 1.0) / 2.0 * (act_max - act_min) + act_min
            next_obs, reward, done, _, img = env.step(action_)
            heights[:, t] = env.root_states[:, 2].reshape(-1,1)
         
            if make_gif:
                img_f = img
                imgs.append(img_f)
            episode_reward += reward.detach()
            obs = next_obs.detach()
        if normalize:
            # Ax+b = y
            # x = -23.01, y = 100, expert
            # x = -250, y = 0, random, (abs velocity error = 1, cannot track at all)
            normalized_score = (episode_reward + 250) * 100 / (250 - 23.01)
        else:
            normalized_score = episode_reward
        normalized_score_ = normalized_score.detach()
        print('\n')
        print(f'normalized_score: {normalized_score_} (elapsed_time={time.time() - start_time:.3f}) ')

    del obs_mean, obs_std, act_max, act_min
    torch.cuda.empty_cache()

    if make_gif:
        imgs = np.array(imgs).transpose(1, 0, 2 , 3, 4)
        # imgs[0].save(f"{Ticker.working_dir}/policy_gifs/{run_name}-iter{iteration}.gif", save_all=True,
        #    append_images=imgs[1:], duration=30, loop=0)
    # print(normalized_scores)
    normalized_score_ = normalized_score_.reshape(num_skills, num_evaluation).cpu().numpy()
    
    return np.mean(normalized_score_, axis=1), np.std(normalized_score_, axis=1), {'imgs': imgs, 'observations': observations.cpu().numpy(),
                                                                   'observations_agent': observations_agent.cpu().numpy(),
                                                                   'heights': heights.cpu().numpy()}


def compute_successor_metrics(observations_per_skill, max_steps, eval_skills, discount=0.99):
    discount = np.power(discount, np.arange(max_steps))[None, None, :, None]
    succ_feat = np.sum(observations_per_skill*discount, axis=2) # (num_skills, num_evals, obs_dim)
    succ_feat = np.mean(succ_feat, axis=1) # (num_skills, obs_dim)
    # compute pairwise distance
    dist = np.linalg.norm(succ_feat[:, None, :] - succ_feat[None, :, :], axis=-1) # (num_skills, num_skills)
    # compute mean distance
    avg_dist = dist.sum()/(len(eval_skills)*(len(eval_skills)-1)) # (num_skills)
    max_dist = np.max(dist+np.eye(dist.shape[0])*(-1000000000)) # (num_skills)
    min_dist = np.min(dist+np.eye(dist.shape[0])*(1000000000)) # (num_skills)
    return succ_feat, avg_dist, max_dist, min_dist, dist


def compute_discriminator_reward(discriminator, obs, skill, div_idxs):
    # repeat skill for each obs
    obs_batch = len(obs)
    skill_batch = torch.tile(skill, (obs_batch, 1))
    obs = np.array(obs)
    if div_idxs is not None:
        obs = obs[:, div_idxs]
    device = next(discriminator.parameters()).device
    obs = torch.FloatTensor(obs).to(device)
    skill = skill.to(device)
    with torch.no_grad():
        reward, logits, info = discriminator.skill_reward((obs,), skill_batch)
        info = {k: v.cpu().numpy() for k, v in info.items()}
        return reward.cpu().numpy(), info

import time
def evaluate_skills(env, agent, dataset_statistics, proj, absorbing_state=True, num_evaluation=2,
                    pid=None, normalize=True, make_gif=False, iteration=0, max_steps=None, run_name='',
                    prefix='', log_wandb=True, solo=False, action_max=1.0, action_min=-1.0, div_idxs=None, disc_idxs=None):
    
    fps_ = 50 if solo else 30 # frames per second
    eval_skills = agent.get_eval_skills()
    imgs_per_skill = []
    mu_per_skill = []
    std_per_skill = []
    if max_steps is None:
        max_steps = env._max_episode_steps
    obs_dim = env.observation_space.shape[0] if not absorbing_state else env.observation_space.shape[0] + 1
    observations_per_skill = np.zeros((len(eval_skills), num_evaluation, max_steps,
                                       env.observation_space.shape[0]))  # (num_skills, num_eval, max_steps, obs_dim)
    observations_per_skill_agent = np.zeros((len(eval_skills), num_evaluation, max_steps, env.observation_space.shape[
        0] + absorbing_state))  # (num_skills, num_eval, max_steps, obs_dim)
    
    height_std_across_skills = None
    if solo:
        # run with Isaac Gym
        img_env_idxs = [i*num_evaluation for i in range(len(eval_skills))]
        env.set_image_env_idxs(img_env_idxs)
        mu, std, info = evaluate_solo(env, agent, dataset_statistics, absorbing_state, num_evaluation, pid,
                                        normalize, make_gif, iteration, max_steps, run_name,
                                        action_max=action_max, action_min=action_min, disc_idxs=disc_idxs, div_idxs=div_idxs)
        observations_per_skill = info['observations'].reshape(len(eval_skills), num_evaluation, max_steps, obs_dim)
        heights = info['heights'].reshape(len(eval_skills), num_evaluation, max_steps, 1)
        mu_height_per_skill =  heights.reshape(len(eval_skills),-1).mean(axis=1)
        height_std_across_skills = mu_height_per_skill.std()
        
        mu_per_skill =  mu.reshape(len(eval_skills), -1)
        std_per_skill = std.reshape(len(eval_skills), -1)
        imgs_per_skill = info['imgs']
    else:
        for i, skill in tqdm(enumerate(eval_skills), desc='Evaluating skill...'):
            agent.set_eval_skill(skill)
            mu, std, info = evaluate(env, agent, dataset_statistics, absorbing_state, num_evaluation, pid,
                                        normalize, make_gif, iteration, max_steps, run_name, disc_idxs=disc_idxs, div_idxs=div_idxs)
            imgs = info['imgs']
            observations_per_skill[i] = info['observations']
            observations_per_skill_agent[i] = info['observations_agent']

            imgs_per_skill.append(imgs)
            mu_per_skill.append(mu)
            std_per_skill.append(std)

    if eval_skills[-1].sum() == 0:
        expert_skill_video = {
            'eval/expert_video': wandb.Video(np.array(imgs_per_skill[-1]).transpose(0, 3, 1, 2), fps=fps_, format="mp4")
        }
        imgs_per_skill = imgs_per_skill[:-1]
    else:
        expert_skill_video = {}


    # compute discriminator reward
    obs_dim = observations_per_skill_agent.shape[-1]
    avg_disc_reward_per_skill = []
    avg_disdain_per_skill = []
    avg_logp_per_skill = []
    disc_metrics = {}
    if hasattr(agent, '_skill_discriminator') and agent._skill_discriminator:
        for i, skill in tqdm(enumerate(eval_skills), desc='Computing discriminator reward...'):
            sums = []
            disdain_rew = []
            logp = []
            if skill.sum() == 0:
                continue
            for obs_rollout in observations_per_skill_agent[i]:
                disc_rewards, info = compute_discriminator_reward(agent._skill_discriminator, obs_rollout, skill, div_idxs=div_idxs)
                sums.append(np.sum(disc_rewards))
                if 'disdain' in info:
                    disdain_rew.append(info['disdain'].sum())
                logp.append(info['logp'].sum())
            avg_disc_reward_per_skill.append(np.mean(sums))
            if 'disdain' in info:
                avg_disdain_per_skill.append(np.mean(disdain_rew))
            avg_logp_per_skill.append(np.mean(logp))

        avg_disc_reward = np.mean(avg_disc_reward_per_skill)
        if avg_disdain_per_skill:
            avg_disdain_rew = np.mean(avg_disdain_per_skill)
        avg_logp = np.mean(avg_logp_per_skill)
        # compute successor representations

        disc_metrics = {
                    'eval/disc_reward': avg_disc_reward,
                    'eval/logp': avg_logp
                    }
        if height_std_across_skills is not None:
            disc_metrics['eval/height_std'] = height_std_across_skills

    proj_succ_metrics = {}
    if proj:
        start_proj_time = time.time()
        flat_obses = observations_per_skill.reshape(-1, observations_per_skill.shape[-1])
        proj_obses = proj.transform(flat_obses)
        proj_obses = proj_obses.reshape((len(eval_skills), num_evaluation, max_steps, 2))
        end_proj_time = time.time()
        print('Projection time: ', end_proj_time - start_proj_time)
        proj_succ_fet, avg_dist, max_dist, min_dist = compute_successor_metrics(proj_obses, max_steps, eval_skills)
        proj_succ_metrics = {
            'successor_eval/proj_succ_avg_dist': avg_dist,
            'successor_eval/proj_succ_max_dist': max_dist,
            'successor_eval/proj_succ_min_dist': min_dist,
        }

    succ_fet, avg_dist, max_dist, min_dist, dist_mat = compute_successor_metrics(observations_per_skill, max_steps, eval_skills)
    succ_metrics = {
        'successor_eval/succ_avg_dist': avg_dist,
        'successor_eval/succ_max_dist': max_dist,
        'successor_eval/succ_min_dist': min_dist,
    }

    succ_fet, avg_dist, max_dist, min_dist, dist_mat = compute_successor_metrics(observations_per_skill_agent, max_steps, eval_skills)
    sqrt_d = np.sqrt(observations_per_skill_agent.shape[-1])
    
    succ_metrics2 = {
        'successor_eval/norm_succ_avg_dist': avg_dist/sqrt_d,
        'successor_eval/norm_succ_max_dist': max_dist/sqrt_d,
        'successor_eval/norm_succ_min_dist': min_dist/sqrt_d,
    }
    skill_metrics = {
                     'eval/skill_avg_dist': avg_dist,
                     'eval/skill_max_dist': max_dist,
                     'eval_iter' : iteration
                     }
    skill_metrics.update(disc_metrics)
    if avg_disdain_per_skill:
        skill_metrics['eval/disdain'] = avg_disdain_rew

    skill_metrics.update(succ_metrics)
    skill_metrics.update(proj_succ_metrics)
    skill_metrics.update(succ_metrics2)

    avg_mu = np.mean(mu_per_skill)
    avg_std = np.mean(std_per_skill)
    skill_metrics['eval/task_reward_mu'] = avg_mu
    skill_metrics['eval/task_reward_std'] = avg_std

    if make_gif:
        skill_videos = [wandb.Video(np.array(imgs).transpose(0, 3, 1, 2), fps=fps_, format="mp4") for imgs in
                        imgs_per_skill]
        skill_metrics.update({'eval/skill_video': skill_videos})
        skill_metrics.update(expert_skill_video)
    if log_wandb:
        wandb.log(skill_metrics)
    return avg_mu, avg_std, skill_metrics, observations_per_skill, observations_per_skill_agent


def save_checkpoint(state, filename=None):
    if filename is None:
        filename = 'last.pth.tar'
    working_dir = Ticker.working_dir
    os.makedirs(os.path.join(working_dir,'checkpoints'), exist_ok=True)
    torch.save(state, os.path.join(working_dir,'checkpoints', filename))

def load_checkpoint(filename='last.pth.tar'):
    if filename is None:
        filename = 'last.pth.tar'
    working_dir = Ticker.working_dir
    return torch.load(os.path.join(working_dir,'checkpoints', filename))

import smart_settings

DEFAULT_PATH = ''

def maybe_load_last_checkpoint(config=None):
    if config is None:
        config = Ticker.config
    working_dir = config.working_dir
    pretrained_models_path = config.pretrained_models_path

    if pretrained_models_path == 'save':
        """Recompute everything and save as default."""
        return None, None

    ckpt = None
    if not config.load_pretrained_only:
        try:
            ckpt = torch.load(os.path.join(working_dir,'checkpoints', 'last.pth.tar'))
            print('\033[91m' + 'Loaded agent checkpoint!' + '\033[0m')
        except Exception as e:
            print('\033[91m' + 'No checkpoint loaded :(' + '\033[0m')
            ckpt =  None

    if pretrained_models_path and pretrained_models_path != 'save':
        if pretrained_models_path == 'default':
            pretrained_models_path = os.path.join(f'{DEFAULT_PATH}/{config.env_name}_{config.dataset}_pretrain_last.pth.tar')
            assert os.path.exists(pretrained_models_path), f'Pretrained model path {pretrained_models_path} does not exist!'
        else:
            pretrained_models_path = os.path.join(pretrained_models_path, 'checkpoints', 'pretrain_last.pth.tar')
        print('\033[91m' + 'Loading pretrained models from {}'.format(pretrained_models_path) + '\033[0m')
        pretrain_ckpt = torch.load(pretrained_models_path, map_location=torch.device('cuda:0'))
    else:
        try:
            pretrain_ckpt = torch.load(os.path.join(working_dir,'checkpoints', 'pretrain_last.pth.tar'))
            print('\033[91m' + 'Loaded pretrain checkpoint!' + '\033[0m')

        except Exception as e:
            pretrain_ckpt = None

    if ckpt:
        print('\033[94m' + 'Loading settings from {}'.format(working_dir) + '\033[0m')
        # load previous settings file
        settings = smart_settings.load(os.path.join(working_dir, 'config.yaml'), make_immutable=False)
        # set wandb id TODO not good practice
    else:
        with open(os.path.join(working_dir, 'config.yaml'), 'w') as f:
            yaml.dump(recursive_dictify(config), f, default_flow_style=False)

    return ckpt, pretrain_ckpt

def sequence_dataset(env, dataset=None, sparse=False, **kwargs):
    """
    Returns an iterator through trajectories.
    Args:
        env: An OfflineEnv object.
        dataset: An optional dataset to pass in for processing. If None,
            the dataset will default to env.get_dataset()
        sparse: if set True, return a trajectory where sparse reward of 1 is attained.
        **kwargs: Arguments to pass to env.get_dataset().
    Returns:
        An iterator through dictionaries with keys:
            observations
            actions
            rewards
            terminals
    """
    if dataset is None:
        dataset = env.get_dataset(**kwargs)

    N = dataset['rewards'].shape[0]
    data_ = collections.defaultdict(list)

    fields = ['actions', 'observations', 'rewards', 'terminals']
    if 'infos/qpos' in dataset:
        fields.append('infos/qpos')
        fields.append('infos/qvel')
    # The newer version of the dataset adds an explicit
    # timeouts field. Keep old method for backwards compatability.
    use_timeouts = False
    if 'timeouts' in dataset:
        use_timeouts = True
        fields.append('timeouts')

    episode_step = 0
    if 'next_observations' in dataset.keys():
        fields.append('next_observations')

    for i in range(N):
        done_bool = bool(dataset['terminals'][i])
        if use_timeouts:
            final_timestep = dataset['timeouts'][i]
        else:
            final_timestep = (episode_step == env._max_episode_steps - 1)

        for k in fields:
            data_[k].append(dataset[k][i])

        if done_bool or final_timestep:
            episode_step = 0
            episode_data = {}
            for k in data_:
                episode_data[k] = np.array(data_[k])

            if sparse:
                if 1 in episode_data['rewards']:
                    yield episode_data
                else:
                    continue
            else:
                yield episode_data
            data_ = collections.defaultdict(list)

        episode_step += 1

def add_absorbing_state(dataset):
    N = dataset['observations'].shape[0]
    obs_dim = dataset['observations'].shape[1]
    dataset['observations'] = np.concatenate((dataset['observations'], np.zeros((N, 1))), axis=1).astype(np.float32)
    dataset['next_observations'] = np.concatenate((dataset['next_observations'], np.zeros((N, 1))), axis=1).astype(np.float32)
    terminal_indices = np.where(dataset['terminals'])[0]
    absorbing_state = np.eye(obs_dim + 1)[-1].astype(np.float32)
    dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
        list(dataset['observations']), list(dataset['actions']), list(dataset['rewards']), list(dataset['next_observations']), list(dataset['terminals'])
    for terminal_idx in terminal_indices:
        dataset['next_observations'][terminal_idx] = absorbing_state
        dataset['observations'].append(absorbing_state)
        dataset['actions'].append(dataset['actions'][terminal_idx])
        dataset['rewards'].append(0)
        dataset['next_observations'].append(absorbing_state)
        dataset['terminals'].append(1)

    dataset['observations'], dataset['actions'], dataset['rewards'], dataset['next_observations'], dataset['terminals'] = \
        np.array(dataset['observations'], dtype=np.float32), np.array(dataset['actions'], dtype=np.float32), np.array(dataset['rewards'], dtype=np.float32), \
        np.array(dataset['next_observations'], dtype=np.float32), np.array(dataset['terminals'], dtype=np.float32)
    return dataset

def get_keys(h5file):
    keys = []

    def visitor(name, item):
        if isinstance(item, h5py.Dataset):
            keys.append(name)

    h5file.visititems(visitor)
    return keys

def get_dataset(h5path):
    data_dict = {}
    with h5py.File(h5path, 'r') as dataset_file:
        for k in tqdm(get_keys(dataset_file), desc="load datafile"):
            try:  # first try loading as an array
                data_dict[k] = dataset_file[k][:]
            except ValueError as e:  # try loading as a scalar
                data_dict[k] = dataset_file[k][()]

    return data_dict

def makedir(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    return path



######### our additions #################

import torch
import numpy as np
def efficient_from_numpy(x):
    if torch.cuda.is_available() and Ticker.device == 'cuda':
        if isinstance(x, torch.Tensor):
            return x.cuda(non_blocking=True)
        return torch.from_numpy(x).cuda(non_blocking=True)
    else:
        if isinstance(x, torch.Tensor):
            return x
        else:
            return torch.from_numpy(x)
        

class SmodiceLoader:
    batch_keys = ['observations', 'actions', 
                  'rewards', 'terminals', 
                  'next_observations', 
                  'experts', 'classifier_rewards',
                  'expert_w_e', 'expert_e_v']
    def __init__(self, 
                    batch_size, 
                    dataset, 
                    initial_obs_dataset,
                    dataset_statistics, 
                    shuffle):
        self._batch_size = batch_size 
        self._dataset = dataset
        self._w_e_dataset = None
        self._e_v_dataset = None
        self._initial_obs_dataset = initial_obs_dataset
        self._dataset_statistics = dataset_statistics
        self._shuffle = shuffle
        self._dataset_size = dataset_statistics['N']

    def __len__(self):
        return self._dataset_size

    def set_w_e(self, w_e_dataset):
        self._w_e_dataset = w_e_dataset
    
    def set_batch_size(self, batch_size):
        self._batch_size = batch_size
        
    def set_shuffle(self, shuffle):
        self._shuffle = shuffle

    def __iter__(self):
        def _iterator():
            dataset_idxs = np.arange(self._dataset_statistics['N'])
            if self._shuffle:
                dataset_idxs = np.random.permutation(dataset_idxs)
            for i in range(0, self._dataset_statistics['N'], self._batch_size):
                last_idx = min(self._dataset_statistics['N'], i + self._batch_size)
                batch_idxs = dataset_idxs[i:last_idx]
                initial_indices = np.random.randint(0, self._dataset_statistics['N_initial_observations'], len(batch_idxs))
                initial_observations = self._initial_obs_dataset['initial_observations'][initial_indices]

                batch = {k: efficient_from_numpy(self._dataset[k][batch_idxs]) for k in self._dataset if k in self.batch_keys}
                batch['initial_observations'] = efficient_from_numpy(initial_observations)
                if self._w_e_dataset is not None:
                    batch['w_e'] = efficient_from_numpy(self._w_e_dataset['w_e'][batch_idxs])
                    batch['e_v'] = efficient_from_numpy(self._w_e_dataset['e_v'][batch_idxs])
                
                batch['idxs'] = efficient_from_numpy(batch_idxs)
                yield batch                 
        return _iterator()
    
def stable_log_mean_exp(x, dim=0, offset=0.0):
    max_x, _ = torch.max(x, dim=dim, keepdim=True)
    return torch.log(torch.mean(torch.exp(x - max_x.detach()), dim=dim) + offset) + max_x.detach()

def weighted_log_mean_exp(x, w, dim=0):
    max_x, _ = torch.max(x, dim=dim, keepdim=True)
    exp = torch.exp(x - max_x.detach())
    return torch.log(torch.sum(w * exp, dim=dim)) + max_x.detach()

def stable_log_sum_exp(x, dim=0):
    max_x, _ = torch.max(x, dim=dim, keepdim=True)
    return torch.log(torch.sum(torch.exp(x - max_x.detach())), dim=dim) + max_x.detach()

def stable_softmax(x, dim=0):
    max_x, _ = torch.max(x, dim=dim, keepdim=True)
    return torch.exp(x - max_x) / torch.sum(torch.exp(x - max_x), dim=dim, keepdim=True)
