import hydra.utils
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import itertools
import tqdm
import copy
import scipy.stats as st
import os
import time

from scipy.stats import norm


# device = 'cpu'


class SingleRewardModel(nn.Module):
    def __init__(self, in_size=1, out_size=1, H=128, n_layers=3, activation='tanh', next_state_predictive=False):
        super().__init__()
        self.next_state_predictive = next_state_predictive
        in_size_original = in_size

        net = []
        for i in range(n_layers):
            net.append(nn.Linear(in_size, H))
            net.append(nn.LeakyReLU())
            in_size = H
        self.embedding = nn.Sequential(*net)

        # get prediction model
        net.append(nn.Linear(in_size, out_size))
        if activation == 'tanh':
            net.append(nn.Tanh())
        elif activation == 'sig':
            net.append(nn.Sigmoid())
        else:
            net.append(nn.ReLU())
        self.model = nn.Sequential(*net[-2:])  # add the prediction layer.

        self.model_next_state = None
        if self.next_state_predictive:  # from embedding to next state
            self.model_next_state = nn.Linear(H, in_size_original)

    def forward(self, x, embedding=False, next_state_prediction=False):
        e = self.embedding(x)
        out = self.model(e)

        result = [out]
        if embedding:
            result.append(e)

        if next_state_prediction:
            n = self.model_next_state(e)
            result.append(n)

        if len(result) == 1:
            return out
        else:
            return result


def gen_net(in_size=1, out_size=1, H=128, n_layers=3, activation='tanh'):
    net = []
    for i in range(n_layers):
        net.append(nn.Linear(in_size, H))
        net.append(nn.LeakyReLU())
        in_size = H
    net.append(nn.Linear(in_size, out_size))
    if activation == 'tanh':
        net.append(nn.Tanh())
    elif activation == 'sig':
        net.append(nn.Sigmoid())
    else:
        net.append(nn.ReLU())

    return net


def KCenterGreedy(obs, full_obs, num_new_sample, device='cpu'):
    selected_index = []
    current_index = list(range(obs.shape[0]))
    new_obs = obs
    new_full_obs = full_obs
    start_time = time.time()
    for count in range(num_new_sample):
        dist = compute_smallest_dist(new_obs, new_full_obs, device)
        max_index = torch.argmax(dist)
        max_index = max_index.item()

        if count == 0:
            selected_index.append(max_index)
        else:
            selected_index.append(current_index[max_index])
        current_index = current_index[0:max_index] + current_index[max_index + 1:]

        new_obs = obs[current_index]
        new_full_obs = np.concatenate([
            full_obs,
            obs[selected_index]],
            axis=0)
    return selected_index


def compute_smallest_dist(obs, full_obs, device):
    obs = torch.from_numpy(obs).float()
    full_obs = torch.from_numpy(full_obs).float()
    batch_size = 100
    with torch.no_grad():
        total_dists = []
        for full_idx in range(len(obs) // batch_size + 1):
            full_start = full_idx * batch_size
            if full_start < len(obs):
                full_end = (full_idx + 1) * batch_size
                dists = []
                for idx in range(len(full_obs) // batch_size + 1):
                    start = idx * batch_size
                    if start < len(full_obs):
                        end = (idx + 1) * batch_size
                        dist = torch.norm(
                            obs[full_start:full_end, None, :].to(device) - full_obs[None, start:end, :].to(device),
                            dim=-1, p=2
                        )
                        dists.append(dist)
                dists = torch.cat(dists, dim=1)
                small_dists = torch.torch.min(dists, dim=1).values
                total_dists.append(small_dists)

        total_dists = torch.cat(total_dists)
    return total_dists.unsqueeze(1)


class RewardModel:
    def __init__(self, ds, da,
                 ensemble_size=3, lr=3e-4, mb_size=128, size_segment=1,
                 env_maker=None, max_size=100, activation='tanh', capacity=5e5,
                 large_batch=1, label_margin=0.0,
                 teacher_beta=-1, teacher_gamma=1,
                 teacher_eps_mistake=0,
                 teacher_eps_skip=0,
                 teacher_eps_equal=0,
                 device='cpu',
                 reward_triplet_loss_cfg=None,
                 action_distance_loss_cfg=None,
                 surf_loss_cfg=None,
                 rdynamics_loss_cfg=None,
                 l2embed_loss_cfg=None):

        # train data is trajectories, must process to sa and s..
        self.rdynamics_loss_cfg = rdynamics_loss_cfg  # cfg if rewards can predict next states
        self.rdynamics = hydra.utils.instantiate(
            rdynamics_loss_cfg) if rdynamics_loss_cfg.name != 'none' else None

        self.device = device
        self.ds = ds
        self.da = da
        self.de = ensemble_size
        self.lr = lr
        self.ensemble = []
        self.paramlst = []
        self.opt = None
        self.model = None
        self.max_size = max_size
        self.activation = activation
        self.size_segment = size_segment

        self.capacity = int(capacity)
        self.buffer_seg1 = np.empty((self.capacity, size_segment, self.ds + self.da), dtype=np.float32)
        self.buffer_seg2 = np.empty((self.capacity, size_segment, self.ds + self.da), dtype=np.float32)
        self.buffer_label = np.empty((self.capacity, 1), dtype=np.float32)
        self.buffer_index = 0
        self.buffer_full = False

        self.construct_ensemble()
        self.inputs = []
        self.targets = []
        self.raw_actions = []
        self.img_inputs = []
        self.mb_size = mb_size
        self.origin_mb_size = mb_size
        self.train_batch_size = 128
        self.CEloss = nn.CrossEntropyLoss()
        self.running_means = []
        self.running_stds = []
        self.best_seg = []
        self.best_label = []
        self.best_action = []
        self.large_batch = large_batch

        # new teacher
        self.teacher_beta = teacher_beta
        self.teacher_gamma = teacher_gamma
        self.teacher_eps_mistake = teacher_eps_mistake
        self.teacher_eps_equal = teacher_eps_equal
        self.teacher_eps_skip = teacher_eps_skip
        self.teacher_thres_skip = 0
        self.teacher_thres_equal = 0

        self.label_margin = label_margin
        self.label_target = 1 - 2 * self.label_margin

        self.reward_triplet_loss_cfg = reward_triplet_loss_cfg
        self.reward_triplet_loss = hydra.utils.instantiate(
            reward_triplet_loss_cfg) if reward_triplet_loss_cfg.name != 'none' else None

        self.action_distance_loss_cfg = action_distance_loss_cfg
        self.action_distance_loss = hydra.utils.instantiate(
            action_distance_loss_cfg) if action_distance_loss_cfg.name != 'none' else None

        self.l2embed_loss_cfg = l2embed_loss_cfg
        self.l2embed_loss = hydra.utils.instantiate(
            l2embed_loss_cfg) if l2embed_loss_cfg.name != 'none' else None



        self.surf_loss = surf_loss_cfg  # TODO. configure via hydra

    def softXEnt_loss(self, input, target):
        logprobs = torch.nn.functional.log_softmax(input, dim=1)
        return -(target * logprobs).sum() / input.shape[0]

    def change_batch(self, new_frac):
        self.mb_size = int(self.origin_mb_size * new_frac)

    def set_batch(self, new_batch):
        self.mb_size = int(new_batch)

    def set_teacher_thres_skip(self, new_margin):
        self.teacher_thres_skip = new_margin * self.teacher_eps_skip

    def set_teacher_thres_equal(self, new_margin):
        self.teacher_thres_equal = new_margin * self.teacher_eps_equal

    def construct_ensemble(self):
        is_dynamics_predictive = self.rdynamics is not None
        for i in range(self.de):
            model = SingleRewardModel(in_size=self.ds + self.da,
                                      out_size=1, H=256, n_layers=3,
                                      activation=self.activation,
                                      next_state_predictive=is_dynamics_predictive).float().to(self.device)
            self.ensemble.append(model)
            self.paramlst.extend(model.parameters())

        self.opt = torch.optim.Adam(self.paramlst, lr=self.lr)

    def add_data(self, obs, act, rew, done):
        sa_t = np.concatenate([obs, act], axis=-1)
        r_t = rew

        flat_input = sa_t.reshape(1, self.da + self.ds)
        r_t = np.array(r_t)
        flat_target = r_t.reshape(1, 1)

        init_data = len(self.inputs) == 0
        if init_data:
            self.inputs.append(flat_input)
            self.targets.append(flat_target)
        elif done:
            self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
            self.targets[-1] = np.concatenate([self.targets[-1], flat_target])
            # FIFO
            if len(self.inputs) > self.max_size:
                self.inputs = self.inputs[1:]
                self.targets = self.targets[1:]
            self.inputs.append([])
            self.targets.append([])
        else:
            if len(self.inputs[-1]) == 0:
                self.inputs[-1] = flat_input
                self.targets[-1] = flat_target
            else:
                self.inputs[-1] = np.concatenate([self.inputs[-1], flat_input])
                self.targets[-1] = np.concatenate([self.targets[-1], flat_target])

    def add_data_batch(self, obses, rewards):
        num_env = obses.shape[0]
        for index in range(num_env):
            self.inputs.append(obses[index])
            self.targets.append(rewards[index])

    def get_rank_probability(self, x_1, x_2):
        # get probability x_1 > x_2
        probs = []
        for member in range(self.de):
            probs.append(self.p_hat_member(x_1, x_2, member=member).cpu().numpy())
        probs = np.array(probs)

        return np.mean(probs, axis=0), np.std(probs, axis=0)

    def get_entropy(self, x_1, x_2):
        # get probability x_1 > x_2
        probs = []
        for member in range(self.de):
            probs.append(self.p_hat_entropy(x_1, x_2, member=member).cpu().numpy())
        probs = np.array(probs)
        return np.mean(probs, axis=0), np.std(probs, axis=0)

    def p_hat_member(self, x_1, x_2, member=-1):
        # softmaxing to get the probabilities according to eqn 1
        with torch.no_grad():
            r_hat1 = self.r_hat_member(x_1, member=member)
            r_hat2 = self.r_hat_member(x_2, member=member)
            r_hat1 = r_hat1.sum(axis=1)
            r_hat2 = r_hat2.sum(axis=1)
            r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

        # taking 0 index for probability x_1 > x_2
        return F.softmax(r_hat, dim=-1)[:, 0]

    def p_hat_entropy(self, x_1, x_2, member=-1):
        # softmaxing to get the probabilities according to eqn 1
        with torch.no_grad():
            r_hat1 = self.r_hat_member(x_1, member=member)
            r_hat2 = self.r_hat_member(x_2, member=member)
            r_hat1 = r_hat1.sum(axis=1)
            r_hat2 = r_hat2.sum(axis=1)
            r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

        ent = F.softmax(r_hat, dim=-1) * F.log_softmax(r_hat, dim=-1)
        ent = ent.sum(axis=-1).abs()
        return ent

    def r_hat_member(self, x, member=-1, embedding=False, next_state_prediction=False):
        # the network parameterizes r hat in eqn 1 from the paper
        return self.ensemble[member](torch.from_numpy(x).float().to(self.device), embedding=embedding,
                                     next_state_prediction=next_state_prediction)

    def r_hat(self, x):
        # they say they average the rewards from each member of the ensemble, but I think this only makes sense if the rewards are already normalized
        # but I don't understand how the normalization should be happening right now :(
        r_hats = []
        for member in range(self.de):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)
        return np.mean(r_hats)

    def r_hat_batch(self, x):
        # they say they average the rewards from each member of the ensemble, but I think this only makes sense if the rewards are already normalized
        # but I don't understand how the normalization should be happening right now :(
        r_hats = []
        for member in range(self.de):
            r_hats.append(self.r_hat_member(x, member=member).detach().cpu().numpy())
        r_hats = np.array(r_hats)

        return np.mean(r_hats, axis=0)

    def save(self, model_dir, step):
        for member in range(self.de):
            torch.save(
                self.ensemble[member].state_dict(), '%s/reward_model_%s_%s.pt' % (model_dir, step, member)
            )

    def load(self, model_dir, step):
        for member in range(self.de):
            self.ensemble[member].load_state_dict(
                torch.load('%s/reward_model_%s_%s.pt' % (model_dir, step, member))
            )

    def get_train_acc(self):
        ensemble_acc = np.array([0 for _ in range(self.de)])
        max_len = self.capacity if self.buffer_full else self.buffer_index
        total_batch_index = np.random.permutation(max_len)
        batch_size = 256
        num_epochs = int(np.ceil(max_len / batch_size))

        total = 0
        for epoch in range(num_epochs):
            last_index = (epoch + 1) * batch_size
            if (epoch + 1) * batch_size > max_len:
                last_index = max_len

            sa_t_1 = self.buffer_seg1[epoch * batch_size:last_index]
            sa_t_2 = self.buffer_seg2[epoch * batch_size:last_index]
            labels = self.buffer_label[epoch * batch_size:last_index]
            labels = torch.from_numpy(labels.flatten()).long().to(self.device)
            total += labels.size(0)
            for member in range(self.de):
                # get logits
                r_hat1 = self.r_hat_member(sa_t_1, member=member)
                r_hat2 = self.r_hat_member(sa_t_2, member=member)
                r_hat1 = r_hat1.sum(axis=1)
                r_hat2 = r_hat2.sum(axis=1)
                r_hat = torch.cat([r_hat1, r_hat2], axis=-1)
                _, predicted = torch.max(r_hat.data, 1)
                correct = (predicted == labels).sum().item()
                ensemble_acc[member] += correct

        ensemble_acc = ensemble_acc / total
        return np.mean(ensemble_acc)

    def _get_max_len_trajectories(self):
        len_traj, max_len = len(self.inputs[0]), len(self.inputs)

        if len(self.inputs[-1]) < len_traj:
            max_len = max_len - 1

        # TODO. The issue I found is that the self.inputs is a list of np.ndarray and np.array(self.inputs[:max_len]) doesn't work since all the elements of self.inputs are not of the same length.
        # the original code first tries to create np.array of var length (not possible -> and causes err)
        # select the sa1 and sa2
        # and then trim the selected trajs to be of self.size_segment

        # filter the self.inputs to get
        # a. Arrays with length of exactly size_segment
        # b. Arrays start index are random.

        t_inputs, t_targets = [], []
        for xi, xt in zip(self.inputs[:max_len], self.targets[:max_len]):
            if (xi.shape[0] < self.size_segment):
                continue

            if xi.shape[0] != self.size_segment:
                start = np.random.randint(xi.shape[0] - self.size_segment)
            else:
                start = 0
            t_inputs.append(xi[start:start + self.size_segment])
            t_targets.append(xt[start:start + self.size_segment])

        return t_inputs, t_targets

    def get_queries(self, mb_size=20):

        img_t_1, img_t_2 = None, None

        t_inputs, t_targets = self._get_max_len_trajectories()

        max_len = len(t_inputs)
        # get train traj
        train_inputs = np.array(t_inputs)
        train_targets = np.array(t_targets)
        # train_inputs = np.array(self.inputs[:max_len])
        # train_targets = np.array(self.targets[:max_len])

        sample_size = min(mb_size, max_len)
        batch_index_2 = np.random.choice(max_len, size=sample_size, replace=True)
        sa_t_2 = train_inputs[batch_index_2]  # Batch x T x dim of s&a
        r_t_2 = train_targets[batch_index_2]  # Batch x T x 1

        batch_index_1 = np.random.choice(max_len, size=sample_size, replace=True)
        sa_t_1 = train_inputs[batch_index_1]  # Batch x T x dim of s&a
        r_t_1 = train_targets[batch_index_1]  # Batch x T x 1

        # sa_t_1 = sa_t_1.reshape(-1, sa_t_1.shape[-1])  # (Batch x T) x dim of s&a
        # r_t_1 = r_t_1.reshape(-1, r_t_1.shape[-1])  # (Batch x T) x 1
        # sa_t_2 = sa_t_2.reshape(-1, sa_t_2.shape[-1])  # (Batch x T) x dim of s&a
        # r_t_2 = r_t_2.reshape(-1, r_t_2.shape[-1])  # (Batch x T) x 1
        #
        # # Generate time index
        # time_index = np.array([list(range(i * len_traj,
        #                                   i * len_traj + self.size_segment)) for i in range(mb_size)])
        # time_index_2 = time_index + np.random.choice(len_traj - self.size_segment, size=mb_size, replace=True).reshape(
        #     -1, 1)
        # time_index_1 = time_index + np.random.choice(len_traj - self.size_segment, size=mb_size, replace=True).reshape(
        #     -1, 1)
        #
        # sa_t_1 = np.take(sa_t_1, time_index_1, axis=0)  # Batch x size_seg x dim of s&a
        # r_t_1 = np.take(r_t_1, time_index_1, axis=0)  # Batch x size_seg x 1
        # sa_t_2 = np.take(sa_t_2, time_index_2, axis=0)  # Batch x size_seg x dim of s&a
        # r_t_2 = np.take(r_t_2, time_index_2, axis=0)  # Batch x size_seg x 1

        return sa_t_1, sa_t_2, r_t_1, r_t_2

    def put_queries(self, sa_t_1, sa_t_2, labels):
        total_sample = sa_t_1.shape[0]
        next_index = self.buffer_index + total_sample
        if next_index >= self.capacity:
            self.buffer_full = True
            maximum_index = self.capacity - self.buffer_index
            np.copyto(self.buffer_seg1[self.buffer_index:self.capacity], sa_t_1[:maximum_index])
            np.copyto(self.buffer_seg2[self.buffer_index:self.capacity], sa_t_2[:maximum_index])
            np.copyto(self.buffer_label[self.buffer_index:self.capacity], labels[:maximum_index])

            remain = total_sample - (maximum_index)
            if remain > 0:
                np.copyto(self.buffer_seg1[0:remain], sa_t_1[maximum_index:])
                np.copyto(self.buffer_seg2[0:remain], sa_t_2[maximum_index:])
                np.copyto(self.buffer_label[0:remain], labels[maximum_index:])

            self.buffer_index = remain
        else:
            np.copyto(self.buffer_seg1[self.buffer_index:next_index], sa_t_1)
            np.copyto(self.buffer_seg2[self.buffer_index:next_index], sa_t_2)
            np.copyto(self.buffer_label[self.buffer_index:next_index], labels)
            self.buffer_index = next_index

    def get_label(self, sa_t_1, sa_t_2, r_t_1, r_t_2):
        sum_r_t_1 = np.sum(r_t_1, axis=1)
        sum_r_t_2 = np.sum(r_t_2, axis=1)

        # skip the query
        if self.teacher_thres_skip > 0:
            max_r_t = np.maximum(sum_r_t_1, sum_r_t_2)
            max_index = (max_r_t > self.teacher_thres_skip).reshape(-1)
            if sum(max_index) == 0:
                return None, None, None, None, []

            sa_t_1 = sa_t_1[max_index]
            sa_t_2 = sa_t_2[max_index]
            r_t_1 = r_t_1[max_index]
            r_t_2 = r_t_2[max_index]
            sum_r_t_1 = np.sum(r_t_1, axis=1)
            sum_r_t_2 = np.sum(r_t_2, axis=1)

        # equally preferable
        margin_index = (np.abs(sum_r_t_1 - sum_r_t_2) < self.teacher_thres_equal).reshape(-1)

        # perfectly rational
        seg_size = r_t_1.shape[1]
        temp_r_t_1 = r_t_1.copy()
        temp_r_t_2 = r_t_2.copy()
        for index in range(seg_size - 1):
            temp_r_t_1[:, :index + 1] *= self.teacher_gamma
            temp_r_t_2[:, :index + 1] *= self.teacher_gamma
        sum_r_t_1 = np.sum(temp_r_t_1, axis=1)
        sum_r_t_2 = np.sum(temp_r_t_2, axis=1)

        rational_labels = 1 * (sum_r_t_1 < sum_r_t_2)
        if self.teacher_beta > 0:  # Bradley-Terry rational model
            r_hat = torch.cat([torch.Tensor(sum_r_t_1),
                               torch.Tensor(sum_r_t_2)], axis=-1)
            r_hat = r_hat * self.teacher_beta
            ent = F.softmax(r_hat, dim=-1)[:, 1]
            labels = torch.bernoulli(ent).int().numpy().reshape(-1, 1)
        else:
            labels = rational_labels

        # making a mistake
        len_labels = labels.shape[0]
        rand_num = np.random.rand(len_labels)
        noise_index = rand_num <= self.teacher_eps_mistake
        labels[noise_index] = 1 - labels[noise_index]

        # equally preferable
        labels[margin_index] = -1

        return sa_t_1, sa_t_2, r_t_1, r_t_2, labels

    def kcenter_sampling(self):

        # get queries
        num_init = self.mb_size * self.large_batch
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=num_init)

        # get final queries based on kmeans clustering
        temp_sa_t_1 = sa_t_1[:, :, :self.ds]
        temp_sa_t_2 = sa_t_2[:, :, :self.ds]
        temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init, -1),
                                  temp_sa_t_2.reshape(num_init, -1)], axis=1)

        max_len = self.capacity if self.buffer_full else self.buffer_index

        tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
        tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
        tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
                                 tot_sa_2.reshape(max_len, -1)], axis=1)

        selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size, self.device)

        r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
        r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)

        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def kcenter_disagree_sampling(self):

        num_init = self.mb_size * self.large_batch
        num_init_half = int(num_init * 0.5)

        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=num_init)

        # get final queries based on uncertainty
        _, disagree = self.get_rank_probability(sa_t_1, sa_t_2)
        top_k_index = (-disagree).argsort()[:num_init_half]
        r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
        r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]

        # get final queries based on kmeans clustering
        temp_sa_t_1 = sa_t_1[:, :, :self.ds]
        temp_sa_t_2 = sa_t_2[:, :, :self.ds]

        temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1),
                                  temp_sa_t_2.reshape(num_init_half, -1)], axis=1)

        max_len = self.capacity if self.buffer_full else self.buffer_index

        tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
        tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
        tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
                                 tot_sa_2.reshape(max_len, -1)], axis=1)

        selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size, self.device)

        r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
        r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)

        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def kcenter_entropy_sampling(self):

        num_init = self.mb_size * self.large_batch
        num_init_half = int(num_init * 0.5)

        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=num_init)

        # get final queries based on uncertainty
        entropy, _ = self.get_entropy(sa_t_1, sa_t_2)
        top_k_index = (-entropy).argsort()[:num_init_half]
        r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
        r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]

        # get final queries based on kmeans clustering
        temp_sa_t_1 = sa_t_1[:, :, :self.ds]
        temp_sa_t_2 = sa_t_2[:, :, :self.ds]

        temp_sa = np.concatenate([temp_sa_t_1.reshape(num_init_half, -1),
                                  temp_sa_t_2.reshape(num_init_half, -1)], axis=1)

        max_len = self.capacity if self.buffer_full else self.buffer_index

        tot_sa_1 = self.buffer_seg1[:max_len, :, :self.ds]
        tot_sa_2 = self.buffer_seg2[:max_len, :, :self.ds]
        tot_sa = np.concatenate([tot_sa_1.reshape(max_len, -1),
                                 tot_sa_2.reshape(max_len, -1)], axis=1)

        selected_index = KCenterGreedy(temp_sa, tot_sa, self.mb_size, self.device)

        r_t_1, sa_t_1 = r_t_1[selected_index], sa_t_1[selected_index]
        r_t_2, sa_t_2 = r_t_2[selected_index], sa_t_2[selected_index]

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)

        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def uniform_sampling(self):
        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=self.mb_size)

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)

        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def disagreement_sampling(self):

        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=self.mb_size * self.large_batch)

        # get final queries based on uncertainty
        _, disagree = self.get_rank_probability(sa_t_1, sa_t_2)
        top_k_index = (-disagree).argsort()[:self.mb_size]
        r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
        r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)
        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def entropy_sampling(self):

        # get queries
        sa_t_1, sa_t_2, r_t_1, r_t_2 = self.get_queries(
            mb_size=self.mb_size * self.large_batch)

        # get final queries based on uncertainty
        entropy, _ = self.get_entropy(sa_t_1, sa_t_2)

        top_k_index = (-entropy).argsort()[:self.mb_size]
        r_t_1, sa_t_1 = r_t_1[top_k_index], sa_t_1[top_k_index]
        r_t_2, sa_t_2 = r_t_2[top_k_index], sa_t_2[top_k_index]

        # get labels
        sa_t_1, sa_t_2, r_t_1, r_t_2, labels = self.get_label(
            sa_t_1, sa_t_2, r_t_1, r_t_2)

        if len(labels) > 0:
            self.put_queries(sa_t_1, sa_t_2, labels)

        return len(labels)

    def compute_surf_loss(self, member):
        UNLABELED_SIZE = 128
        tau = 0.95
        # get trajectory data
        sa_1, sa_2, _, _ = self.get_queries(mb_size=64)
        r_hat1_vec = self.r_hat_member(sa_1, member=member)
        r_hat2_vec = self.r_hat_member(sa_1, member=member)

        r_hat1 = r_hat1_vec.sum(axis=1)
        r_hat2 = r_hat2_vec.sum(axis=1)
        r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

        max_probs, max_indices = torch.max(torch.softmax(r_hat, dim=-1), dim=-1)
        selected_indxs = max_probs > tau
        r_hat = r_hat[selected_indxs]  # need to filter the data according to tau
        labels = max_indices[selected_indxs]

        if r_hat.numel():  # True if tensor is empty
            return 0
        else:
            # compute loss
            surf_loss = self.CEloss(r_hat, labels)
            return surf_loss

    def compute_triplet_loss(self, t1, t2, labels, r1, r2, member):
        # trajs are in self.inputs - goes into achor
        if len(self.inputs) < 2:
            return 0
        num_anchor_samples = self.reward_triplet_loss_cfg.batch_size  # take these many samples from trajectory buffer.
        t_inputs, _ = self._get_max_len_trajectories()
        t_inputs = np.array(t_inputs)
        sample_size = min(num_anchor_samples, len(t_inputs))
        batch_index = np.random.choice(len(t_inputs), size=sample_size, replace=True)
        t_inputs = t_inputs[batch_index]

        anchor = self.r_hat_member(t_inputs, member=member)
        # (use self.inputs) to ge trajs, and compute r_hat over member

        # prep data so pos has all preferred trajs and neg has all dispreferred trajs from t1, t2, labels
        a1_g, a2_g = r1[labels == 0], r2[labels == 1]
        a1_b, a2_b = r1[labels == 1], r2[labels == 0]
        pos = torch.concat([a1_g, a2_g], dim=0)
        neg = torch.concat([a1_b, a2_b], dim=0)

        anchor_size, pn_size = anchor.shape[0], pos.shape[0]

        a = anchor.repeat(pn_size, 1, 1, 1).transpose(0, 1)  # add pn dimension and swap axis
        p = pos.repeat(anchor_size, 1, 1, 1)
        n = neg.repeat(anchor_size, 1, 1, 1)

        # predict label for the anchor and setup pos / neg examples using the label
        # [ Method 2 : Using the same samples for finding the label & computing loss]
        norm_pos, norm_neg = torch.norm(a - p, 2), torch.norm(a - n, 2)
        if norm_pos > norm_neg:  # if distance to pos is greater then swap
            p, n = n, p  # swap p & n

        # TODO. [Method 1 : Use different samples for finding the label and computing the loss ]

        loss = self.reward_triplet_loss(a, p, n).sum(dim=-1).mean()  # out is [pn_size, anchor_size, traj_len]
        # then take sum along traj_len, mean along rest of dims
        return loss

    def compute_triplet_loss_optimistic_version(self, t1, t2, labels, r1, r2, member):
        # trajs are in self.inputs - goes into anchor
        if len(self.inputs) < 2:
            return 0
        num_anchor_samples = self.reward_triplet_loss_cfg.batch_size  # take these many samples from trajectory buffer.
        t_inputs, _ = self._get_max_len_trajectories()
        t_inputs = np.array(t_inputs)
        sample_size = min(num_anchor_samples, len(t_inputs))
        batch_index = np.random.choice(len(t_inputs), size=sample_size, replace=True)
        t_inputs = t_inputs[batch_index]

        anchor = self.r_hat_member(t_inputs, member=member)
        # (use self.inputs) to ge trajs, and compute r_hat over member

        # prep data so pos has all preferred trajs and neg has all dispreferred trajs from t1, t2, labels
        a1_g, a2_g = r1[labels == 0], r2[labels == 1]
        a1_b, a2_b = r1[labels == 1], r2[labels == 0]
        pos = torch.concat([a1_g, a2_g], dim=0)
        neg = torch.concat([a1_b, a2_b], dim=0)

        anchor_size, pn_size = anchor.shape[0], pos.shape[0]

        a = anchor.repeat(pn_size, 1, 1, 1).transpose(0, 1)  # add pn dimension and swap axis
        p = pos.repeat(anchor_size, 1, 1, 1)
        n = neg.repeat(anchor_size, 1, 1, 1)
        loss = self.reward_triplet_loss(a, p, n).sum(dim=-1).mean()  # out is [pn_size, anchor_size, traj_len]
        # then take sum along traj_len, mean along rest of dims
        return loss

    def compute_action_distance_loss(self, k, thresh, member):
        # get last k trajectories
        t_inputs_, _ = self._get_max_len_trajectories()
        t_inputs = np.array(t_inputs_)[-k:]

        # create dataset from trajectories s.t. dy > thresh
        max_len = len(t_inputs[0])  # length of a trajectory
        indices = itertools.combinations(range(max_len), 2)
        indices = np.array(
            [x for x in indices if abs(x[0] - x[1]) > thresh])  # need to gather all these indices from each trajectory
        left, right = indices[:, 0], indices[:, 1]

        s1, s2 = t_inputs[:, left], t_inputs[:, right]
        dy = torch.from_numpy(np.expand_dims(abs(indices[:, 0] - indices[:, 1]), axis=0)).repeat(len(t_inputs),
                                                                                                 1).reshape(-1).to(
            self.device) / float(max_len)
        # forward pass on r_hat_member() to get the embedding
        _, e1 = self.r_hat_member(s1, member=member, embedding=True)
        _, e2 = self.r_hat_member(s2, member=member, embedding=True)
        d = ((e2 - e1) ** 2).mean(dim=-1).reshape(-1, )

        # compute adloss for input : d = || e2 - e1 || and dy  and backprop
        loss = self.action_distance_loss(y=dy, y_pred=d)
        return loss


    def compute_l2_embed_loss(self, k, thresh, member):
        # get last k trajectories
        t_inputs_, _ = self._get_max_len_trajectories()
        t_inputs = np.array(t_inputs_)[-k:]

        # create dataset from trajectories s.t. dy > thresh
        max_len = len(t_inputs[0])  # length of a trajectory
        indices = itertools.combinations(range(max_len), 2)
        indices = np.array(
            [x for x in indices if abs(x[0] - x[1]) > thresh])  # need to gather all these indices from each trajectory
        left, right = indices[:, 0], indices[:, 1]

        s1, s2 = t_inputs[:, left], t_inputs[:, right]
        # KEY CHANGE FROM AD LOSS
        dy = torch.from_numpy(np.expand_dims(((s2 - s1)**2).mean(dim=-1), axis=0)).repeat(len(t_inputs),
                                                                                                 1).reshape(-1).to(
            self.device) / float(max_len)
        # forward pass on r_hat_member() to get the embedding
        _, e1 = self.r_hat_member(s1, member=member, embedding=True)
        _, e2 = self.r_hat_member(s2, member=member, embedding=True)
        d = ((e2 - e1) ** 2).mean(dim=-1).reshape(-1, )

        # compute adloss for input : d = || e2 - e1 || and dy  and backprop
        loss = self.action_distance_loss(y=dy, y_pred=d)
        return loss

    def compute_rdynamics_loss(self, sa_1, sa_2, sa_1_pred, sa_2_pred, member):
        # remove first state in gt and last state in pred (no gt for these)
        sa_1, sa_2 = sa_1[:, 1:], sa_2[:, 1:]  # dimentions = batch x len x statesize
        sa_1_pred, sa_2_pred = sa_1_pred[:, :-1], sa_2_pred[:, :-1]

        # now sa and sa_pred should be of same shape.
        loss = self.rdynamics(y=sa_1, y_pred=sa_1_pred)
        loss += self.rdynamics(y=sa_2, y_pred=sa_2_pred)

        return loss

    def train_reward(self):
        ensemble_losses = [[] for _ in range(self.de)]
        ensemble_acc = np.array([0 for _ in range(self.de)])

        max_len = self.capacity if self.buffer_full else self.buffer_index
        total_batch_index = []
        for _ in range(self.de):
            total_batch_index.append(np.random.permutation(max_len))

        num_epochs = int(np.ceil(max_len / self.train_batch_size))
        list_debug_loss1, list_debug_loss2 = [], []
        total = 0

        reward_embedding = True if (self.action_distance_loss is not None) or (self.l2embed_loss is not None) else False
        rdynamics_next_state = True if self.rdynamics is not None else False

        for epoch in range(num_epochs):
            self.opt.zero_grad()
            loss = 0.0

            last_index = (epoch + 1) * self.train_batch_size
            if last_index > max_len:
                last_index = max_len

            for member in range(self.de):

                # get random batch
                idxs = total_batch_index[member][epoch * self.train_batch_size:last_index]
                sa_t_1 = self.buffer_seg1[idxs]
                sa_t_2 = self.buffer_seg2[idxs]
                labels = self.buffer_label[idxs]
                labels = torch.from_numpy(labels.flatten()).long().to(self.device)

                if member == 0:
                    total += labels.size(0)

                # get logits
                r_hat1_vec = self.r_hat_member(sa_t_1, member=member, embedding=reward_embedding,
                                               next_state_prediction=rdynamics_next_state)
                r_hat2_vec = self.r_hat_member(sa_t_2, member=member, embedding=reward_embedding,
                                               next_state_prediction=rdynamics_next_state)

                if reward_embedding and rdynamics_next_state:  # extract the tuple
                    r_hat1_vec, r1_embed, r1_next_state = r_hat1_vec
                    r_hat2_vec, r2_embed, r2_next_state = r_hat2_vec
                elif reward_embedding:  # extract the tuple
                    r_hat1_vec, r1_embed = r_hat1_vec
                    r_hat2_vec, r2_embed = r_hat2_vec
                elif rdynamics_next_state:
                    r_hat1_vec, r1_next_state = r_hat1_vec
                    r_hat2_vec, r2_next_state = r_hat2_vec

                r_hat1 = r_hat1_vec.sum(axis=1)
                r_hat2 = r_hat2_vec.sum(axis=1)
                r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

                # compute loss
                curr_loss = self.CEloss(r_hat, labels)
                loss += curr_loss
                ensemble_losses[member].append(curr_loss.item())

                # compute acc
                _, predicted = torch.max(r_hat.data, 1)
                correct = (predicted == labels).sum().item()
                ensemble_acc[member] += correct

            if self.surf_loss is not None:
                surf_loss = self.compute_surf_loss(member)
                lambda_surfloss = 0.1
                loss += lambda_surfloss * surf_loss

            if self.reward_triplet_loss is not None:
                reward_triplet_loss = self.compute_triplet_loss(sa_t_1, sa_t_2, labels, r_hat1_vec, r_hat2_vec,
                                                                member)
                lambda_celoss, lambda_tloss = 1, self.reward_triplet_loss_cfg.weight  # linear combination
                loss = lambda_celoss * loss + lambda_tloss * reward_triplet_loss

            if self.action_distance_loss is not None:
                action_distance_loss = self.compute_action_distance_loss(k=5, thresh=10, member=member)
                lambda_adloss = self.action_distance_loss_cfg.weight
                loss += lambda_adloss * action_distance_loss

            if self.l2embed_loss is not None:
                l2embed_loss = self.compute_l2_embed_loss(k=5, thresh=10, member=member)
                lambda_l2embed = self.l2embed_loss_cfg.weight
                loss += lambda_l2embed * l2embed_loss

            if self.rdynamics is not None:
                rdynamics_loss = self.compute_rdynamics_loss(sa_t_1, sa_t_2, r1_next_state, r2_next_state,
                                                             member=member)
                lambda_rdloss = self.rdynamics_loss_cfg.weight
                loss += lambda_rdloss * rdynamics_loss

            loss.backward()
            self.opt.step()

        ensemble_acc = ensemble_acc / total

        return ensemble_acc

    def train_soft_reward(self):
        ensemble_losses = [[] for _ in range(self.de)]
        ensemble_acc = np.array([0 for _ in range(self.de)])

        max_len = self.capacity if self.buffer_full else self.buffer_index
        total_batch_index = []
        for _ in range(self.de):
            total_batch_index.append(np.random.permutation(max_len))

        num_epochs = int(np.ceil(max_len / self.train_batch_size))
        list_debug_loss1, list_debug_loss2 = [], []
        total = 0

        for epoch in range(num_epochs):
            self.opt.zero_grad()
            loss = 0.0

            last_index = (epoch + 1) * self.train_batch_size
            if last_index > max_len:
                last_index = max_len

            for member in range(self.de):

                # get random batch
                idxs = total_batch_index[member][epoch * self.train_batch_size:last_index]
                sa_t_1 = self.buffer_seg1[idxs]
                sa_t_2 = self.buffer_seg2[idxs]
                labels = self.buffer_label[idxs]
                labels = torch.from_numpy(labels.flatten()).long().to(self.device)

                if member == 0:
                    total += labels.size(0)

                # get logits
                r_hat1 = self.r_hat_member(sa_t_1, member=member)
                r_hat2 = self.r_hat_member(sa_t_2, member=member)
                r_hat1 = r_hat1.sum(axis=1)
                r_hat2 = r_hat2.sum(axis=1)
                r_hat = torch.cat([r_hat1, r_hat2], axis=-1)

                # compute loss
                uniform_index = labels == -1
                labels[uniform_index] = 0
                target_onehot = torch.zeros_like(r_hat).scatter(1, labels.unsqueeze(1), self.label_target)
                target_onehot += self.label_margin
                if sum(uniform_index) > 0:
                    target_onehot[uniform_index] = 0.5
                curr_loss = self.softXEnt_loss(r_hat, target_onehot)
                loss += curr_loss
                ensemble_losses[member].append(curr_loss.item())

                # compute acc
                _, predicted = torch.max(r_hat.data, 1)
                correct = (predicted == labels).sum().item()
                ensemble_acc[member] += correct

            if self.reward_triplet_loss is not None:
                action_distance_loss = self.compute_triplet_loss()
                l_a_1, l_a_2 = 0.6, 0.4  # linear combination
                loss = l_a_1 * loss + l_a_2 * action_distance_loss

            loss.backward()
            self.opt.step()

        ensemble_acc = ensemble_acc / total

        return ensemble_acc
