import os

from models.dan_nn_wrapper import DANWrapper
from rgb_arrays_to_mp4 import rgb_arrays_to_mp4
os.environ["D4RL_SUPPRESS_IMPORT_ERROR"] = "1"
import numpy as np
import pickle
import torch
import torch.multiprocessing as mp
from nn_util import construct_env, eval_over, get_action_from_obs_batched, env_to_rgb_array, get_processed_obs, robotwin_reset, set_seed
import time
import torch.distributed as dist
from logging_utils import logger
import copy
import cv2

try:
    profile
except NameError:
    def profile(func):
        return func

DEBUG = False
persistent_processes = [None] * 8
class PersistentProcessPool:
    def __init__(self, config, start_trial, num_trials, rank):
        self.command_queues = [mp.Queue() for _ in range(num_trials)]
        self.result_queue = mp.Queue()
        self.config = config
        self.workers = []

        for i in range(num_trials):
            p = mp.Process(target=env_worker, 
                          args=(i, i + start_trial, self.config, self.command_queues[i], 
                               self.result_queue, rank, True))
            p.start()
            self.workers.append(p)

        import signal
        import atexit
        atexit.register(self.shutdown)
        signal.signal(signal.SIGTERM, self._signal_handler)
        signal.signal(signal.SIGINT, self._signal_handler)

    def _signal_handler(self):
        self.shutdown()
        exit(0)

    def shutdown(self):
        for queue in self.command_queues:
            queue.put(None)
        for worker in self.workers:
            worker.join()

@profile
def single_trial_eval(config, agent, env, trial, reset, dan, debug_data=None, split_agent=None, step_offset=0, reward_offset=0, pre_obs_history=None, dump_trial=False, dump_reward=False, debug_stop=1, split_step=[]):
    env_name = config['name']
    is_robosuite = config.get('robosuite', False)
    is_robotwin = config.get('robotwin', False)
    is_atari = env_name.startswith("ALE")
    cam_names = config.get("cams", [])
    crops = config.get('crops', {})
    #act_horizon = agent.policy_cfg.get("act_horizon", 1)
    act_horizon = 1

    if act_horizon > 1:
        act_size = agent.datasets['retrieval'].act_matrix[0][0].shape[-1]
        pred_act_matrix = np.empty([1, act_horizon, act_size])
    
    video_frames = []

    if not is_robosuite and not is_atari and not is_robotwin:
        env.seed(trial)

    if reset:
        if is_robotwin:
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            print(f"rank {local_rank} starting Trial {trial}")
            found_bad_seeds = []
            good_env = False

            while not good_env:
                try:
                    robotwin_reset(env, env_name, seed=trial, gpu_id=local_rank)
                    good_env = True
                except:
                    import traceback
                    traceback.print_exc()
                    found_bad_seeds.append(trial)
                    trial += 100
                # robotwin_reset(env, env_name, seed=trial, gpu_id=local_rank)
                # good_env = True

            if len(found_bad_seeds):
                logger.info(f"ADD BAD SEEDS {found_bad_seeds}")

            observation = env.get_obs()
        else:
            observation = env.reset()
            if config['name'].startswith("ALE"):
                observation = observation[0]
    else:
        if is_robosuite:
            observation = env.get_observation()
        else:
            observation = env.env._wrapped_env._get_obs()

    if env_name == "maze2d-umaze-v1":
        env.set_target()

    if config['name'] == "maze2d-umaze-v1":
        observation = np.hstack((env._target, observation))

    if (dan and agent.retrieval_agent.lookback > 1) or split_agent is not None or dump_trial:
        if pre_obs_history is not None:
            obs_history = pre_obs_history
            obs_history.to(agent.device)
        else:
            obs_history = torch.empty((1, 0, 0), device=agent.device)
    else:
        obs_history = None

    if dump_reward:
        rewards = []

    episode_reward = reward_offset
    steps = step_offset

    done = False
    while not (steps > 0 and (done or eval_over(steps, config, env))):
        should_split = steps in split_step if not len(split_step) == 0 else steps % 10 == 0
        if should_split and split_agent is not None:
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            if config.get("robosuite", False):
                split_env = construct_env(config, gpu_id=local_rank)
                for _ in range(trial):
                    split_env.reset()
                state = env.get_state()
                state['model'] = env.env.sim.model.get_xml()
                split_env.reset_to(state)
            else:
                split_env = copy.deepcopy(env)
                unobserved_nq = 1
                nq = env.model.nq - unobserved_nq
                nv = env.model.nv
                split_env.set_state(np.hstack((np.zeros(unobserved_nq), observation[:nq])), observation[-nv:])
            reward, _ = single_trial_eval(config, split_agent, split_env, trial, False, not dan, pre_obs_history=obs_history, step_offset=steps, reward_offset=episode_reward, dump_trial=True, debug_stop=debug_stop, debug_data=debug_data)
            if not dan:
                split_agent.episode_queries.append(split_agent.queries)
                split_agent.episode_deltas.append(split_agent.deltas)
                split_agent.queries = []
                split_agent.deltas = []
            print(f"Fork at step {steps}: {reward}")
            # if steps == 60:
                # ood_state = obs_history[-1]
            pickle.dump((split_agent.episode_queries[0][0], split_agent.episode_deltas[0][0].unsqueeze(0)), open(f"results/fork_data_{trial}.pkl", 'wb'))
            #     pass
            # if steps == 400:
            #     pickle.dump((torch.stack(split_agent.episode_queries), torch.stack(split_agent.episode_deltas)), open("results/fork_data.pkl", 'wb'))
            #     pass

        if debug_data is not None and steps < debug_stop:
            env.reset()
            if config.get("robosuite", False):
                # config['name'] = ''
                initial_state = dict(states=debug_data[trial]['states'][steps])
                initial_state["model"] = debug_data[trial]["model_file"]
                env.reset_to(initial_state)
                #print(f"{trial}, {steps}")
                # observation = debug_data[trial]['observations'][steps]
                observation = env.get_observation()
            else:
                unobserved_nq = 1
                nq = env.model.nq - unobserved_nq
                nv = env.model.nv
                env.set_state(
                    np.hstack((np.zeros(unobserved_nq), debug_data[trial]['observations'][steps][:nq])), 
                    debug_data[trial]['observations'][steps][-nv:])
                observation = debug_data[trial]['observations'][steps]

        height, width = 224, 224
        if len(cam_names) > 0:
            full_frame = np.empty((height, 0, 3), dtype=np.uint8)
            for i, camera in enumerate(cam_names):
                if is_robotwin:
                    frame = observation['observation'][camera]['rgb']
                    frame = cv2.resize(frame, (224, 224))
                    #import matplotlib.pyplot as plt
                    #plt.imsave(f"{i}.png", frame)
                else:
                    crop_corners = np.array(crops.get(camera, [[0, 0], [1.0, 1.0]]))
                    frame = env_to_rgb_array(env, camera, crop_corners, width, height)

                full_frame = np.hstack((full_frame, frame))
        else:
            if is_atari:
                observation = observation['image_observation'].transpose(2, 0, 1).reshape((1, 84 * 4, 84, 1)).astype(np.uint8)
                full_frame = observation
                #full_frame = cv2.resize(full_frame, (84, 84), cv2.INTER_AREA)
                #full_frame = cv2.cvtColor(full_frame, cv2.COLOR_RGB2GRAY)
            else:
                full_frame = env.render(mode='rgb_array', height=height, width=width, camera_name="agentview")
                #full_frame = np.empty((height, 0, 3), dtype=np.uint8)
            #frame = []

        #pickle.dump([{'observations': [observation[0, :84].reshape((84, 84)).flatten()], 'actions': [[0]]}], open("atari_debug.pkl", 'wb'))

        step_start = time.time()
        with torch.no_grad():
            if not ((split_agent is not None or dump_trial) and not (dan and agent.retrieval_agent.lookback > 1)):
                action, obs_history = get_action_from_obs_batched(config, agent, [env], [observation], [full_frame], obs_history=obs_history, numpy_action=False, is_first_ob=(steps == 0))
            else:
                action, _ = get_action_from_obs_batched(config, agent, [env], [observation], [full_frame], obs_history=None, numpy_action=False, is_first_ob=(steps == 0))

                observation = get_processed_obs(observation, full_frame, env, agent, config, config['type'])[0]
                if obs_history.shape[2] == 0:
                    obs_history = torch.empty((obs_history.shape[0], 0, observation.shape[-1]), device=obs_history.device)
                obs_history = torch.cat((obs_history, observation.unsqueeze(0).unsqueeze(0)), dim=1)


            if debug_data is not None:
                pass
                # print(f"Action diff: {action - debug_data[trial]['actions'][steps]}")
        #logger.debug(f"Time for step: {time.time() - step_start}")
        #logger.debug(f"{steps=}, {trial=}")
        action = action[0]

        if is_atari:
            full_frame = np.copy(observation).reshape((84 * 4, 84))[:84].flatten()
        video_frames.append(full_frame)
        if act_horizon > 1:
            # Shape [steps, steps + act_horizon, action size]
            pred_act_matrix[steps][steps:steps+act_horizon] = action

            # Shape [act_horizon, action_size]
            all_timestep_actions = pred_act_matrix[steps:steps+act_horizon, steps]

            DECAY = -0.3
            weights = np.exp(DECAY * np.arange(len(all_timestep_actions)))
            weights /= np.sum(weights)

            action = np.sum(all_timestep_actions * np.flip(weights, axis=0), axis=0)

            # Make room for the next action
            # Shape [steps, steps + act_horizon + 1, action size]
            pred_act_matrix = np.pad(pred_act_matrix, ((0, 0), (0, 1), (0, 0)), mode='constant')

            # Shape [steps + 1, steps + act_horizon + 1, action size]
            pred_act_matrix = np.pad(pred_act_matrix, ((0, 1), (0, 0), (0, 0)), mode='constant')

        if is_atari:
            action = np.argmax(action)
            observation, reward, terminated, truncated, info = env.step(action)[:5]
            done = terminated | truncated

            # Handle "fake done" for atari
            if done and False:
                print(f"Maybe done? {info}")
                if "episode" not in info:
                    observation = env.reset()[0]
                    done = False
        else:
            if is_robotwin:
                #left_action = Action(ArmTag("left"), "move", action[:6], action[6])
                #right_action = Action(ArmTag("right"), "move", action[7:13], action[13])
                #good_move = env.move(actions_by_arm1=(left_action.arm_tag, [left_action]), actions_by_arm2=(right_action.arm_tag, [right_action]))
                env.take_action(action)
                observation = env.get_obs()
                done = env.check_success()
                reward = 1.0 if done else 0.0
                done = done or steps >= 200
                info = env.info
            else:
                observation, reward, done, info = env.step(action)[:4]

            if config['name'] == "maze2d-umaze-v1":
                observation = np.hstack((env._target, observation))

        if env_name == "push_t":
            episode_reward = max(episode_reward, reward)
        else:
            episode_reward += reward
            if is_robosuite and ((not config.get('reward_shaping', False) and episode_reward > 0) or (config.get('reward_shaping', False) and done)):
                break

            # env.render(mode='human')

        steps += 1
        if dump_reward:
            rewards.append(reward)
        #print(f"Trial {trial}: step {steps}")

    # trials_to_render = [0, 1, 5, 7, 9, 10, 15, 18, 19, 21, 24, 25, 26, 29, 30, 32, 35, 40, 41, 45, 46, 47, 48, 49, 54, 57, 59, 62, 65, 66, 67, 77, 81, 82, 83, 84, 86, 90, 93]

    # if len(video_frames) > 0 and False:
    # if len(video_frames) > 0 and trial in trials_to_render and False:
    if len(video_frames) > 0:
        video_frames = np.array(video_frames)
        #pickle.dump(video_frames, open(f"data/{trial}.pkl", 'wb'))
        rgb_arrays_to_mp4(video_frames, f"data/{trial}.mp4", greyscale=is_atari)

    if dump_trial:
        if debug_data is not None:
            if dump_reward:
                pickle.dump(rewards, open(f"data/{trial}_{'dan' if dan else 'bc'}_reward_{env_name}.pkl", 'wb'))
            else:
                pickle.dump(obs_history.detach().cpu().numpy(), open(f"data/{trial}_{'dan' if dan else 'bc'}_{env_name}.pkl", 'wb'))
            # pickle.dump(obs_history.detach().cpu().numpy(), open(f"data/{debug_stop - 1}_{'dan' if dan else 'bc'}.pkl", 'wb'))
        else:
            pickle.dump(obs_history.detach().cpu().numpy(), open(f"data/{trial}.pkl", 'wb'))

    success = 1 if 'success' in info else 0

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    print(f"rank {local_rank} finished Trial {trial}")

    return episode_reward, success

def prepare_env(config, trial=None, reset=False, gpu_id=0):
    import robosuite.renderers.context.egl_context as egl_context
    egl_context.EGL_DISPLAY = None

    set_seed(42)
    env = construct_env(config, gpu_id=gpu_id, seed=trial)
    env_name = config['name']
    is_robosuite = config.get('robosuite', False)
    is_robotwin = config.get('robotwin', False)
    is_atari = env_name.startswith("ALE")


    if not is_robosuite and not is_atari and not is_robotwin:
        env.seed(trial)
    elif is_robosuite:
        for _ in range(trial):
            env.reset()

    if env_name == "maze2d-umaze-v1":
        env.set_target()

    return env

def env_worker(worker_id, trial, config, command_queue, result_queue, local_rank, reset):
    import psutil
    p = psutil.Process()
    available_cpus = list(os.sched_getaffinity(0))
    p.cpu_affinity([available_cpus[worker_id % len(available_cpus)]])

    os.environ['CUDA_VISIBLE_DEVICES'] = f"{local_rank}"
    is_robosuite = config.get('robosuite', False)
    is_robotwin = config.get('robotwin', False)
    env = None
    logger.debug(f"{local_rank}:{worker_id} created")
    try:
        while True:
            command = command_queue.get()
            if command is None:
                break
                
            cmd_type, data = command
            
            if cmd_type == 'step':
                assert env is not None
                action, step_num = data
                #logger.debug(f"{local_rank}:{worker_id} step {step_num}")
                if is_robotwin:
                    from robotwin.envs.utils.action import ArmTag, Action
                    left_action = Action(ArmTag("left"), "move", action[:6], action[6])
                    right_action = Action(ArmTag("right"), "move", action[7:12], action[12])
                    env.move(actions_by_arm1=(left_action.arm_tag, [left_action]), actions_by_arm2=(right_action.arm_tag, [right_action]))
                    observation = env.get_obs()
                    done = env.check_success()
                    reward = 1.0 if done else 0.0
                    done = done or step_num >= 200
                    info = env.info
                else:
                    observation, reward, done, info = env.step(action)[:4]

                if config['name'] == "maze2d-umaze-v1":
                    observation = np.hstack((env._target, observation))

                if is_robosuite and not config.get('reward_shaping', False) and reward > 0:
                    done = True

                done = done or eval_over(step_num, config, env)

                if done and not is_robosuite:
                    del env
                    env = None

                result_queue.put((worker_id, 'step_result', {
                    'observation': observation,
                    'reward': reward, 
                    'done': done,
                    'info': info
                }))
            elif cmd_type == 'init_trial':
                if is_robosuite and env is not None:
                    #logger.debug(f"{local_rank}:{worker_id} loading state")
                    env.reset_to(initial_state)
                elif config['name'] == "push_t" and env is not None:
                    env._set_state(initial_state)
                else:
                    if env is not None:
                        del env

                    logger.debug(f"{local_rank}:{worker_id} creating env for the first time")
                    env = prepare_env(config, trial=trial, gpu_id=local_rank, reset=reset)
                    logger.debug(f"{local_rank}:{worker_id} done env")

                    if is_robosuite:
                        env.reset()
                        initial_state = env.get_state()
                        initial_state['model'] = env.env.sim.model.get_xml()
                        env.reset_to(initial_state)
                        reset = False
                    elif config['name'] == "push_t":
                        env.reset()
                        initial_state = env._get_obs()
                        env._set_state(initial_state)
                        reset = False

                if reset:
                    if is_robotwin:
                        observation = env.get_obs()
                    else:
                        observation = env.reset()

                    if config['name'].startswith("ALE") or config['name'] == "push_t":
                        observation = observation[0]
                else:
                    observation = env.get_observation()

                if config['name'] == "maze2d-umaze-v1":
                    observation = np.hstack((env._target, observation))

                result_queue.put((worker_id, 'env_created', observation))
                
            elif cmd_type == 'get_frame':
                camera, crop_corners, width, height = data
                if is_robotwin:
                    env_obs = env.get_obs()
                    frame = env_obs['observation'][camera]['rgb']
                    frame = cv2.resize(frame, (224, 224))
                else:
                    frame = env_to_rgb_array(env, camera, crop_corners, width, height)
                result_queue.put((worker_id, 'frame', frame))
                # print(f"{worker_id} sending frame")
                
    except Exception as e:
        result_queue.put((worker_id, 'error', str(e)))

#@profile
def batched_nn_eval(config, agent, trials=10, results=None, reset=False, dan=True):
    global persistent_processes
    start = time.time()
    mp.set_start_method('spawn', force=True)
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))
    set_seed(42)

    if world_size > 1 and not dist.is_initialized():
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)

    trials_per_proc = trials // world_size
    remainder = trials % world_size

    my_num_trials = trials_per_proc + (1 if rank < remainder else 0)
    start_trial = rank * trials_per_proc + min(rank, remainder)
    end_trial = start_trial + my_num_trials

    logger.info(f"GPU {local_rank} taking trials {start_trial + 1} to {end_trial}")

    env_name = config['name']
    is_robosuite = config.get('robosuite', False)
    is_atari = env_name.startswith("ALE")
    cam_names = config.get("cams", [])
    crops = config.get('crops', {})
    height, width = 224, 224

    video_frames = []
    command_queues = [mp.Queue() for _ in range(my_num_trials)]
    result_queue = mp.Queue()

    # [B, H, O]
    if dan and agent.retrieval_agent.lookback > 1:
        obs_history = torch.empty((my_num_trials, 0, 0), device=local_rank)

    # if rank == 0:
    #     logger.debug(f"Time pre-process creation: {time.time() - start}")
        
    if persistent_processes[local_rank] == None:
        logger.info("Creating PersistentProcessPool")
        persistent_processes[local_rank] = PersistentProcessPool(config, start_trial, my_num_trials, local_rank)

    command_queues = persistent_processes[local_rank].command_queues
    result_queue = persistent_processes[local_rank].result_queue

    for i in range(my_num_trials):
        command_queues[i].put(('init_trial', None))

    # if rank == 0:
    #     logger.debug(f"Time post-process creation: {time.time() - start}")

    observations = [None] * my_num_trials
    envs_created = 0
    
    #print("Creating environments...")
    while envs_created < my_num_trials:
        worker_id, msg_type, data = result_queue.get()
        if msg_type == 'env_created':
            observations[worker_id] = data
            envs_created += 1
            logger.debug(f"{rank} Environment {worker_id + 1 + start_trial}/{trials} created, {(envs_created/my_num_trials) * 100:.0f}% of my envs created")
        elif msg_type == 'error':
            logger.error(f"Error in worker {worker_id}: {data}")
            raise RuntimeError(f"Failed to create environment in worker {worker_id}: {data}")
    # logger.debug(f"Time post-env creation: {time.time() - start}")

    # Construct environment on each process
    print(f"1 Allocated memory: {torch.cuda.memory_allocated() / (1024**2):.2f} MB")
    try:
        dones = np.zeros(my_num_trials).astype(bool)
        episode_rewards = np.zeros(my_num_trials)
        steps = 0

        if hasattr(agent, "reset_obs_history"):
            agent.reset_obs_history()

        while not (steps > 0 and np.all(dones)):
            # if steps == 1 and rank == 0:
            #     logger.debug(f"Time at second step: {time.time() - start}")
            frames = None
            #print("Requesting frames...")
            if len(cam_names) > 0:
                for i, queue in enumerate(command_queues):
                    if not dones[i]:
                        for camera in cam_names:
                            crop_corners = np.array(crops.get(camera, [[0, 0], [1.0, 1.0]]))
                            #print(f"Request frame from {i}")
                            queue.put(('get_frame', (camera, crop_corners, width, height)))
                frames = [[] for _ in range(my_num_trials)]
                expected_frames = sum(len(cam_names) for i in range(my_num_trials) if not dones[i])
                for _ in range(expected_frames):
                    worker_id, msg_type, frame = result_queue.get()
                    frames[worker_id].append(frame)
                
                for i in range(my_num_trials):
                    if frames[i]:
                        frames[i] = np.hstack(frames[i])

            active_envs = [i for i in range(my_num_trials) if not dones[i]]
            if not active_envs:
                break
                
            active_observations = [observations[i] for i in active_envs]
            active_frames = [frames[i] for i in active_envs] if frames else None
            #print("Got all frames")

            #print("Getting actions...")
            with torch.no_grad():
                if dan and agent.retrieval_agent.lookback > 1:
                    actions, new_obs_history = get_action_from_obs_batched(config, agent, active_envs, active_observations, active_frames, numpy_action=False, is_first_ob=(steps == 0), obs_history=obs_history[active_envs])
                    if steps == 0:
                        obs_history = new_obs_history
                    else:
                        full_new_obs_history = torch.zeros((my_num_trials, 1, new_obs_history.shape[-1]), device=agent.device)
                        full_new_obs_history[active_envs] = new_obs_history[:, -1].unsqueeze(1)
                        obs_history = torch.cat((obs_history, full_new_obs_history), dim=1)
                else:
                    actions, _ = get_action_from_obs_batched(config, agent, active_envs, active_observations, active_frames, numpy_action=False, is_first_ob=(steps == 0))
            #print("Got actions") 

            #print("Stepping...")
            for idx, action in zip(active_envs, actions):
                command_queues[idx].put(('step', (action, steps)))

            for _ in range(len(active_envs)):
                worker_id, msg_type, result = result_queue.get()
                observations[worker_id] = result['observation']
                if env_name == "push_t":
                    episode_rewards[worker_id] = max(episode_rewards[worker_id], result['reward'])
                else:
                    episode_rewards[worker_id] += result['reward']
                dones[worker_id] = result['done']
            #print("Got step results")

            steps += 1
            # print(f"Step {steps} on {rank}")

        # if rank == 0:
        #     logger.debug(f"Time at end of eval: {time.time()}")

    finally:
        pass
        # for queue in command_queues:
        #     queue.put(None)  # Shutdown signal
        #
        # for worker in workers:
        #     worker.join()

    #logger.info(f"GPU {local_rank} all envs done")
    if world_size > 1:
        # Create tensors to gather results
        all_rewards = [None for _ in range(world_size)]
        all_successes = [None for _ in range(world_size)]

        # Gather rewards and successes
        dist.all_gather_object(all_rewards, episode_rewards)
        #dist.all_gather_object(all_successes, successes)

        # Flatten rewards list and sum successes
        episode_rewards = [r for proc_rewards in all_rewards for r in proc_rewards]
        #successes = sum(all_successes)

    # Save results (only on rank 0)
    if rank == 0:
        print(episode_rewards)
        if results is not None:
            os.makedirs('results', exist_ok=True)
            with open(f"results/{results}.pkl", 'wb') as f:
                pickle.dump(episode_rewards, f)

        logger.info(f"mean {round(np.mean(episode_rewards), 2)}, std {round(np.std(episode_rewards), 2)}")

    # Wait for all processes
    if world_size > 1:
        dist.barrier()

    print(f"2 Allocated memory: {torch.cuda.memory_allocated() / (1024**2):.2f} MB")
    return np.mean(episode_rewards)

#@profile
def parallel_nn_eval(config, nn_agent, trials=10, results=None, dan=False, split_agent=None, dump_trial=False):
    # Initialize the distributed environment
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))

    # Set up the process group
    if world_size > 1 and not dist.is_initialized():
        dist.init_process_group(backend="nccl")
        torch.cuda.set_device(local_rank)
        nn_agent = nn_agent.to(local_rank)

    # Construct environment on each process
    set_seed(42)
    
    DEBUG = True
    if DEBUG:
        #debug_data = pickle.load(open("data/stack_task_D0/100_val.pkl", 'rb'))
        # debug_data = pickle.load(open("data/hopper-expert-v2_10_val.pkl", 'rb'))
        debug_data = pickle.load(open("data/hopper-expert-v2_1.pkl", 'rb'))

    #initial_sanity(env, nn_agent, "data/stack_task_D0/100_vae.pkl", config)
    #initial_sanity(env, nn_agent.datasets['retrieval'].name, config, obs_scaler=nn_agent.datasets['retrieval'].obs_scaler)

    # Divide trials among processes
    trials_per_proc = trials // world_size
    remainder = trials % world_size

    # Distribute remaining trials evenly
    my_num_trials = trials_per_proc + (1 if rank < remainder else 0)
    start_trial = rank * trials_per_proc + min(rank, remainder)

    # Stack Val
    #start_trial = 6

    # Stack Fork
    # start_trial = 65

    # Hopper Fork
    #start_trial = 2

    env = construct_env(config, gpu_id=local_rank, seed=start_trial)

    # No seeding, so this is the only way to make sure we don't do repeat seeds 
    if config.get('robosuite', False):
        for _ in range(start_trial):
            env.reset()

    end_trial = start_trial + my_num_trials

    logger.info(f"GPU {local_rank} taking trials {start_trial + 1} to {end_trial}")

    # Run assigned trials
    episode_rewards = []
    successes = 0

    if DEBUG:
        # for debug_stop_step in range(1):
        # split_steps = [60, 108, 172, 148]
        # split_steps = [999, 999, 210, 999, 260, 999, 999, 999, 60, 999]
        # split_steps = [49, 999, 999, 999, 47, 999, 51, 999, 999, 999]
        split_steps = [999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
        #for debug_stop_step in range(0, 100, 10):
        # for trial in range(9):
        # Hopper
        # for trial in [2, 4, 8]:
        # Stack
        for trial in [0, 4, 6]:
            # episode_reward, success = single_trial_eval(config, nn_agent, env, trial, reset=not DEBUG, dan=dan, debug_data=debug_data if DEBUG else None, split_agent=split_agent, dump_trial=True, debug_stop=1, split_step=[split_steps[trial]], dump_reward=True)
            episode_reward, success = single_trial_eval(config, nn_agent, env, trial, reset=not DEBUG, dan=dan, debug_data=debug_data if DEBUG else None, split_agent=split_agent, dump_trial=True, debug_stop=1, split_step=[split_steps[trial]], dump_reward=False)
            # episode_reward, success = single_trial_eval(config, nn_agent, env, start_trial, reset=not DEBUG, dan=dan, debug_data=debug_data if DEBUG else None, split_agent=split_agent, dump_trial=True, debug_stop=debug_stop_step + 1)
            episode_rewards.append(episode_reward)
            successes += success
    else:
        for trial in range(start_trial, end_trial):
            #print(f"Process {rank} running trial {trial}")
            start_time = time.time()
            episode_reward, success = single_trial_eval(config, nn_agent, env, trial, reset=not DEBUG, dan=dan, debug_data=debug_data if DEBUG else None, split_agent=split_agent, dump_trial=dump_trial)

            #print(f"Process {rank}, Trial {trial}: reward={episode_reward}, time={time.time() - start_time}s")
            episode_rewards.append(episode_reward)
            successes += success

    # Gather results from all processes
    if world_size > 1:
        # Create tensors to gather results
        all_rewards = [None for _ in range(world_size)]
        all_successes = [None for _ in range(world_size)]

        # Gather rewards and successes
        dist.all_gather_object(all_rewards, episode_rewards)
        dist.all_gather_object(all_successes, successes)

        # Flatten rewards list and sum successes
        episode_rewards = [r for proc_rewards in all_rewards for r in proc_rewards]
        successes = sum(all_successes)

    # Save results (only on rank 0)
    if rank == 0:
        if results is not None:
            os.makedirs('results', exist_ok=True)
            with open(f"results/{results}.pkl", 'wb') as f:
                pickle.dump(episode_rewards, f)

        # print(
        #     f"Candidates {nn_agent.candidates}, lookback {nn_agent.lookback}, "
        #     f"decay {nn_agent.decay}, ratio {nn_agent.final_neighbors_ratio}: "
        #     f"mean {round(np.mean(episode_rewards), 2)}, std {round(np.std(episode_rewards), 2)}"
        # )
        logger.debug(episode_rewards)
        logger.info(f"mean {round(np.mean(episode_rewards), 2)}, std {round(np.std(episode_rewards), 2)}")

    # Wait for all processes
    if world_size > 1:
        dist.barrier()

    return np.mean(episode_rewards)

def nn_eval(config, nn_agent, trials=10, results=None):
    env = construct_env(config)

    if not initial_sanity(env, nn_agent.datasets['retrieval'], config):
        print("Initial observation is not expected!")
    else:
        print("Initial observation matches expected")

    episode_rewards = []
    successes = 0

    for trial in range(trials):
        #print(trial)
        start = time.time()
        episode_reward, success = single_trial_eval(config, nn_agent, env, trial)
        #print(f"Episode length: {time.time() - start}s")
        #print(episode_reward)
        episode_rewards.append(episode_reward)
        successes += success

    if results is not None:
        os.makedirs('results', exist_ok=True)
        with open(f"results/{results}.pkl", 'wb') as f:
            pickle.dump(episode_rewards, f)
    print(
        f"Candidates {nn_agent.candidates}, lookback {nn_agent.lookback}, decay {nn_agent.decay}, ratio {nn_agent.final_neighbors_ratio}: "
        f"mean {round(np.mean(episode_rewards), 2)}, std {round(np.std(episode_rewards), 2)}"
    )

    return np.mean(episode_rewards)

def initial_sanity(env, model, comp_data_path, config, obs_scaler=None):
    cam_names = config.get("cams", [])
    crops = config.get('crops', {})

    comp_data = pickle.load(open(comp_data_path, 'rb'))

    height, width = 224, 224
    for traj in range(len(comp_data)):
        initial_state = dict(states=comp_data[traj]['states'][0])
        initial_state["model"] = comp_data[traj]["model_file"]
        env.reset()
        env.reset_to(initial_state)
        for ob in range(len(comp_data[traj]['observations'])):
            full_frame = np.empty((height, 0, 3), dtype=np.uint8)
            if len(cam_names) > 0:
                for camera in cam_names:
                    crop_corners = np.array(crops.get(camera, [[0, 0], [1.0, 1.0]]))
                    frame = env_to_rgb_array(env, camera, crop_corners, width, height)

                    full_frame = np.hstack((full_frame, frame))
                #import matplotlib.pyplot as plt
                #plt.imsave("first_img.png", full_frame)
                #full_frame = cv2.resize(full_frame, (height, width))

            processed_obs = get_processed_obs(env.get_observation(), full_frame, env, model, config, config['type'], is_first_ob=True)
            assert processed_obs is not None
            ground_truth = comp_data[traj]['observations'][ob]
            if obs_scaler:
                ground_truth = obs_scaler.inverse_transform(ground_truth)
            diff = torch.sum(torch.abs(ground_truth - processed_obs.detach().cpu()))
            if diff < 1e-8:
                print(f"Sanity observation good within 1e-8! ({diff})")
            else:
                print(f"Sanity observations differ by {diff}!")

            if ob != len(comp_data[traj]['observations']) - 1:
                next_state = dict(states=comp_data[traj]['states'][ob + 1])
                env.reset_to(next_state)
            else:
                env.step(comp_data[traj]['actions'][-1])

# Set initial state to an initial state from the training dataset
def nn_eval_sanity(config, nn_agent, data, comp_data, results=None):
    env = construct_env(config)
    episode_rewards = []
    successes = 0

    for trial in range(len(data)):
        initial_state["model"] = data[trial]["model_file"]
        env.reset_to(initial_state)
        episode_reward, success = single_trial_eval(config, nn_agent, env, trial, reset=False)

        if not (pickle.load(open("debug_obs.pkl", 'rb')) == comp_data[trial]['observations'][0]).all():
            print("OBSERVATIONS DON'T MATCH")
        else:
            print("OBSERVATIONS MATCH")

        #print(episode_reward)
        episode_rewards.append(episode_reward)
        successes += success

    if results is not None:
        os.makedirs('results', exist_ok=True)
        with open(f"results/{results}.pkl", 'wb') as f:
            pickle.dump(episode_rewards, f)
    print(
        f"Candidates {nn_agent.candidates}, lookback {nn_agent.lookback}, decay {nn_agent.decay}, ratio {nn_agent.final_neighbors_ratio}: "
        f"mean {round(np.mean(episode_rewards), 2)}, std {round(np.std(episode_rewards), 2)}"
    )

    return np.mean(episode_rewards)

