import numpy as np
from rlutil.logging import logger
import rlutil.torch as torch
import rlutil.torch.pytorch_util as ptu
import torch
import time
import tqdm
import os.path as osp
import copy
import pickle
import seaborn as sns
from huge.algo import buffer, networks
import matplotlib.cm as cm
import os
from datetime import datetime
import shutil
from huge.envs.room_env import PointmassGoalEnv
from huge.envs.sawyer_push import SawyerPushGoalEnv
from huge.envs.sawyer_push_hard import SawyerHardPushGoalEnv
from huge.envs.kitchen_simplified_state_space import KitchenGoalEnv

from huge.envs.kitchen_env_sequential import KitchenSequentialGoalEnv
from huge.envs.kitchen_env_3d import Kitchen3DGoalEnv

import wandb
import skvideo.io

#from gcsl.envs.kitchen_env import KitchenGoalEnv

try:
    from torch.utils.tensorboard import SummaryWriter
    tensorboard_enabled = True
except:
    print('Tensorboard not installed!')
    tensorboard_enabled = False

import tkinter
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

curr_label = 0

curr_label = 0
class Index:
    def first(self, event):
        global curr_label
        curr_label = 0
        plt.close()
    def second(self, event):
        global curr_label
        curr_label = 1
        plt.close()

#TODO: missing to dump trajectories

# New version GCSL with preferences
# Sample random goals
# Search on the buffer the set of achieved goals and pick up the closest achieved goal
# Launch batch of trajectories with all new achieved goals 
# we can launch one batch without exploration, just to reinforce stopping at the point and then another one with exploration
# add all trajectories to the buffer
# train standard GCSL
# THIS SHOULD WORK BY 11am, 12pm we have positive results on the 2d point environment

class GCSL:
    """Goal-conditioned Supervised Learning (GCSL).

    Parameters:
        env: A gcsl.envs.goal_env.GoalEnv
        policy: The policy to be trained (likely from gcsl.algo.networks)
        replay_buffer: The replay buffer where data will be stored
        validation_buffer: If provided, then 20% of sampled trajectories will
            be stored in this buffer, and used to compute a validation loss
        max_timesteps: int, The number of timesteps to run GCSL for.
        max_path_length: int, The length of each trajectory in timesteps

        # Exploration strategy
        
        explore_timesteps: int, The number of timesteps to explore randomly
        expl_noise: float, The noise to use for standard exploration (eps-greedy)

        # Evaluation / Logging Parameters

        goal_threshold: float, The distance at which a trajectory is considered
            a success. Only used for logging, and not the algorithm.
        eval_freq: int, The policy will be evaluated every k timesteps
        eval_episodes: int, The number of episodes to collect for evaluation.
        save_every_iteration: bool, If True, policy and buffer will be saved
            for every iteration. Use only if you have a lot of space.
        log_tensorboard: bool, If True, log Tensorboard results as well

        # Policy Optimization Parameters
        
        start_policy_timesteps: int, The number of timesteps after which
            GCSL will begin updating the policy
        batch_size: int, Batch size for GCSL updates
        n_accumulations: int, If desired batch size doesn't fit, use
            this many passes. Effective batch_size is n_acc * batch_size
        policy_updates_per_step: float, Perform this many gradient updates for
            every environment step. Can be fractional.
        train_policy_freq: int, How frequently to actually do the gradient updates.
            Number of gradient updates is dictated by `policy_updates_per_step`
            but when these updates are done is controlled by train_policy_freq
        lr: float, Learning rate for Adam.
        demonstration_kwargs: Arguments specifying pretraining with demos.
            See GCSL.pretrain_demos for exact details of parameters        
    """
    def __init__(self,
        env,
        policy,
        reward_model,
        replay_buffer,
        reward_model_buffer,
        validation_buffer=None,
        max_timesteps=1e6,
        max_path_length=50,
        # Exploration Strategy
        explore_timesteps=1e4,
        expl_noise=0.1,
        # Evaluation / Logging
        goal_threshold=0.05,
        eval_freq=5e3,
        eval_episodes=200,
        save_every_iteration=False,
        log_tensorboard=False,
        # Policy Optimization Parameters
        start_policy_timesteps=0,
        batch_size=100,
        n_accumulations=1,
        policy_updates_per_step=1,
        train_policy_freq=None,
        hallucinate_policy_freq=None,
        demonstrations_kwargs=dict(),
        train_with_hallucination=True,
        lr=5e-4,
        rewardmodel_epochs = 300,
        train_rewardmodel_freq = 10,#5000,
        display_trajectories_freq = 15,
        use_oracle=False,
        exploration_horizon=30,
        expanding_horizon=False,
        reward_model_num_samples=100,
        comment="",
        select_best_sample_size = 1000,
        load_buffer=False,
        save_buffer=-1,
        rewardmodel_batch_size = 1000,
        train_regression = False,
        load_rewardmodel=False, 
        render=False,
        sample_softmax = False,
        display_plots=False,
        data_folder="data",
        clip=5,
        stop_training_rewardmodel_steps = 2e6,
        remove_last_steps_when_stopped = True,
        exploration_when_stopped = True,
        distance_noise_std = 0.0,
        save_videos=True,
        logger_dump=False,
        stop_rewardmodel=-1,
        human_input=False,
        num_envs = 1,
        epsilon_greedy=0.2,
        set_desired_when_stopped=True,
        last_k_steps=10, # steps to look into for checking whether it stopped
        explore_length=10,

    ):
        self.epsilon_greedy=epsilon_greedy
        self.set_desired_when_stopped=set_desired_when_stopped
        self.last_k_steps=last_k_steps # steps to look into for checking whether it stopped
        self.explore_length=explore_length

        self.num_envs = num_envs
        self.env = env
        self.policy = policy
        self.random_policy = copy.deepcopy(policy)

        self.stop_training_rewardmodel_steps = stop_training_rewardmodel_steps
        self.rewardmodel_batch_size = rewardmodel_batch_size
        self.train_regression = train_regression

        """
        
        """
        #with open(f'human_dataset_06_10_2022_20:15:53.pickle', 'rb') as handle:
        #    self.human_data = pickle.load(handle)
        #    print(len(self.human_data))
        
        self.total_timesteps = 0

        self.buffer_filename = "buffer_saved.csv"
        self.val_buffer_filename = "val_buffer_saved.csv"
        self.data_folder = data_folder

        self.exploration_when_stopped = exploration_when_stopped
        self.load_buffer = load_buffer
        self.save_buffer = save_buffer

        self.stop_rewardmodel = stop_rewardmodel

        self.comment = comment
        self.display_plots = display_plots
        self.lr = lr
        self.clip = clip
        self.evaluate_reward_model = True

        self.reward_model_buffer = reward_model_buffer

        self.select_best_sample_size = select_best_sample_size

        self.store_model = False

        self.num_labels_queried = 0
        self.save_videos = save_videos

        self.load_rewardmodel = load_rewardmodel

        self.remove_last_steps_when_stopped = remove_last_steps_when_stopped

        self.train_with_hallucination = train_with_hallucination
        self.replay_buffer = replay_buffer
        self.validation_buffer = validation_buffer

        self.is_discrete_action = hasattr(self.env.action_space, 'n')

        self.max_timesteps = max_timesteps
        self.max_path_length = max_path_length

        self.explore_timesteps = explore_timesteps
        self.expl_noise = expl_noise
        self.render = render
        self.goal_threshold = goal_threshold
        self.eval_freq = eval_freq
        self.eval_episodes = eval_episodes
        self.save_every_iteration = save_every_iteration

        self.reward_model_num_samples = reward_model_num_samples

        self.start_policy_timesteps = start_policy_timesteps

        self.train_rewardmodel_freq = train_rewardmodel_freq
        self.display_trajectories_freq = display_trajectories_freq

        self.human_exp_idx = 0
        self.distance_noise_std = distance_noise_std

        if train_policy_freq is None:
            self.train_policy_freq = self.max_path_length
        else:
            self.train_policy_freq = train_policy_freq


        if hallucinate_policy_freq is None:
            hallucinate_policy_freq = self.max_path_length*300

        self.hallucinate_policy_freq = hallucinate_policy_freq

        self.batch_size = batch_size
        self.n_accumulations = n_accumulations
        self.policy_updates_per_step = policy_updates_per_step
        self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        
        self.log_tensorboard = log_tensorboard and tensorboard_enabled
        self.summary_writer = None

        self.exploration_horizon = exploration_horizon

        self.logger_dump = logger_dump

        self.dict_labels = {
            'state_1': [],
            'state_2': [],
            'label': [],
            'goal':[],
        }
        now = datetime.now()
        self.dt_string = now.strftime("%d_%m_%Y_%H:%M:%S")
        
        self.use_oracle = use_oracle
        if self.use_oracle:
            self.reward_model = self.oracle_model
            if load_rewardmodel:
                self.reward_model = reward_model
                self.reward_model.load_state_dict(torch.load("reward_model.pth"))
        else:
            self.reward_model = reward_model
            if load_rewardmodel:
                self.reward_model.load_state_dict(torch.load("reward_model.pth"))
            self.reward_optimizer = torch.optim.Adam(list(self.reward_model.parameters()))
            self.reward_model.to(device)
        
        self.policy.to(device)

        self.rewardmodel_epochs = rewardmodel_epochs

        self.device = "cuda"


        self.expanding_horizon = expanding_horizon

        self.sample_softmax = sample_softmax

        self.human_input = human_input

        self.traj_num_file = 0
        self.collected_trajs_dump = []
        self.success_ratio_eval_arr = []
        self.train_loss_arr = []
        self.distance_to_goal_eval_arr = []
        self.success_ratio_relabelled_arr = []
        self.eval_trajectories_arr = []
        self.train_loss_rewardmodel_arr = []
        self.eval_loss_arr = []
        self.distance_to_goal_eval_relabelled = []
        
        if isinstance(self.env.wrapped_env, PointmassGoalEnv):
            self.env_name = "pointmass"
        if isinstance(self.env.wrapped_env, SawyerPushGoalEnv):
            self.env_name ="pusher"        
        if isinstance(self.env.wrapped_env, SawyerHardPushGoalEnv):
            self.env_name ="pusher_hard"
        if isinstance(self.env.wrapped_env, KitchenGoalEnv):
            self.env_name ="kitchen"
        if isinstance(self.env.wrapped_env, Kitchen3DGoalEnv):
            self.env_name ="kitchen3D"
        if isinstance(self.env.wrapped_env, KitchenSequentialGoalEnv):
            self.env_name ="kitchenSeq"

        os.makedirs(self.data_folder, exist_ok=True)
        os.makedirs(os.path.join(self.data_folder, 'eval_trajectories'), exist_ok=True)


    def contrastive_loss(self, pred, label):
        label = label.float()
        pos = label@torch.clamp(pred[:,0]-pred[:,1], min=0)
        neg = (1-label)@torch.clamp(pred[:,1]-pred[:,0], min=0)

        #print("pos shape", pos.shape)
        return  pos + neg
    
    def eval_rewardmodel(self, eval_data, batch_size=32):
        achieved_states_1, achieved_states_2, goals ,labels = eval_data

        losses = []
        idxs = np.array(range(len(goals)))
        num_batches = len(idxs) // batch_size + 1
        losses = []
        loss_fn = torch.nn.CrossEntropyLoss()
        losses_eval = []

        # Eval the model
        mean_loss = 0.0
        start = time.time()
        total_samples = 0
        accuracy = 0
        for i in range(num_batches):

            t_idx = np.random.randint(len(goals), size=(batch_size,)) # Indices of first trajectory
                
            state1 = torch.Tensor(achieved_states_1[t_idx]).to(device)
            state2 = torch.Tensor(achieved_states_2[t_idx]).to(device)
            goal = torch.Tensor(goals[t_idx]).to(device)
            label_t = torch.Tensor(labels[t_idx]).long().to(device)

            g1g2 = torch.cat([self.reward_model(state1, goal), self.reward_model(state2, goal)], axis=-1)
            loss = loss_fn(g1g2, label_t)
            pred = torch.argmax(g1g2, dim=-1)
            accuracy += torch.sum(pred == label_t)
            total_samples+=len(label_t)
            # print statistics
            mean_loss += loss.item()

        mean_loss /=num_batches
        accuracy = accuracy.cpu().numpy() / total_samples

        return mean_loss,accuracy

    # TODO: try train regression on it
    def train_rewardmodel_regression(self,device, eval_data=None, batch_size=32, num_epochs=400):
        # Train standard goal conditioned policy

        loss_fn = torch.nn.MSELoss() 
        losses_eval = []

        self.reward_model.train()
        running_loss = 0.0
        
        # Train the model with regular SGD
        for epoch in range(num_epochs):  # loop over the dataset multiple times
            start = time.time()
            
            achieved_states, _,  goals ,distance = self.reward_model_buffer.sample_batch(batch_size)
            
            self.reward_optimizer.zero_grad()

            t_idx = np.random.randint(len(goals), size=(batch_size,)) # Indices of first trajectory
            
            state = torch.Tensor(achieved_states[t_idx]).to(device)
            goal = torch.Tensor(goals[t_idx]).to(device)
            dist_t = torch.Tensor(distance[t_idx]).to(device).float()
            pred = self.reward_model(state, goal)
            loss = loss_fn(pred, dist_t)
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.reward_model.parameters(), 5)
            self.reward_optimizer.step()

            # print statistics
            running_loss += float(loss.item())
            
            #if epoch % 10 == 0 and epoch > 0:
                #losses_eval, acc_eval = self.eval_rewardmodel(batch_size)
        
                #print("Accuracy eval is ", acc_eval)
                # print('[%d, %5d] loss: %.8f' %
                #     (epoch + 1, i + 1, running_loss / 100.))
        #if eval_data is not None:
        #    eval_loss, _ = self.eval_rewardmodel(eval_data, batch_size)
        #    losses_eval.append(eval_loss)
        return running_loss/batch_size, 0#, (losses_eval, acc_eval)

    def train_rewardmodel(self,device, eval_data=None, batch_size=32, num_epochs=400):
        # Train standard goal conditioned policy

        loss_fn = torch.nn.CrossEntropyLoss() 
        losses_eval = []

        self.reward_model.train()
        running_loss = 0.0
        
        # Train the model with regular SGD
        for epoch in range(num_epochs):  # loop over the dataset multiple times
            start = time.time()
            
            achieved_states_1, achieved_states_2, goals ,labels = self.reward_model_buffer.sample_batch(batch_size)
            
            self.reward_optimizer.zero_grad()

            t_idx = np.random.randint(len(goals), size=(batch_size,)) # Indices of first trajectory
            
            state1 = torch.Tensor(achieved_states_1[t_idx]).to(device)
            state2 = torch.Tensor(achieved_states_2[t_idx]).to(device)
            goal = torch.Tensor(goals[t_idx]).to(device)
            label_t = torch.Tensor(labels[t_idx]).long().to(device)

            g1g2 = torch.cat([self.reward_model(state1, goal), self.reward_model(state2, goal)], axis=-1)
            loss = loss_fn(g1g2, label_t)
            loss.backward()
            self.reward_optimizer.step()

            # print statistics
            running_loss += float(loss.item())
            
            #if epoch % 10 == 0 and epoch > 0:
                #losses_eval, acc_eval = self.eval_rewardmodel(batch_size)
        
                #print("Accuracy eval is ", acc_eval)
                # print('[%d, %5d] loss: %.8f' %
                #     (epoch + 1, i + 1, running_loss / 100.))
        #if eval_data is not None:
        #    eval_loss, _ = self.eval_rewardmodel(eval_data, batch_size)
        #    losses_eval.append(eval_loss)
        return running_loss, 0#, (losses_eval, acc_eval)


    def get_closest_achieved_state(self, goal_candidates, device, use_oracle=False):
        reached_state_idxs = []
        
        observations, actions, goals, _, horizons, weights = self.replay_buffer.sample_batch_last_steps(self.select_best_sample_size)

        #print("observations 0", observations[0])
        achieved_states = self.env.observation(observations)
        #print("achieved states", achieved_states[0])

        request_goals = []

        for goal_candidate in goal_candidates:
            
            state_tensor = torch.Tensor(achieved_states).to(device)
            goal_tensor = torch.Tensor(np.repeat(goal_candidate[None], len(achieved_states), axis=0)).to(device)

            if use_oracle:
                reward_vals = self.oracle_model(state_tensor, goal_tensor).cpu().detach().numpy()
                self.num_labels_queried += 1
            else:
                reward_vals = self.reward_model(state_tensor, goal_tensor).cpu().detach().numpy()
            
            if self.sample_softmax:
                best_idx = torch.distributions.Categorical(logits=torch.tensor(reward_vals.reshape(-1))).sample()
            else:
                best_idx = reward_vals.argmax()

            request_goals.append(achieved_states[best_idx])

        request_goals = np.array(request_goals)

        return request_goals

    def env_distance(self, state, goal):
        obs = self.env.observation(state)
        if isinstance(self.env.wrapped_env, PointmassGoalEnv):
            return self.env.wrapped_env.base_env.room.get_shaped_distance(obs, goal)
        else:
            return self.env.get_shaped_distance(obs, goal)
            
        #if isinstance(self.env.wrapped_env, KitchenGoalEnv):
        #    state = self.env.observation(state)
        #    if goal.shape[0]==90:
        #        goal = self.env.extract_goal(goal)
        #    return self.env.get_shaped_distance(state, goal)
        return None
    def oracle_model(self, state, goal):
        state = state.detach().cpu().numpy()

        goal = goal.detach().cpu().numpy()

        dist = [
            self.env_distance(state[i], goal[i]) + np.random.normal(scale=self.distance_noise_std)
            for i in range(goal.shape[0])
        ] #- np.linalg.norm(state - goal, axis=1)

        scores = - torch.tensor(np.array([dist])).T
        return scores
        
    # TODO: generalise this
    def oracle(self, state1, state2, goal):
        d1_dist = self.env_distance(state1, goal) + np.random.normal(scale=self.distance_noise_std) #self.env.shaped_distance(state1, goal) # np.linalg.norm(state1 - goal, axis=-1)
        d2_dist = self.env_distance(state2, goal) + np.random.normal(scale=self.distance_noise_std) #self.env.shaped_distance(state2, goal) # np.linalg.norm(state2 - goal, axis=-1)

        if d1_dist < d2_dist:
            return 0
        else:
            return 1

    def generate_pref_labels_regression(self, goal_states, extract=False):
        observations_1, _, _, _, _, _ = self.replay_buffer.sample_batch(self.reward_model_num_samples) # TODO: add
   
        goals = []
        labels = []
        achieved_state = []

        # TODO: remove
        #goal_states = np.array([[0.3,0.3]])
        num_goals = len(goal_states)
        for state_1 in observations_1:
            for goal in goal_states:
                if extract:
                    goal = self.env.extract_goal(goal)
                labels.append(self.env_distance(state_1, goal)) # oracle TODO: we will use human labels here

                achieved_state.append(state_1) 
                goals.append(goal)

        achieved_state = np.array(achieved_state)
        goals = np.array(goals)
        labels = np.array(labels)
        
        return achieved_state, achieved_state, goals, labels # TODO: check ordering
    def display_wall_fig(self, fig, ax):
        walls = self.env.base_env.room.get_walls()
        for wall in walls:
            start, end = wall
            sx, sy = start
            ex, ey = end
            ax.plot([sx, ex], [sy, ey], marker='o',  color = 'b')

    def stop_training_rewardmodel(self, event ):
        self.stop_rewardmodel = 0
    def ask_human_labels(self, state1, state2, goal):
        if self.human_exp_idx < len(self.human_data['label']):
            label = self.human_data['label'][self.human_exp_idx]
            self.human_exp_idx += 1
            return label
        else:
            from matplotlib.widgets import Button
            global curr_label
            curr_label = 0
            callback = Index()
            fig, ax = plt.subplots()
            self.display_wall_fig(fig, ax)
            fig.subplots_adjust(bottom=0.2)
            axfirst = fig.add_axes([0.7,0.05, 0.1, 0.075])
            axsecond = fig.add_axes([0.81,0.05,0.1,0.075])
            axthird = fig.add_axes([0.9,0.05,0.1,0.075])
            ax.scatter(state1[0], state1[1], color="blue")
            ax.scatter(state2[0], state2[1], color="red")
            ax.scatter(goal[0], goal[1], marker='o', s=100, color='seagreen')
            bfirst = Button(axfirst, 'Blue')
            bfirst.color = 'royalblue'
            bfirst.hovercolor = 'blue'
            bfirst.on_clicked(callback.first)
            bsecond = Button(axsecond, 'Red')
            bsecond.color = 'salmon'
            bsecond.hovercolor = 'red'
            bsecond.on_clicked(callback.second)
            bthird = Button(axthird, 'black')
            bthird.color = 'black'
            bthird.hovercolor = 'black'
            bthird.on_clicked(self.stop_training_rewardmodel)
            plt.show()#block=False)
            #plt.pause(1)
            #key = ""
            #while key != 'r' and key != 'b':
            #    key =  input("Please give the preference, b if blue is closer to the goal in black and r if red is closer: ")
            #    print(key)
            #plt.close()
            #if key == 'b':
            #    return 0
            #else:
            #    return 1
            return curr_label
    
    def generate_pref_from_human(self, goal_states):
        observations_1, _, _, _, _, _ = self.replay_buffer.sample_batch_last_steps(self.reward_model_num_samples) # TODO: add
        observations_2, _, _, _, _, _ = self.replay_buffer.sample_batch_last_steps(self.reward_model_num_samples) # TODO: add
   
        goals = []
        labels = []
        achieved_state_1 = []
        achieved_state_2 = []

        num_goals = len(goal_states)
        for state_1, state_2 in zip(observations_1, observations_2):
            goal_idx = np.random.randint(0, len(goal_states)) 
            goal = self.env.extract_goal(goal_states[goal_idx])
            label_oracle = self.oracle(state_1, state_2, goal)
            label = self.ask_human_labels(state_1, state_2, goal)
            print("Correct:", label==label_oracle, "label", label, "label_oracle", label_oracle)

            labels.append(label) 

            self.num_labels_queried += 1 

            achieved_state_1.append(state_1) 
            achieved_state_2.append(state_2) 
            goals.append(goal)

            # dump data
            self.dict_labels['state_1'].append(state_1)
            self.dict_labels['state_2'].append(state_2)
            self.dict_labels['label'].append(label)
            self.dict_labels['goal'].append(goal)
            with open(f'human_dataset_{self.dt_string}.pickle', 'wb') as handle:
                pickle.dump(self.dict_labels, handle)

        achieved_state_1 = np.array(achieved_state_1)
        achieved_state_2 = np.array(achieved_state_2)
        goals = np.array(goals)
        labels = np.array(labels)
        
        return achieved_state_1, achieved_state_2, goals, labels # TODO: check ordering


    # TODO: this is not working too well witht the shaped distances
    def generate_pref_labels(self, goal_states):
        observations_1, _, _, _, _, _ = self.replay_buffer.sample_batch_last_steps(self.reward_model_num_samples) # TODO: add
        observations_2, _, _, _, _, _ = self.replay_buffer.sample_batch_last_steps(self.reward_model_num_samples) # TODO: add
   
        goals = []
        labels = []
        achieved_state_1 = []
        achieved_state_2 = []

        num_goals = len(goal_states)
        for state_1, state_2 in zip(observations_1, observations_2):
            goal_idx = np.random.randint(0, len(goal_states)) 
            goal = self.env.extract_goal(goal_states[goal_idx])
            labels.append(self.oracle(state_1, state_2, goal)) 

            self.num_labels_queried += 1 

            achieved_state_1.append(state_1) 
            achieved_state_2.append(state_2) 
            goals.append(goal)

        achieved_state_1 = np.array(achieved_state_1)
        achieved_state_2 = np.array(achieved_state_2)
        goals = np.array(goals)
        labels = np.array(labels)
        
        return achieved_state_1, achieved_state_2, goals, labels # TODO: check ordering

    def loss_fn(self, observations, goals, actions, horizons, weights):
        obs_dtype = torch.float32
        action_dtype = torch.int64 if self.is_discrete_action else torch.float32

        observations_torch = torch.tensor(observations, dtype=obs_dtype).to(device)
        goals_torch = torch.tensor(goals, dtype=obs_dtype).to(device)
        actions_torch = torch.tensor(actions, dtype=action_dtype).to(device)
        if horizons is not None:
            horizons_torch = torch.tensor(horizons, dtype=obs_dtype).to(device)
        else:
            horizons_torch = None
        weights_torch = torch.tensor(weights, dtype=torch.float32).to(device)

        conditional_nll = self.policy.nll(observations_torch, goals_torch, actions_torch, horizon=horizons_torch)
        nll = conditional_nll

        return torch.mean(nll * weights_torch)
    
    def traj_stopped(self, states):
        thresh = 0.05
        if np.shape(states)[0] <= self.last_k_steps:
            return [False for _ in range(self.num_envs)]


        state1 = states[-self.last_k_steps]
        state2 = states[-1]

        return np.linalg.norm(state1-state2, axis=1) < thresh

    def create_video(self, images, video_filename):
        images = np.array(images).astype(np.uint8)
        skvideo.io.vwrite(f"{self.trajectories_videos_folder}/{video_filename}.mp4", images)
        if 'eval' in video_filename:
            wandb.log({"eval_video_trajectories":wandb.Video(f"{self.trajectories_videos_folder}/{video_filename}.mp4")})
        else:
            wandb.log({"video_trajectories":wandb.Video(f"{self.trajectories_videos_folder}/{video_filename}.mp4")})
    


    def sample_trajectory(self, goal= None, greedy=False, noise=0, with_preferences=False, exploration_enabled=False,save_video_trajectory=False, video_filename='traj_0'):
        if goal is None:
            #print("i")
            goal_state = self.env.sample_goal()
            desired_goal_state = goal_state.copy()
            desired_goal = self.env.extract_goal(goal_state.copy())
            #print("goal state", goal_state)
            commanded_goal_state = goal_state.copy()
            commanded_goal = self.env.extract_goal(goal_state.copy())

            # Get closest achieved state
            # TODO: this might be too much human querying, except if we use the reward model
            if with_preferences:
                goal = self.get_closest_achieved_state([commanded_goal], self.device,)[0]
                #print(f"goal {goal}, commanded_goal {commanded_goal}")
                if np.linalg.norm(commanded_goal - goal) < self.goal_threshold:
                    goal = commanded_goal
                    exploration_enabled = False
                    print("Goals too close, prefrences disabled")
                else:
                    commanded_goal = goal.copy()
                    print("Using preferences")
            else:
                goal = commanded_goal

        else:
            # TODO: URGENT should fix this
            commanded_goal = goal.copy()
            desired_goal = goal.copy()
            commanded_goal_state = np.concatenate([goal.copy(), goal.copy(), goal.copy()])
            desired_goal_state = commanded_goal_state.copy()

        commanded_goal_state = np.concatenate([goal.copy(), goal.copy(), goal.copy()])

        states = []
        actions = []
        video = []

        state = self.env.reset()
        stopped = np.zeros(self.num_envs)
        stopped[:] = False
        ts = np.zeros(self.num_envs)
        t = 0
        ts_stopped = - np.ones(self.num_envs)
        while np.any(ts < self.max_path_length) and not np.all(stopped):
            if self.render:
                self.env.render()

            if save_video_trajectory and self.save_videos: #and False: # TODO: remove
                imgs = self.env.get_images()[0]#render_image()[0]
                video.append(imgs)

            states.append(state)

            observation = self.env.observation(state)

            new_goal = np.array([goal for _ in range(self.num_envs)])
                
            random_action = self.random_policy.act_vectorized(observation, new_goal,  greedy=False, noise=noise)
            
            policy_action = self.policy.act_vectorized(observation, new_goal, greedy=greedy, noise=noise)
            
            if self.set_desired_when_stopped:
                desired_new_goal = np.array([desired_goal for _ in range(self.num_envs)])
                policy_action_to_goal = self.policy.act_vectorized(observation, desired_new_goal, greedy=greedy, noise=noise)
            else:
                policy_action_to_goal = policy_action

            greedy_vect = np.random.random(self.num_envs) < self.epsilon_greedy
            random_action = random_action*greedy_vect + policy_action_to_goal*(1- greedy_vect)

            action = stopped*random_action + (1-stopped)*policy_action
            action = action.astype(np.int)
            if not self.is_discrete_action:
                policy_action = np.clip(action, self.env.action_space.low, self.env.action_space.high)

            actions.append(action)
            
            if self.exploration_when_stopped and exploration_enabled:
                res_stopped = self.traj_stopped(states)
                old_stopped = stopped.copy()
                stopped = np.logical_or(stopped, res_stopped) 
                new_stopped = np.logical_xor(old_stopped, stopped)

                if np.any(new_stopped) and self.remove_last_steps_when_stopped:
                    ts[new_stopped]-=self.last_k_steps
                    ts_stopped[new_stopped] = t
            state, _, _, _ = self.env.step(action)
            ts += 1
            t += 1

        new_states = []
        new_actions = []
        actions = np.array(actions)
        states = np.array(states)
        ts_stopped = ts_stopped.astype(np.int)
        for i in range(self.num_envs):
            if stopped[i]:
                inter_states = np.concatenate([states[:ts_stopped[i] - self.last_k_steps,i], states[ts_stopped[i]:min(self.max_path_length + self.last_k_steps, ts_stopped[i]+self.explore_length+self.last_k_steps),i]])
                inter_actions = np.concatenate([actions[:ts_stopped[i] - self.last_k_steps,i], actions[ts_stopped[i]:min(self.max_path_length + self.last_k_steps, ts_stopped[i]+self.explore_length + self.last_k_steps),i]])
                if len(inter_states) > self.max_path_length:
                    import IPython
                    IPython.embed()
            else:
                inter_states = states[:self.max_path_length, i]
                inter_actions = actions[:self.max_path_length, i]

            new_states.append(inter_states)
            new_actions.append(inter_actions)
        
            print("newstate shape", np.shape(new_states[-1]))
        video = np.array(video)
        states = np.array(new_states)
        actions = np.array(new_actions)

        if save_video_trajectory and self.save_videos:
            #for i in range(self.num_envs):
                i=0
                final_dist = self.env_distance(states[i][-1], desired_goal)
                #video_env = video[:,i]
                #print("video env", video_env.shape)
                self.create_video(video, f"{video_filename}_{i}_{final_dist}")
                with open(f'{self.trajectories_videos_folder}/{video_filename}_{i}_{final_dist}', 'w') as f:
                    f.write(str(goal))

        return states, actions, np.array([commanded_goal_state for _ in range(self.num_envs)]), np.array([desired_goal_state for _ in range(self.num_envs)])
    

    def take_policy_step(self, buffer=None):
        if buffer is None:
            buffer = self.replay_buffer

        avg_loss = 0
        self.policy_optimizer.zero_grad()
        
        for _ in range(self.n_accumulations):
            observations, actions, goals, _, horizons, weights = buffer.sample_batch(self.batch_size)

            loss = self.loss_fn(observations, goals, actions, horizons, weights)

            loss.backward()
            avg_loss += ptu.to_numpy(loss)
        
        torch.nn.utils.clip_grad_norm(self.policy.parameters(), self.clip)
        self.policy_optimizer.step()

        return avg_loss / self.n_accumulations

    def validation_loss(self, buffer=None):

        if buffer is None:
            buffer = self.validation_buffer

        if buffer is None or buffer.current_buffer_size == 0:
            return 0, 0

        avg_loss = 0
        avg_rewardmodel_loss = 0
        for _ in range(self.n_accumulations):
            observations, actions, goals, lengths, horizons, weights = buffer.sample_batch(self.batch_size)
            loss = self.loss_fn(observations, goals, actions, horizons, weights)
            #eval_data = self.generate_pref_labels(observations, actions, [goals], extract=False)
            #print("eval data", eval_data)
            #loss_rewardmodel =self.eval_rewardmodel(eval_data)
            # TODO: implement eval loss
            loss_rewardmodel = torch.tensor(0)
            avg_loss += ptu.to_numpy(loss)
            avg_rewardmodel_loss += ptu.to_numpy(loss_rewardmodel)

        return avg_loss / self.n_accumulations, avg_rewardmodel_loss / self.n_accumulations

    def pretrain_demos(self, demo_replay_buffer=None, demo_validation_replay_buffer=None, demo_train_steps=0):
        if demo_replay_buffer is None:
            return

        self.policy.train()
        with tqdm.trange(demo_train_steps) as looper:
            for _ in looper:
                loss = self.take_policy_step(buffer=demo_replay_buffer)
                validation_loss, rewardmodel_val_loss = self.validation_loss(buffer=demo_validation_replay_buffer)

                if running_loss is None:
                    running_loss = loss
                else:
                    running_loss = 0.99 * running_loss + 0.01 * loss
                if running_validation_loss is None:
                    running_validation_loss = validation_loss
                else:
                    running_validation_loss = 0.99 * running_validation_loss + 0.01 * validation_loss

                looper.set_description('Loss: %.03f Validation Loss: %.03f'%(running_loss, running_validation_loss))
        
    # TODO: why isn't this working??
    def test_rewardmodel(self, itr):
        goal =self.env.sample_goal()#np.random.uniform(-0.5, 0.5, size=(2,))
        goal_pos =  self.env.extract_goal(goal)
        #goal_pos = goal
        #TODO: remove
        #goal_pos = np.array([0.3,0.3])
        goals = np.repeat(goal_pos[None], 10000, axis=0)
        states = np.random.uniform(-0.6, 0.6, size=(10000, 2))
        states_t = torch.Tensor(states).cuda()
        goals_t = torch.Tensor(goals).cuda()
        r_val = self.reward_model(states_t, goals_t)
        #print("goal pos", goal_pos.shape)
        #r_val = self.oracle_model(states_t, goals_t)
        r_val = r_val.cpu().detach().numpy()
        plt.clf()
        plt.cla()
        #self.display_wall(plt)
        plt.scatter(states[:, 0], states[:, 1], c=r_val[:, 0], cmap=cm.jet)

        if self.env_name == "pusher":
            self.display_wall_pusher()

            plt.scatter(goal_pos[2], goal_pos[3], marker='o', s=100, color='black')
        else:
            self.display_wall()
            plt.scatter(goal_pos[0], goal_pos[1], marker='o', s=100, color='black')

        
        plt.savefig(self.env_name+"/rewardmodel_test/test_rewardmodel_itr%d.png"%itr)
        
        r_val = self.oracle_model(states_t, goals_t)
        r_val = r_val.cpu().detach().numpy()
        plt.clf()
        plt.cla()
        #self.display_wall(plt)
        plt.scatter(states[:, 0], states[:, 1], c=r_val[:, 0], cmap=cm.jet)
        if self.env_name == "pusher":
            self.display_wall_pusher()

            plt.scatter(goal_pos[2], goal_pos[3], marker='o', s=100, color='black')
        else:
            self.display_wall()
            plt.scatter(goal_pos[0], goal_pos[1], marker='o', s=100, color='black')
        plt.savefig("rewardmodel_test/test_oracle_itr%d.png"%itr)
        
        

    def plot_visit_freq(self, itr):
        pos = np.random.uniform(-0.5, 0.5, size=(2,))
        #goals = np.repeat(goal_pos[None], 10000, axis=0)
        #states = np.random.uniform(-0.5, 0.5, size=(10000, 2))
        #states_t = torch.Tensor(states).cuda()
        #goals_t = torch.Tensor(goals).cuda()
        #r_val = self.reward_model(states_t, goals_t, goals_t)
        r_val = np.zeros(pos.shape)
        #r_val = r_val.cpu().detach().numpy()
        os.makedirs('rewardmodel_test', exist_ok=True)
        plt.clf()
        plt.cla()
        self.display_wall()
        plt.scatter(states[:, 0], states[:, 1], c=r_val[:, 0], cmap=cm.jet)
        plt.scatter(goal_pos[0], goal_pos[1], marker='o', s=100, color='black')
        plt.savefig("rewardmodel_test/test_rewardmodel_itr%d.png"%itr)

    def full_grid_evaluation(self, itr):
        grid_size = 20
        goals = np.linspace(-0.6, 0.6, grid_size)
        distances = np.zeros((grid_size,grid_size))

        for x in range(len(goals)):
            for y in range(len(goals)):
                goal = np.array([goals[x],goals[y]])
                if self.num_envs == 1:
                    states, actions, goal_state, _ = self.sample_trajectory(goal=goal, greedy=True)
                    
                distance =  np.linalg.norm(goal - states[-1][-2:])
                distances[x,y]= distance 

        plot = sns.heatmap(distances, xticklabels=goals, yticklabels=goals)
        fig = plot.get_figure()
        fig.savefig(f'heatmap_performance/eval_{itr}.png')
        plot = sns.heatmap(distances < self.goal_threshold, xticklabels=goals, yticklabels=goals)
        fig = plot.get_figure()
        fig.savefig(f'heatmap_accuracy/eval_{itr}.png')
    
    def get_distances(self, state, goal):
        obs = self.env.observation(state)

        if not isinstance(self.env.wrapped_env, KitchenSequentialGoalEnv):
            return None, None, None, None, None, None

        per_pos_distance, per_obj_distance = self.env.success_distance(obs)
        distance_to_slide = per_pos_distance['slide_cabinet']
        distance_to_hinge = per_pos_distance['hinge_cabinet']
        distance_to_microwave = per_pos_distance['microwave']
        distance_joint_slide = per_obj_distance['slide_cabinet']
        distance_joint_hinge = per_obj_distance['hinge_cabinet']
        distance_microwave = per_obj_distance['microwave']

        return distance_to_slide, distance_to_hinge, distance_to_microwave, distance_joint_slide, distance_joint_hinge, distance_microwave

    def plot_trajectories(self,traj_accumulated_states, traj_accumulated_goal_states, extract=True, filename=""):
        if isinstance(self.env.wrapped_env, PointmassGoalEnv):
            return self.plot_trajectories_rooms(traj_accumulated_states.copy(), traj_accumulated_goal_states.copy(), extract, "pointmass/" + filename)
        if isinstance(self.env.wrapped_env, SawyerPushGoalEnv):
            return self.plot_trajectories_pusher(traj_accumulated_states.copy(), traj_accumulated_goal_states.copy(), extract, "pusher/" + filename)
        if isinstance(self.env.wrapped_env, SawyerHardPushGoalEnv):
            return self.plot_trajectories_pusher_hard(traj_accumulated_states.copy(), traj_accumulated_goal_states.copy(), extract, "pusher_hard/" + filename)

    def plot_trajectories_rooms(self,traj_accumulated_states, traj_accumulated_goal_states, extract=True, filename=""):
        # plot added trajectories to fake replay buffer
        plt.clf()
        self.display_wall()
        
        colors = sns.color_palette('hls', (traj_accumulated_states.shape[0]))
        for j in range(traj_accumulated_states.shape[0]):
            color = colors[j]
            plt.plot(self.env.observation(traj_accumulated_states[j ])[:,0], self.env.observation(traj_accumulated_states[j])[:, 1], color=color, zorder = -1)
            #if 'train_states_preferences' in filename:
            #    color = 'black'
            
            plt.scatter(traj_accumulated_goal_states[j][-2],
                    traj_accumulated_goal_states[j][-1], marker='o', s=20, color=color, zorder=1)
        
        plt.savefig(filename)

    def plot_trajectories_pusher(self,traj_accumulated_states, traj_accumulated_goal_states, extract=True, filename=""):
        # plot added trajectories to fake replay buffer
        plt.clf()
        plt.cla()
        self.display_wall_pusher()
        #if extract:

        states_plot =  self.env._extract_sgoal(traj_accumulated_states)
        traj_accumulated_goal_states =  self.env._extract_sgoal(traj_accumulated_goal_states)

        #else:
        #    states_plot = traj_accumulated_states
        #shutil.rmtree("train_states_preferences")
        colors = sns.color_palette('hls', (states_plot.shape[0]))
        for j in range(states_plot.shape[0]):
            color = colors[j]
            plt.plot(states_plot[j ][:,2], states_plot[j][:, 3], color=color)
            plt.scatter(traj_accumulated_goal_states[j][2],
                    traj_accumulated_goal_states[j][3], marker='o', s=20, color=color)
        
        plt.savefig(filename)

    def plot_trajectories_pusher_hard(self,traj_accumulated_states, traj_accumulated_goal_states, extract=True, filename=""):
        # plot added trajectories to fake replay buffer
        plt.clf()
        plt.cla()
        #self.display_wall()
        #if extract:

        states_plot =  self.env._extract_sgoal(traj_accumulated_states)
        traj_accumulated_goal_states =  self.env._extract_sgoal(traj_accumulated_goal_states)

        #else:
        #    states_plot = traj_accumulated_states
        #shutil.rmtree("train_states_preferences")
        colors = sns.color_palette('hls', (states_plot.shape[0]))
        for j in range(states_plot.shape[0]):
            color = colors[j]
            plt.plot(states_plot[j ][:,2], states_plot[j][:, 3], color=color)
            plt.scatter(traj_accumulated_goal_states[j][2],
                    traj_accumulated_goal_states[j][3], marker='o', s=20, color=color)
        
        plt.savefig(filename)

    def collect_and_train_rewardmodel(self, desired_goal_states_rewardmodel,total_timesteps):
        print("Collecting and training rewardmodel")
        # TODO: we are gonna substitute generate pref labels with human labelling
        if self.train_regression:
            achieved_state_1, achieved_state_2, goals, labels = self.generate_pref_labels_regression(desired_goal_states_rewardmodel)
        elif self.human_input:
            achieved_state_1, achieved_state_2, goals, labels = self.generate_pref_from_human(desired_goal_states_rewardmodel)
        else:
            achieved_state_1, achieved_state_2, goals, labels = self.generate_pref_labels(desired_goal_states_rewardmodel)
        # TODO: add validation buffer
        
        self.reward_model_buffer.add_multiple_data_points(achieved_state_1, achieved_state_2, goals, labels)

        # Train reward model
        if not self.use_oracle:
            # Generate labels with preferences
            if self.train_regression:
                losses_reward_model, eval_loss_reward_model = self.train_rewardmodel_regression(device, self.rewardmodel_batch_size)
            else:
                losses_reward_model, eval_loss_reward_model = self.train_rewardmodel(device, self.rewardmodel_batch_size)

            print("Computing reward model loss ", np.mean(losses_reward_model))
            if self.summary_writer:
                self.summary_writer.add_scalar('LossesRewardModel/Train', np.mean(losses_reward_model), total_timesteps)
            wandb.log({'LossesRewardModel/Train':np.mean(losses_reward_model), 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})

            self.train_loss_rewardmodel_arr.append((np.mean(losses_reward_model), total_timesteps))
        
        return losses_reward_model, eval_loss_reward_model

    def dump_data(self):
        metrics = {
            'success_ratio_eval_arr':self.success_ratio_eval_arr,
            'train_loss_arr':self.train_loss_arr,
            'distance_to_goal_eval_arr':self.distance_to_goal_eval_arr,
            'success_ratio_relabelled_arr':self.success_ratio_relabelled_arr,
            'eval_trajectories_arr':self.eval_trajectories_arr,
            'train_loss_rewardmodel_arr':self.train_loss_rewardmodel_arr,
            'eval_loss_arr':self.eval_loss_arr,
            'distance_to_goal_eval_relabelled':self.distance_to_goal_eval_relabelled,
        }
        with open(os.path.join(self.data_folder, 'metrics.pkl'), 'wb') as f:
            pickle.dump(metrics, f)

    def dump_trajectories(self):
        
        with open(os.path.join(self.data_folder, f'eval_trajectories/traj_{self.traj_num_file}.pkl'), 'wb') as f:
            pickle.dump(self.collected_trajs_dump, f)
        self.traj_num_file +=1

        self.collected_trajs_dump = []

    def train(self):
        start_time = time.time()
        last_time = start_time

        # Evaluate untrained policy
        total_timesteps = 0
        timesteps_since_train = 0
        timesteps_since_eval = 0
        timesteps_since_reset = 0

        iteration = 0
        running_loss = None
        running_validation_loss = None
        rewardmodel_running_val_loss = None

        losses_reward_model_acc = None
        if self.display_plots:
            os.makedirs("relabeled_states_preferences", exist_ok=True)
            shutil.rmtree("relabeled_states_preferences")
            os.makedirs("train_states_preferences", exist_ok=True)
            os.makedirs("relabeled_states_preferences", exist_ok=True)
            os.makedirs("explore_states_trajectories", exist_ok=True)
            os.makedirs("train_states_preferences", exist_ok=True)
            shutil.rmtree("explore_states_trajectories")
            os.makedirs("heatmap_performance", exist_ok=True)
            os.makedirs("explore_states_trajectories", exist_ok=True)
            shutil.rmtree("heatmap_performance")
            os.makedirs("heatmap_accuracy", exist_ok=True)
            os.makedirs("heatmap_performance", exist_ok=True)
            shutil.rmtree("heatmap_accuracy")
            os.makedirs(self.env_name+'/rewardmodel_test', exist_ok=True)        
            os.makedirs("heatmap_accuracy", exist_ok=True)
            os.makedirs('preferences_distance', exist_ok=True)
            shutil.rmtree(self.env_name+"/rewardmodel_test")
            os.makedirs(self.env_name+'/rewardmodel_test', exist_ok=True)        
            shutil.rmtree("preferences_distance")
            os.makedirs('preferences_distance', exist_ok=True)

        now = datetime.now()
        dt_string = now.strftime("%d_%m_%Y_%H:%M")
        os.makedirs(f'{self.env_name}', exist_ok=True)
        self.trajectories_videos_folder = f'{self.env_name}/trajectories_videos_{dt_string}'
        os.makedirs(self.trajectories_videos_folder, exist_ok=True)

        
        
        now = datetime.now()
        dt_string = now.strftime("%d_%m_%Y_%H:%M:%S")

        if logger.get_snapshot_dir() and self.log_tensorboard:
            info = self.comment
            if self.train_with_hallucination:
                info+="preferences"
            info+= f"_hallucination_freq_{self.hallucinate_policy_freq}"
            info+= f"_start_policy_{self.start_policy_timesteps}"
            info+= f"_use_oracle_{self.use_oracle}"
            info+= f"_lr_{self.lr}"
            info+= f"_batch_size_{self.batch_size}"
            info+= f"_select_best_sample_size_{self.select_best_sample_size}"
            info+= f"_max_path_length_{self.max_path_length}"
            

            tensorboard_path = osp.join(logger.get_snapshot_dir(), info)

            print("tensorboard directory", tensorboard_path)
            self.summary_writer = SummaryWriter(tensorboard_path)
        else:
            print("Tensorboard failed", logger.get_snapshot_dir(), self.log_tensorboard)

        # Evaluation Code
        self.policy.eval()
        if self.train_with_hallucination and self.display_plots:
            if os.path.exists(self.env_name+"/train_states_preferences"):
                shutil.rmtree(self.env_name+"/train_states_preferences")

            os.makedirs(self.env_name+"/train_states_preferences", exist_ok=True)

            os.makedirs(self.env_name+"/plots_preferences", exist_ok=True)
            shutil.rmtree(self.env_name+"/plots_preferences")
            os.makedirs(self.env_name+"/plots_preferences", exist_ok=True)
            os.makedirs(self.env_name+"/plots_preferences_requested", exist_ok=True)
            shutil.rmtree(self.env_name+"/plots_preferences_requested")
            os.makedirs(self.env_name+"/plots_preferences_requested", exist_ok=True)
            plots_folder = "plots_preferences"
            plots_folder_requested = "plots_preferences_requested"

        elif self.display_plots:
            os.makedirs(self.env_name+"/plots", exist_ok=True)
            shutil.rmtree(self.env_name+"/plots")
            os.makedirs(self.env_name+"/plots", exist_ok=True)
            plots_folder = self.env_name+"/plots"
            os.makedirs(self.env_name+"/plots_requested", exist_ok=True)
            shutil.rmtree(self.env_name+"/plots_requested")
            os.makedirs(self.env_name+"/plots_requested", exist_ok=True)
            if os.path.exists(self.env_name+"/train_states"):
                shutil.rmtree(self.env_name+"/train_states")

            os.makedirs(self.env_name+"/train_states", exist_ok=True)


            plots_folder = "/plots"
            plots_folder_requested = "/plots_requested"
        else:
            plots_folder = ""
            plots_folder_requested = ""


        self.evaluate_policy(self.eval_episodes, total_timesteps=0, greedy=True, prefix='Eval', plots_folder=plots_folder)
        logger.record_tabular('policy loss', 0)
        logger.record_tabular('reward model train loss', 0)
        logger.record_tabular('reward model eval loss', 0)
        logger.record_tabular('timesteps', total_timesteps)
        logger.record_tabular('epoch time (s)', time.time() - last_time)
        logger.record_tabular('total time (s)', time.time() - start_time)
        last_time = time.time()
        logger.dump_tabular()
        # End Evaluation Code

        # Trajectory states being accumulated
        traj_accumulated_states = []
        traj_accumulated_actions = []
        traj_accumulated_goal_states = []
        desired_goal_states_rewardmodel = []
        traj_accumulated_desired_goal_states = []
        goal_states_rewardmodel = []
        full_iters = 0

        
        with tqdm.tqdm(total=self.eval_freq, smoothing=0) as ranger:
            while total_timesteps < self.max_timesteps:
                self.total_timesteps = total_timesteps
                full_iters +=1
                if self.save_buffer != -1 and total_timesteps > self.save_buffer:
                    self.save_buffer = -1
                    self.replay_buffer.save(self.buffer_filename)
                    self.validation_buffer.save(self.val_buffer_filename)


                #print("total timesteps", total_timesteps, "max timesteps", self.max_timesteps)
                # Interact in environmenta according to exploration strategy.
                # TODO: we can probably skip this in preferences or use it to learn a rewardmodel
                if total_timesteps < self.explore_timesteps:
                    #print("Sample trajectory noise")
                    states, actions, goal_state, desired_goal_state = self.sample_trajectory(noise=1, exploration_enabled=False)
                    for i in range(self.num_envs):
                        traj_accumulated_states.append(states[i])
                        traj_accumulated_desired_goal_states.append(desired_goal_state[i])
                        traj_accumulated_actions.append(actions[i])
                        traj_accumulated_goal_states.append(goal_state[i])
                        """
                        if self.train_with_hallucination and not self.use_oracle:
                            self.collect_and_train_rewardmodel(np.array([goal_state]))
                        """
                        if total_timesteps != 0 and self.validation_buffer is not None and np.random.rand() < 0.2:
                            self.validation_buffer.add_trajectory(states[i], actions[i], goal_state[i])
                        else:
                            self.replay_buffer.add_trajectory(states[i], actions[i], goal_state[i])

                elif not self.train_with_hallucination:
                    assert not self.use_oracle and not self.sample_softmax
                    #print("sample trajectory greedy")
                    states, actions, goal_state, desired_goal_state = self.sample_trajectory(greedy=False, noise=self.expl_noise, exploration_enabled=False)
                    for i in range(self.num_envs):
                        traj_accumulated_states.append(states[i])
                        traj_accumulated_desired_goal_states.append(desired_goal_state[i])
                        traj_accumulated_actions.append(actions[i])
                        traj_accumulated_goal_states.append(goal_state[i])
                        #desired_goal_states_rewardmodel.append(desired_goal_state)
                        #goal_states_rewardmodel.append(goal_state)
                        if total_timesteps != 0 and self.validation_buffer is not None and np.random.rand() < 0.2:
                            self.validation_buffer.add_trajectory(states[i], actions[i], goal_state[i])
                        else:
                            self.replay_buffer.add_trajectory(states[i], actions[i], goal_state[i])
                
                
                # Interact in environmenta according to exploration strategy.
                # TODO: should we try increasing the explore timesteps?
                if self.train_with_hallucination and total_timesteps > self.explore_timesteps:
                    save_video_trajectory = full_iters % 10 == 0
                    video_filename = f"traj_{total_timesteps}"

                    explore_states, explore_actions, explore_goal_state, desired_goal_state = self.sample_trajectory(greedy=False, noise=self.expl_noise, with_preferences=True, exploration_enabled=True, save_video_trajectory=save_video_trajectory, video_filename=video_filename)
                    for i in range(self.num_envs):
                        traj_accumulated_states.append(explore_states[i])
                        traj_accumulated_desired_goal_states.append(desired_goal_state[i])
                        traj_accumulated_actions.append(explore_actions[i])
                        traj_accumulated_goal_states.append(explore_goal_state[i])
                        desired_goal_states_rewardmodel.append(desired_goal_state[i])
                        goal_states_rewardmodel.append(explore_goal_state[i])

                    
                        if self.validation_buffer is not None and np.random.rand() < 0.2:
                            self.validation_buffer.add_trajectory(explore_states[i], explore_actions[i], explore_goal_state[i])
                        else:
                            self.replay_buffer.add_trajectory(explore_states[i], explore_actions[i], explore_goal_state[i])

                #if total_timesteps < self.explore_timesteps: # TODO: remove
                    # With some probability, put this new trajectory into the validation buffer

                
                
                
                #print(f"Attr: train with hallucination: {self.train_with_hallucination}, hallucinate freq. {self.hallucinate_policy_freq}, policy_timesteps:{self.start_policy_timesteps}")
                if  self.train_with_hallucination and full_iters % self.train_rewardmodel_freq == 0 and total_timesteps > self.explore_timesteps:
                    #print("total timesteps", total_timesteps)
                    desired_goal_states_rewardmodel = np.array(desired_goal_states_rewardmodel)
                    goal_states_rewardmodel = np.array(goal_states_rewardmodel)

                    dist = np.array([
                            self.env_distance(desired_goal_states_rewardmodel[i], self.env.extract_goal(goal_states_rewardmodel)[i])
                            for i in range(desired_goal_states_rewardmodel.shape[0])
                    ])

                    if self.summary_writer:
                        #print(dist, np.mean(dist))
                        self.summary_writer.add_scalar("Preferences/DistanceCommandedToDesiredGoal", np.mean(dist), total_timesteps)
                    wandb.log({'Preferences/DistanceCommandedToDesiredGoal':np.mean(dist), 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})
                    
                    self.distance_to_goal_eval_arr.append((np.mean(dist), total_timesteps))
                    if self.display_plots:
                        plt.clf()
                        #self.display_wall()
                        
                        colors = sns.color_palette('hls', (goal_states_rewardmodel.shape[0]))
                        for j in range(desired_goal_states_rewardmodel.shape[0]):
                            color = colors[j]
                            plt.scatter(desired_goal_states_rewardmodel[j][-2],
                                    desired_goal_states_rewardmodel[j][-1], marker='o', s=20, color=color)
                            plt.scatter(goal_states_rewardmodel[j][-2],
                                    goal_states_rewardmodel[j][-1], marker='x', s=20, color=color)
                        
                        plt.savefig(f'preferences_distance/distance_commanded_to_desired_goal%d.png'%total_timesteps)
                    # relabel and add to buffer
                    if not self.use_oracle and (self.stop_rewardmodel==-1 or self.total_timesteps < self.stop_rewardmodel): #and self.stop_training_rewardmodel_steps > self.total_timesteps:
                        losses_reward_model, eval_loss_reward_model = self.collect_and_train_rewardmodel(desired_goal_states_rewardmodel, total_timesteps)
                    
                    desired_goal_states_rewardmodel = []
                    goal_states_rewardmodel = []

                
                if len(traj_accumulated_actions) % self.display_trajectories_freq == 0:
                    traj_accumulated_states = np.array(traj_accumulated_states)
                    traj_accumulated_actions = np.array(traj_accumulated_actions)
                    traj_accumulated_goal_states = np.array(traj_accumulated_goal_states)
                    if self.display_plots:
                        if self.train_with_hallucination:
                            self.plot_trajectories(traj_accumulated_states, traj_accumulated_goal_states, filename=f'train_states_preferences/train_trajectories_%d.png'%total_timesteps)
                        else:
                            self.plot_trajectories(traj_accumulated_states, traj_accumulated_goal_states, filename=f'train_states/train_trajectories_%d.png'%total_timesteps)


                    #if self.train_with_hallucination and not self.use_oracle and self.display_plots:
                    #    self.test_rewardmodel(total_timesteps)               

                    self.dump_data()



                    if self.env_name == "kitchenSeq":
                        avg_distance_to_hinge = 0
                        avg_distance_to_slide = 0
                        avg_distance_to_microwave = 0
                        avg_distance_joint_hinge = 0
                        avg_distance_joint_slide = 0
                        avg_distance_joint_microwave = 0
                        avg_success = 0
                        avg_distance_total = 0
                        traj_accumulated_desired_goal_states = np.array(traj_accumulated_desired_goal_states)
                        print(traj_accumulated_desired_goal_states.shape)
                        for i in range(traj_accumulated_desired_goal_states.shape[0]):

                            distance_to_slide, distance_to_hinge, distance_to_microwave, distance_joint_slide, distance_joint_hinge, distance_joint_microwave = self.get_distances(traj_accumulated_states[i][-1], self.env.extract_goal(traj_accumulated_desired_goal_states[i]))
                            success = self.env.compute_success(traj_accumulated_states[i][-1], self.env.extract_goal(traj_accumulated_desired_goal_states[i]))
                            distance_total = self.env.compute_shaped_distance(traj_accumulated_states[i][-1], self.env.extract_goal(traj_accumulated_desired_goal_states[i]))

                            if distance_to_hinge is None:
                                break

                            avg_distance_to_hinge += distance_to_hinge
                            avg_distance_to_slide += distance_to_slide
                            avg_distance_to_microwave += distance_to_microwave
                            avg_distance_joint_hinge += distance_joint_hinge
                            avg_distance_joint_slide += distance_joint_slide
                            avg_distance_joint_microwave += distance_joint_microwave
                            avg_success += success
                            avg_distance_total += distance_total

                        
                        avg_distance_to_hinge /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_to_slide /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_to_microwave /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_joint_hinge /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_joint_slide /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_joint_microwave /= traj_accumulated_desired_goal_states.shape[0]
                        avg_success /= traj_accumulated_desired_goal_states.shape[0]
                        avg_distance_total /= traj_accumulated_desired_goal_states.shape[0]

                        if self.summary_writer:           
                            self.summary_writer.add_scalar("DistanceToHinge", avg_distance_to_hinge, self.total_timesteps)
                            self.summary_writer.add_scalar("DistanceToSlide", avg_distance_to_slide, self.total_timesteps)
                            self.summary_writer.add_scalar("DistanceToMicrowave", avg_distance_to_microwave, self.total_timesteps)
                            self.summary_writer.add_scalar("DistanceJointSlide", avg_distance_joint_slide, self.total_timesteps)
                            self.summary_writer.add_scalar("DistanceJointHinge", avg_distance_joint_hinge, self.total_timesteps)
                            self.summary_writer.add_scalar("DistanceJointMicrowave", avg_distance_joint_microwave, self.total_timesteps)
                            self.summary_writer.add_scalar("TrainingSuccess", avg_success, self.total_timesteps)
                            self.summary_writer.add_scalar("TrainingDistance", avg_distance_total, self.total_timesteps)

                        wandb.log({'DistanceToHinge':avg_distance_to_hinge, 'timesteps':self.total_timesteps,  'num_labels_queried':self.num_labels_queried})
                        wandb.log({'DistanceToSlide':avg_distance_to_slide, 'timesteps':self.total_timesteps,  'num_labels_queried':self.num_labels_queried})
                        wandb.log({'DistanceToMicrowave':avg_distance_to_microwave, 'timesteps':self.total_timesteps,  'num_labels_queried':self.num_labels_queried})
                        wandb.log({'DistanceJointSlide':avg_distance_joint_slide, 'timesteps':self.total_timesteps, 'num_labels_queried':self.num_labels_queried})
                        wandb.log({'DistanceJointHinge':avg_distance_joint_hinge, 'timesteps':self.total_timesteps, 'num_labels_queried':self.num_labels_queried})
                        wandb.log({'DistanceJointMicrowave':avg_distance_joint_microwave, 'timesteps':self.total_timesteps, 'num_labels_queried':self.num_labels_queried})
                        wandb.log({'TrainingSuccess':avg_success, 'timesteps':self.total_timesteps, 'num_labels_queried':self.num_labels_queried})
                        wandb.log({'TrainingDistance':avg_distance_total, 'timesteps':self.total_timesteps,  'num_labels_queried':self.num_labels_queried})

                    traj_accumulated_states = []
                    traj_accumulated_actions = []
                    traj_accumulated_goal_states = []
                    traj_accumulated_desired_goal_states = []

                total_timesteps += self.max_path_length*self.num_envs
                timesteps_since_train += self.max_path_length
                timesteps_since_eval += self.max_path_length
                
                ranger.update(self.max_path_length)
                
                # Take training steps
                #print(f"timesteps since train {timesteps_since_train}, train policy freq {self.train_policy_freq}, total_timesteps {total_timesteps}, start policy timesteps {self.start_policy_timesteps}")
                if full_iters % self.train_policy_freq == 0 and total_timesteps >= self.start_policy_timesteps:
                    timesteps_since_train %= self.train_policy_freq
                    self.policy.train()
                    for idx in range(int(self.policy_updates_per_step*self.train_policy_freq)):
                        loss = self.take_policy_step()
                        validation_loss, rewardmodel_val_loss = self.validation_loss()

                        if running_loss is None:
                            running_loss = loss
                        else:
                            running_loss = 0.9 * running_loss + 0.1 * loss

                        if running_validation_loss is None:
                            running_validation_loss = validation_loss
                        else:
                            running_validation_loss = 0.9 * running_validation_loss + 0.1 * validation_loss

                        if rewardmodel_running_val_loss is None:
                            rewardmodel_running_val_loss = rewardmodel_val_loss
                        else:
                            rewardmodel_running_val_loss = 0.9 * rewardmodel_running_val_loss + 0.1 * rewardmodel_val_loss

                    self.policy.eval()
                    ranger.set_description('Loss: %s Validation Loss: %s'%(running_loss, running_validation_loss))
                    
                    if self.summary_writer:
                        self.summary_writer.add_scalar('Losses/Train', running_loss, total_timesteps)
                        self.summary_writer.add_scalar('Losses/Validation', running_validation_loss, total_timesteps)
                        self.summary_writer.add_scalar('LossesRewardModel/Eval', rewardmodel_running_val_loss, total_timesteps)
                    wandb.log({'Losses/Train':running_loss, 'timesteps':total_timesteps,  'num_labels_queried':self.num_labels_queried})
                    wandb.log({'Losses/Validation':running_validation_loss, 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})
                    wandb.log({'LossesRewardModel/Eval':rewardmodel_running_val_loss, 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})
                    
                    self.train_loss_arr.append((running_loss, total_timesteps))
                    self.eval_loss_arr.append((running_validation_loss, total_timesteps))
                    self.train_loss_rewardmodel_arr.append((rewardmodel_running_val_loss, total_timesteps))

                
                # Evaluate, log, and save to disk
                if timesteps_since_eval >= self.eval_freq:
                    timesteps_since_eval %= self.eval_freq
                    iteration += 1
                    # Evaluation Code
                    self.policy.eval()
                    self.evaluate_policy(self.eval_episodes, total_timesteps=total_timesteps, greedy=True, prefix='Eval', plots_folder=plots_folder)
                    _, _, goals, _, _, _ = self.replay_buffer.sample_batch(self.eval_episodes)
                    self.evaluate_policy_requested(goals, total_timesteps=total_timesteps, greedy=True, prefix='EvalRequested', plots_folder=plots_folder_requested)

                    logger.record_tabular('policy loss', running_loss or 0) # Handling None case

                    #if iteration % 10 == 0:
                    #    self.full_grid_evaluation(iteration)

                    if self.train_with_hallucination:
                        
                        if self.store_model:
                            torch.save(self.reward_model.state_dict(), f'reward_models/reward_model_{dt_string}.pth')
                
                    if self.logger_dump:
                        logger.record_tabular('reward model train loss', 0)
                        logger.record_tabular('reward model eval loss', 0)
                            
                        logger.record_tabular('timesteps', total_timesteps)
                        logger.record_tabular('epoch time (s)', time.time() - last_time)
                        logger.record_tabular('total time (s)', time.time() - start_time)
                        last_time = time.time()
                        logger.dump_tabular()

                        
                        # Logging Code
                        if logger.get_snapshot_dir():
                            modifier = str(iteration) if self.save_every_iteration else ''
                            torch.save(
                                self.policy.state_dict(),
                                osp.join(logger.get_snapshot_dir(), 'policy%s.pkl'%modifier)
                            )
                            if hasattr(self.replay_buffer, 'state_dict'):
                                with open(osp.join(logger.get_snapshot_dir(), 'buffer%s.pkl'%modifier), 'wb') as f:
                                    pickle.dump(self.replay_buffer.state_dict(), f)

                            full_dict = dict(env=self.env, policy=self.policy)
                            with open(osp.join(logger.get_snapshot_dir(), 'params%s.pkl'%modifier), 'wb') as f:
                                pickle.dump(full_dict, f)
                        
                        ranger.reset()
                        
                    
    def evaluate_policy(self, eval_episodes=200, greedy=True, prefix='Eval', total_timesteps=0, plots_folder="plots"):
        print("Evaluate policy")
        env = self.env
        
        all_states = []
        all_goal_states = []
        all_actions = []
        final_dist_vec = np.zeros(eval_episodes)
        success_vec = np.zeros(eval_episodes)

        for index in tqdm.trange(eval_episodes, leave=True):
            video_filename = f"eval_traj_{total_timesteps}"
            states, actions, goal_state, _ = self.sample_trajectory(noise=0, greedy=greedy, save_video_trajectory=index==0, video_filename=video_filename)
            final_dist = []
            for i in range(self.num_envs):
                all_actions.extend(actions[i])
                all_states.append(states[i])
                all_goal_states.append(goal_state[i])
                inter_dist = env.goal_distance(states[i,-1], goal_state[i]) # TODO: should we compute shaped distance?
                final_dist.append(inter_dist)

            final_dist = np.array(final_dist)
            final_dist_vec[index] = np.mean(final_dist)
            success_vec[index] = np.mean(final_dist < self.goal_threshold)

        all_states = np.stack(all_states)
        all_goal_states = np.stack(all_goal_states)
        print('%s num episodes'%prefix, len(all_goal_states))
        print('%s avg final dist'%prefix,  np.mean(final_dist_vec))
        print('%s success ratio'%prefix, np.mean(success_vec))

        logger.record_tabular('%s num episodes'%prefix, eval_episodes)
        logger.record_tabular('%s avg final dist'%prefix,  np.mean(final_dist_vec))
        logger.record_tabular('%s success ratio'%prefix, np.mean(success_vec))
        if self.summary_writer:
            self.summary_writer.add_scalar('%s/avg final dist'%prefix, np.mean(final_dist_vec), total_timesteps)
            self.summary_writer.add_scalar('%s/success ratio'%prefix,  np.mean(success_vec), total_timesteps)

        wandb.log({'%s/avg final dist'%prefix:np.mean(final_dist_vec), 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})
        wandb.log({'%s/success ratio'%prefix:np.mean(success_vec), 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})

        self.success_ratio_eval_arr.append((np.mean(success_vec), total_timesteps))
        self.distance_to_goal_eval_arr.append((np.mean(final_dist_vec), total_timesteps))
        
        diagnostics = env.get_diagnostics(all_states, all_goal_states)
        for key, value in diagnostics.items():
            print('%s %s'%(prefix, key), value)
            logger.record_tabular('%s %s'%(prefix, key), value)
        
        if self.display_plots:
            self.plot_trajectories(all_states, all_goal_states, extract=False, filename=f'{plots_folder}/eval_%d.png'%total_timesteps)

        return all_states, all_goal_states


    def display_wall(self):
        walls = self.env.wrapped_env.base_env.room.get_walls()
        for wall in walls:
            start, end = wall
            sx, sy = start
            ex, ey = end
            plt.plot([sx, ex], [sy, ey], marker='o',  color = 'b')
    def display_wall_pusher_hard(self):
        walls = [
            [(-0.025, 0.625), (0.025, 0.625)],
            [(0.025, 0.625), (0.025, 0.575)],
            [(0.025, 0.575), (-0.025, 0.575) ],
            [(-0.025, 0.575), (-0.025, 0.625)]
        ]

        for wall in walls:
            start, end = wall
            sx, sy = start
            ex, ey = end
            plt.plot([sx, ex], [sy, ey], marker='o',  color = 'b')
    def display_wall_pusher(self):
        walls = [
            [(-0.025, 0.625), (0.025, 0.625)],
            [(0.025, 0.625), (0.025, 0.575)],
            [(0.025, 0.575), (-0.025, 0.575) ],
            [(-0.025, 0.575), (-0.025, 0.625)]
        ]

        for wall in walls:
            start, end = wall
            sx, sy = start
            ex, ey = end
            plt.plot([sx, ex], [sy, ey], marker='o',  color = 'b')

    def evaluate_policy_requested(self, requested_goals, greedy=True, prefix='Eval', total_timesteps=0, plots_folder="plots"):
        env = self.env
        
        all_states = []
        all_goal_states = []
        all_actions = []
        final_dist_vec = np.zeros(len(requested_goals))
        success_vec = np.zeros(len(requested_goals))

        for index, goal in enumerate(requested_goals):
            
            states, actions, goal_state, _ = self.sample_trajectory(goal, noise=0, greedy=greedy)
            final_dist = []
            for i in range(self.num_envs):
                all_actions.extend(actions[i])
                all_states.append(states[i])
                all_goal_states.append(goal_state[i])
                inter_dist = self.env_distance(states[i,-1], self.env.extract_goal(goal_state[i])) # TODO: should we compute shaped distance?
                final_dist.append(inter_dist)
            final_dist = np.array(final_dist)
            final_dist_vec[index] = np.mean(final_dist)
            success_vec[index] = np.mean(final_dist < self.goal_threshold)


        all_states = np.stack(all_states)
        all_goal_states = np.stack(all_goal_states)

        """
        logger.record_tabular('%s num episodes'%prefix, len(requested_goals))
        logger.record_tabular('%s avg final dist requested goals'%prefix,  np.mean(final_dist_vec))
        logger.record_tabular('%s success ratio requested goals'%prefix, np.mean(success_vec))
        
        diagnostics = env.get_diagnostics(all_states, all_goal_states)
        for key, value in diagnostics.items():
            logger.record_tabular('%s %s'%(prefix, key), value)
        """
        print('%s num episodes'%prefix, len(requested_goals))
        print('%s avg final dist relabelled goals'%prefix,  np.mean(final_dist_vec))
        print('%s success ratio relabelled goals'%prefix, np.mean(success_vec))

        if self.summary_writer:
            self.summary_writer.add_scalar('%s/avg final dist relabelled goals'%prefix, np.mean(final_dist_vec), total_timesteps)
            self.summary_writer.add_scalar('%s/success ratio relabelled goals'%prefix,  np.mean(success_vec), total_timesteps)
        wandb.log({'%s/avg final dist relabelled goals'%prefix:np.mean(final_dist_vec), 'timesteps':total_timesteps,'num_labels_queried':self.num_labels_queried})
        wandb.log({'%s/success ratio relabelled goals'%prefix:np.mean(success_vec), 'timesteps':total_timesteps, 'num_labels_queried':self.num_labels_queried})
        
        self.success_ratio_relabelled_arr.append((np.mean(success_vec), total_timesteps))
        self.distance_to_goal_eval_relabelled.append((np.mean(success_vec), total_timesteps))
        diagnostics = env.get_diagnostics(all_states, all_goal_states)
        for key, value in diagnostics.items():
            print('%s %s'%(prefix, key), value)

        if self.display_plots:
            self.plot_trajectories(all_states, all_goal_states, extract=False, filename=f'{plots_folder}/eval_requested_%d.png'%total_timesteps)


        return all_states, all_goal_states
