import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from policies import NeuralNetwork, NNwLinear, QBackBone, Matrix
import copy
import time
import estimators
import pdb
import utils
import utils_gw
from utils import soft_target_update, get_phi_stats,\
    load_mini_batches, PhiMetricTracker, reset_optimizer, online_target_difference
from scipy.linalg import schur
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import pairwise_distances
from behavior_dataset import PWDataset

SQRT_EPS = 5e-5

class RepAlgo:
    def __init__(self, pie, gamma, soft_update_tau = 0, hard_update_freq = 0):
        self.pmt = PhiMetricTracker(pie, gamma)
        self.soft_update_tau = soft_update_tau
        self.hard_update_freq = hard_update_freq
    
    def get_metrics(self):
        metrics = self.pmt.metrics
        return metrics

    def get_phi(self):
        return self.best_phi
    
    def get_critic(self):
        return self.best_phi

class KernelROPE(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            self.phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, lspe = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.rew_range = tr_data.max_abs_reward_diff
        self.best_phi = copy.deepcopy(self.target_phi)
        assert self.rew_range > 0

        check_realizability = self.tabular and lspe is not None and ope_evaluator is not None

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader1 = torch.utils.data.DataLoader(tr_data, **params)
        dataloader2 = torch.utils.data.DataLoader(tr_data, **params)
        pw_dataset = isinstance(tr_data, PWDataset)

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, (mb1, mb2) in enumerate(zip(dataloader1, dataloader2)):
                curr_sa = mb1['curr_sa']
                next_sa = mb1['next_sa']
                rewards = mb1['rewards']
                terminal_masks = mb1['terminal_masks']

                if pw_dataset:
                    other_curr_sa = mb1['other_curr_sa']
                    other_next_sa = mb1['other_next_sa']
                    other_rewards = mb1['other_rewards']
                    other_terminal_masks = mb1['other_terminal_masks']
                else:
                    other_curr_sa = mb2['curr_sa']
                    other_next_sa = mb2['next_sa']
                    other_rewards = mb2['rewards']
                    other_terminal_masks = mb2['terminal_masks']

                objs = self._ksme_loss(rewards, other_rewards,\
                            curr_sa, other_curr_sa,\
                            next_sa, other_next_sa, terminal_masks, other_terminal_masks)
                total_obj = objs['total_obj']
                losses.append(total_obj)

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    
                    # gradient step
                    self.optimizer.step()
                    #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()

            diff = online_target_difference(self.phi, self.target_phi)
            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                lspe.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi)
                ope_error = ope_evaluator.evaluate(self.best_phi, lspe.get_theta())
                objs['phi_ope_error'] = ope_error
                objs['phi_spectral_radius'] = lspe.get_metrics()['phi_lspe_stats']['spectral_radius']
                objs['phi_pos_eigen_frac'] = lspe.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi,\
                objs, curr_sa, next_sa, pie = self.pie, ope_evaluator = ope_evaluator,\
                check_realizability = check_realizability, terminals = terminal_masks)
            print (f'{epoch}, {diff}, stats: {stats}')

    def _ksme_loss(self, rews, other_rews,\
                state_actions, other_state_actions,\
                next_state_actions, other_next_state_actions, terminal_masks, other_terminal_masks):

        # pairwise reward distances
        # reward_dist = torch.abs((rews.unsqueeze(1) - rews))
        # phi_x = self.phi.forward(state_actions)
        # curr_Kxy = torch.mm(phi_x, phi_x.T)

        # next_phi_x = torch.Tensor(self.target_phi(next_state_actions))
        # next_Kxy = torch.mm(next_phi_x, next_phi_x.T)
        # target = self.rew_range - reward_dist + self.gamma * next_Kxy
        # #target = 1. - (reward_dist / self.rew_range) + self.gamma * next_Kxy

        # obj = torch.nn.functional.mse_loss(curr_Kxy, target, reduction = 'none')
        # obj = obj[~torch.eye(*obj.shape,dtype = torch.bool)]
        # obj = obj.mean()

        reward_dist = torch.Tensor(np.abs(rews - other_rews))
        phi_x = self.phi.forward(state_actions)
        phi_y = self.phi.forward(other_state_actions)
        curr_dotprod = torch.sum(phi_x * phi_y, axis = 1)
        curr_Kxy = curr_dotprod

        next_phi_x = torch.tensor(self.target_phi(next_state_actions))
        next_phi_y = torch.tensor(self.target_phi(other_next_state_actions))
        next_dotprod = torch.sum(next_phi_x * next_phi_y, axis = 1)
        next_Kxy = next_dotprod
        #target = self.rew_range - reward_dist + self.gamma * next_Kxy
        target = 1. - (reward_dist / self.rew_range) + terminal_masks * other_terminal_masks * self.gamma * next_Kxy
        #target = 1. - (reward_dist / self.rew_range) + self.gamma * next_Kxy
        #print (self.target_phi(np.unique(next_state_actions, axis = 0)))
        #pdb.set_trace()

        #obj = torch.nn.functional.huber_loss(curr_Kxy, target)

        obj = torch.nn.functional.mse_loss(curr_Kxy, target, reduction = 'none')
        obj = obj.mean()

        frac = torch.Tensor([-1])#torch.count_nonzero(torch.abs(curr_Kxy - target) < 1) / curr_Kxy.shape[0]


        cov = -1#torch.bmm(phi_x.unsqueeze(-1), phi_x.unsqueeze(-2)).mean(dim=0)
        design_loss = 0#-self.log_det(cov)

        # dt = torch.sum(phi_x * next_phi_x, axis = 1)
        # ortho_obj = torch.square(dt) \
        #     - torch.square(torch.linalg.norm(phi_x, axis = 1))\
        #     - torch.square(torch.linalg.norm(next_phi_x, axis = 1))\
        #     + self.abs_state_action_dims
        # ortho_obj = 0*1e-2 * torch.mean(ortho_obj)
        
        #class_input = torch.sigmoid(curr_Kxy)
        #class_target = 1. - (reward_dist / self.rew_range) + self.gamma * torch.sigmoid(next_Kxy)
        #obj = torch.nn.functional.binary_cross_entropy(class_input, class_target)
        total_obj = obj + 1e-2 * design_loss#ortho_obj
        objs = {
            'obj': obj,
            'total_obj': total_obj,
            'frac': frac
        }
        return objs

    def log_det(self, A):
            assert A.dim() in [2, 3]
            # regularize when computing log-det
            A = A + 1e-5 * torch.eye(A.shape[1], device=A.device)
            return 2 * torch.linalg.cholesky(A).diagonal(dim1=-2, dim2=-1).log().sum(-1)

class BCRL(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        phi_lr = 3e-4,
        M_lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        bcrl_type = 'both',
        logdet_coeff = 0,
        norm_selfpred = False,
        soft_update_tau = 5e-3,
        hard_update_freq = 5,
        ope_method = 'lspe'):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.phi_lr = phi_lr
        self.M_lr = M_lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular
        self.logdet_coeff = logdet_coeff
        self.ope_method = ope_method

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            self.phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)

            self.M_phi = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            self.M_rew = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = 1,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            
        self.target_phi = copy.deepcopy(self.phi)
        self.phi_optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.phi_lr, betas = (0.9, 0.9))
        self.M_params = list(self.M_phi.parameters()) + list(self.M_rew.parameters())
        self.M_optimizer = torch.optim.AdamW(self.M_params, lr = self.M_lr)

        self.use_rew_pred = 0.
        self.use_lat_pred = 0.
        self.norm_selfpred = norm_selfpred

        if bcrl_type == 'rew':
            self.use_rew_pred = 1.
        elif bcrl_type == 'lat':
            self.use_lat_pred = 1.
        else:
            self.use_lat_pred = 1.
            self.use_rew_pred = 1.

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, ope_algo = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.best_phi = copy.deepcopy(self.target_phi)

        check_realizability = self.tabular and ope_algo is not None and ope_evaluator is not None

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(tr_data, **params)

        for epoch in range(0, epochs + 1):
            # update M
            for mb in dataloader:
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']
                
                M_obj, reward_obj, BC_obj = self._M_loss(rewards, curr_sa, next_sa, terminal_masks)

                if epoch > 0:
                    self.M_optimizer.zero_grad()
                    M_obj.backward()
                    nn.utils.clip_grad_value_(self.M_params, clip_value = 1.0)
                    self.M_optimizer.step()

            # update phi
            losses = []
            for mb in dataloader:
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']

                phi_obj = self._phi_loss(rewards, curr_sa, next_sa, terminal_masks)
                total_obj = phi_obj + M_obj
                losses.append(total_obj)
                
                if epoch > 0:
                    self.phi_optimizer.zero_grad()
                    phi_obj.backward()
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    self.phi_optimizer.step()
            avg_loss = torch.Tensor(losses).mean()

            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.phi_optimizer = reset_optimizer(self.phi, self.phi_lr, 0.9, 0.9)
                #self.M_optimizer = reset_optimizer(self.M_params, self.M_lr, 0.9, 0.9)

            objs = {'total_obj': avg_loss}
            total_obj = total_obj.item()

            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                ope_algo.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi)
                ope_error = ope_evaluator.evaluate(self.best_phi, ope_algo.get_theta())
                ope_ret = ope_evaluator.evaluate(self.best_phi, ope_algo.get_theta(), metric_type = 'return')
                objs['phi_ope_error'] = ope_error
                objs['phi_ope_ret'] = ope_ret
                objs['phi_spectral_radius'] = ope_algo.get_metrics()['phi_lspe_stats']['spectral_radius']
                objs['phi_pos_eigen_frac'] = ope_algo.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi,\
                objs, curr_sa, next_sa, pie = self.pie, ope_evaluator = ope_evaluator,\
                check_realizability = check_realizability, terminals = terminal_masks)
            print (f'{epoch}, stats: {stats}, bc_obj: {BC_obj.item()}, reward_obj: {reward_obj.item()}')

    def _phi_loss(self, rewards, curr_sa, next_sa, terminal_masks):
        phi_x = self.phi.forward(curr_sa)
        terminal_masks = terminal_masks.reshape(-1, 1)
        #next_phi_x = terminal_masks * self.phi.forward(next_sa)
        with torch.no_grad():
            next_phi_x = terminal_masks * torch.Tensor(self.target_phi(next_sa))

        pred_next_phi_x = self.M_phi.forward(phi_x)
        pred_rew = self.M_rew.forward(phi_x)

        design_loss = 0
        if self.logdet_coeff > 0:
            cov = torch.bmm(phi_x.unsqueeze(-1), phi_x.unsqueeze(-2)).mean(dim=0)
            design_loss = -self.log_det(cov)
        # ncov = self.gamma * torch.bmm(phi_x.unsqueeze(-1), next_phi_x.unsqueeze(-2)).mean(dim=0)
        # ncov = ncov + 1e-5 * torch.eye(*ncov.shape)
        # inv_ncov = torch.linalg.inv(ncov)
        # #res = torch.matmul(inv_ncov, cov)
        # res = cov - ncov
        # design_loss = -torch.log(torch.linalg.det(res))
        
        rewards = torch.Tensor(rewards.reshape((-1, 1)))
        reward_loss = F.mse_loss(pred_rew, rewards)

        BC_loss = self._self_pred_error(pred_next_phi_x, next_phi_x)

        # Combined Loss
        phi_loss = (
            self.use_rew_pred * 1.0 * reward_loss
            + self.use_lat_pred * BC_loss
            + self.logdet_coeff * design_loss
        )
        return phi_loss

    def _M_loss(self, rewards, curr_sa, next_sa, terminal_masks):
        terminal_masks = terminal_masks.reshape(-1, 1)
        with torch.no_grad():
            phi_x = torch.Tensor(self.phi(curr_sa))
            next_phi_x = terminal_masks * torch.Tensor(self.target_phi(next_sa))

        pred_next_phi_x = self.M_phi.forward(phi_x)
        pred_rew = self.M_rew.forward(phi_x)
        rewards = torch.Tensor(rewards.reshape((-1, 1)))
        reward_loss = F.mse_loss(pred_rew, rewards)

        BC_loss = self._self_pred_error(pred_next_phi_x, next_phi_x)

        M_loss = self.use_rew_pred * 1. * reward_loss\
            + self.use_lat_pred * BC_loss

        return M_loss, reward_loss, BC_loss

    def _self_pred_error(self, pred_next_phi_x, next_phi_x):

        BC_loss = torch.square(torch.linalg.vector_norm(pred_next_phi_x - self.gamma * next_phi_x, dim=1))
        
        if self.norm_selfpred:
            # pred_next_phi_x_norm = torch.linalg.vector_norm(pred_next_phi_x, dim = 1)
            # next_phi_x_norm = torch.linalg.vector_norm(next_phi_x, dim = 1)
            # BC_loss = BC_loss / (pred_next_phi_x_norm * next_phi_x_norm)
            next_phi_x_norm = torch.linalg.norm(next_phi_x, axis = 1)
            clipped_value = torch.full_like(next_phi_x_norm, fill_value=1e-8, dtype=torch.float)
            next_phi_x_norm = torch.maximum(next_phi_x_norm, clipped_value)
            cosine_sim = torch.sum(pred_next_phi_x * self.gamma * next_phi_x, axis = 1)\
               / (next_phi_x_norm * torch.linalg.norm(pred_next_phi_x, axis = 1))
            BC_loss = -cosine_sim.mean()
        
        BC_loss = BC_loss.mean()
        return BC_loss

    def log_det(self, A):
            assert A.dim() in [2, 3]
            # regularize when computing log-det
            A = A + 1e-5 * torch.eye(A.shape[1], device=A.device)
            return 2 * torch.linalg.cholesky(A).diagonal(dim1=-2, dim2=-1).log().sum(-1)

class FQE(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5,
        use_penultimate = False,
        ope_method = 'lspe'):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular
        self.use_penultimate = use_penultimate
        self.ope_method = ope_method

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            phi = QBackBone(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)
            if self.use_penultimate:
                self.phi = NNwLinear(phi)
            else:
                self.phi = phi

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))
        self.Qs = []

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, ope_algo = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.best_phi = copy.deepcopy(self.target_phi)

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(tr_data, **params)
        if test_data is not None:
            test_dataloader = torch.utils.data.DataLoader(test_data, **params)

        check_realizability = self.tabular and ope_algo is not None and ope_evaluator is not None

        for epoch in range(0, epochs + 1):
            losses = []
            brm_losses = []
            for idx, mb in enumerate(dataloader):
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']

                objs = self._td_loss(tr_data, rewards, curr_sa, next_sa, terminal_masks)
                total_obj = objs['total_obj']
                losses.append(total_obj)
                brm_losses.append(objs['brm_obj'])

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    
                    # gradient step
                    self.optimizer.step()
                    #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()
            avg_brm = torch.Tensor(brm_losses).mean()
            diff = online_target_difference(self.phi, self.target_phi)

            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss
            objs['brm_obj'] = avg_brm

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                if self.ope_method == 'lspe':
                    ope_algo.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi.backbone)
                    ope_error = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta())
                    ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta(), metric_type = 'return')
                    objs['phi_spectral_radius'] = ope_algo.get_metrics()['phi_lspe_stats']['spectral_radius']
                    objs['phi_pos_eigen_frac'] = ope_algo.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
                # elif self.ope_method == 'fqe':
                #     ope_algo.train(tr_data, epochs = 10, phi = self.best_phi.backbone)
                #     ope_error = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta())
                #     ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta(), metric_type = 'return')
                elif self.ope_method == 'fqe-e2e':
                    w_weights = [param.detach() for param in self.best_phi.linear.parameters()][0].numpy().reshape(-1)
                    ope_error = ope_evaluator.evaluate(self.best_phi.backbone, w_weights)
                    ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, w_weights, metric_type = 'return')
                objs['phi_ope_error'] = ope_error
                objs['phi_ope_ret'] = ope_ret

            self.Qs.append(copy.deepcopy(self.best_phi))
            if len(self.Qs) == 21:
                # evaluate
                val_path_err = ope_evaluator.value_eval_path(self.Qs)
                self.Qs.clear()
                objs['phi_val_eval_path_err'] = val_path_err

            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi.backbone,\
                objs, curr_sa, next_sa, pie = self.pie,\
                ope_evaluator = ope_evaluator, check_realizability = check_realizability, terminals = terminal_masks)
            #test_err = self._get_test_error(test_data, test_dataloader)
            print (f'{epoch}, {diff}, stats: {stats}')

    # def _get_test_error(self, test_data, test_dataloader):
    #     count = 0
    #     with torch.no_grad():
    #         total_error = 0
    #         for idx, mb in enumerate(test_dataloader):
    #             curr_sa = mb['curr_sa']
    #             next_sa = mb['next_sa']
    #             rewards = mb['rewards']
    #             terminal_masks = mb['terminal_masks']

    #             objs = self._td_loss(test_data, rewards, curr_sa, next_sa, terminal_masks)
    #             total_obj = objs['total_obj']
    #             total_error += total_obj
    #             count += 1
    #         avg_error = (total_error / count).item()
    #         return avg_error

    def _td_loss(self, data, rews, curr_sa, next_sa, term_masks):

        q_curr_outputs = self.phi.forward(curr_sa).reshape(-1)
        with torch.no_grad():
            q_next_rep = torch.Tensor(self.target_phi.backbone(next_sa))
            q_next_outputs = self.target_phi.linear(q_next_rep).reshape(-1).numpy()
        
        target = torch.Tensor(rews + self.gamma * term_masks * q_next_outputs)


        # reps = self.phi.backbone.forward(curr_sa)

        # pw = torch.cdist(reps, reps)
        # off_diagonal = pw[~torch.eye(*pw.shape, dtype = torch.bool)]

        # sim_matrix = cosine_similarity(q_next_rep)

        # rounded_vals = np.round(q_next_outputs / 0.1).astype(int).reshape(-1, 1)
        # pw_qdiff = pairwise_distances(rounded_vals, metric = 'l1')
        # pw_qdiff[np.eye(*pw_qdiff.shape).astype(bool)] = np.inf
        # nearest_neigh = np.argmin(pw_qdiff, axis = 1)
        # pos_sim = torch.exp(torch.Tensor(sim_matrix[np.arange(sim_matrix.shape[0]), nearest_neigh]))

        # neg_logits = torch.sum(torch.exp(torch.Tensor(sim_matrix)))

        # pwc_obj = -torch.log(pos_sim / neg_logits).mean()
        #neg_logits = torch.logsumexp(torch.Tensor(sim_matrix), dim = 1)
        #pwc_obj = torch.mean(neg_logits - pos_sim)
        pwc_obj = 0#-off_diagonal.mean()
        # if self.clip_target:
        #     target = torch.clip(target,\
        #         min = data.min_reward / (1. - self.gamma),\
        #         max = data.max_reward / (1. - self.gamma))

        obj = torch.nn.functional.mse_loss(q_curr_outputs, target, reduction = 'none')
        obj = obj.mean()
        total_obj = obj + 1 * pwc_obj

        brm_q_next_outputs = self.phi(next_sa).reshape(-1)
        brm_target = torch.Tensor(rews + self.gamma * term_masks * brm_q_next_outputs)
        brm_obj = torch.nn.functional.mse_loss(q_curr_outputs, brm_target, reduction = 'none')
        brm_obj = brm_obj.mean()

        objs = {
            'obj': obj,
            'total_obj': total_obj,
            'brm_obj': brm_obj
        }
        return objs

    def get_grouping(self, array):
        groups = {}
        for num in array:
            if num in groups:
                groups[num].append(num)
            else:
                groups[num] = [num]

        group_ids = list(groups.keys())
        sa_to_group = {}
        for sa_idx, val in enumerate(array):
            group_idx = group_ids.index(val)
            sa_to_group[sa_idx] = group_idx
        return groups, sa_to_group    

    def get_phi(self):
        return self.best_phi.backbone

class TCL(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        M_lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.M_lr = M_lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            self.phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)

            self.M_phi = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            
        self.M_params = list(self.M_phi.parameters())
        self.M_optimizer = torch.optim.AdamW(self.M_params, lr = self.M_lr)

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, lspe = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.rew_range = tr_data.max_abs_reward_diff
        self.best_phi = copy.deepcopy(self.target_phi)
        assert self.rew_range > 0

        check_realizability = self.tabular and lspe is not None and ope_evaluator is not None

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader1 = torch.utils.data.DataLoader(tr_data, **params)
        dataloader2 = torch.utils.data.DataLoader(tr_data, **params)
        pw_dataset = isinstance(tr_data, PWDataset)

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, (mb1, mb2) in enumerate(zip(dataloader1, dataloader2)):
                curr_sa = mb1['curr_sa']
                next_sa = mb1['next_sa']
                rewards = mb1['rewards']
                terminal_masks = mb1['terminal_masks']

                if pw_dataset:
                    other_curr_sa = mb1['other_curr_sa']
                    other_next_sa = mb1['other_next_sa']
                    other_rewards = mb1['other_rewards']
                    other_terminal_masks = mb1['other_terminal_masks']
                else:
                    other_curr_sa = mb2['curr_sa']
                    other_next_sa = mb2['next_sa']
                    other_rewards = mb2['rewards']
                    other_terminal_masks = mb2['terminal_masks']

                objs = self._tcl_loss(curr_sa, other_curr_sa,\
                            next_sa, terminal_masks)
                total_obj = objs['total_obj']
                losses.append(total_obj)

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    self.M_optimizer.zero_grad()
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        pdb.set_trace()
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    nn.utils.clip_grad_value_(self.M_params, clip_value = 1.0)

                    # gradient step
                    self.optimizer.step()
                    self.M_optimizer.step()
                    #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()

            diff = online_target_difference(self.phi, self.target_phi)
            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)
                #self.M_optimizer = reset_optimizer(self.M_params, self.M_lr, 0.9, 0.9)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                lspe.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi)
                ope_error = ope_evaluator.evaluate(self.best_phi, lspe.get_theta())
                objs['phi_ope_error'] = ope_error
                objs['phi_spectral_radius'] = lspe.get_metrics()['phi_lspe_stats']['spectral_radius']
                objs['phi_pos_eigen_frac'] = lspe.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi,\
                objs, curr_sa, next_sa, pie = self.pie, ope_evaluator = ope_evaluator,\
                check_realizability = check_realizability, terminals = terminal_masks)
            print (f'{epoch}, {diff}, stats: {stats}')

    def _tcl_loss(self, state_actions, other_state_actions,\
                next_state_actions, terminal_masks):

        phi_csa = self.phi.forward(state_actions)
        phi_nsa = self.phi.forward(next_state_actions)
        phi_other_csa = self.phi.forward(other_state_actions)

        M_curr_sa = self.M_phi.forward(phi_csa)

        successive_dot = -torch.sum(phi_nsa * M_curr_sa, axis = 1)
        curr_other_dot = torch.exp(torch.sum(phi_other_csa * M_curr_sa, axis = 1))

        obj = successive_dot.mean() + torch.log(curr_other_dot.mean())

        total_obj = obj
        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

class Recon(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        M_lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.M_lr = M_lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            self.phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)
            self.M_phi = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = ground_state_action_dims,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            
        self.M_params = list(self.M_phi.parameters())
        self.M_optimizer = torch.optim.AdamW(self.M_params, lr = self.M_lr)

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, lspe = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.best_phi = copy.deepcopy(self.target_phi)

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(tr_data, **params)
        if test_data is not None:
            test_dataloader = torch.utils.data.DataLoader(test_data, **params)

        check_realizability = self.tabular and lspe is not None and ope_evaluator is not None

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, mb in enumerate(dataloader):
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']

                objs = self._recon_loss(curr_sa, next_sa, terminal_masks)
                total_obj = objs['total_obj']
                losses.append(total_obj)

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    self.M_optimizer.zero_grad()
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    nn.utils.clip_grad_value_(self.M_params, clip_value = 1.0)

                    # gradient step
                    self.optimizer.step()
                    self.M_optimizer.step()
                    #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()
            diff = online_target_difference(self.phi, self.target_phi)

            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                lspe.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi)
                ope_error = ope_evaluator.evaluate(self.best_phi, lspe.get_theta())
                objs['phi_ope_error'] = ope_error
                objs['phi_spectral_radius'] = lspe.get_metrics()['phi_lspe_stats']['spectral_radius']
                objs['phi_pos_eigen_frac'] = lspe.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi,\
                objs, curr_sa, next_sa, pie = self.pie,\
                ope_evaluator = ope_evaluator, check_realizability = check_realizability, terminals = terminal_masks)
            print (f'{epoch}, {diff}, stats: {stats}')

    def _recon_loss(self, curr_sa, next_sa, term_masks):

        phi_csa = self.phi.forward(curr_sa)
        next_recon = self.M_phi.forward(phi_csa)
        terminal_masks = term_masks.reshape(-1, 1)

        cosine_sim = torch.sum(next_recon * terminal_masks * next_sa, axis = 1)\
            / (torch.linalg.norm(next_recon, axis = 1) * torch.linalg.norm(next_sa, axis = 1))
        recon_loss = -cosine_sim.mean()

        obj = recon_loss.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

    def get_phi(self):
        return self.best_phi

class FQEAux(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        M_lr = 3e-4,
        logdet_coeff = 0,
        norm_selfpred = False,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5,
        use_penultimate = False,
        aux_task = None,
        aux_alpha = 0.1,
        krope_kernel = 'dot',
        krope_sigma = 1e-2,
        ope_method = 'lspe'):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.M_lr = M_lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular
        self.use_penultimate = use_penultimate
        self.logdet_coeff = logdet_coeff
        self.norm_selfpred = norm_selfpred
        self.krope_kernel = krope_kernel
        self.krope_sigma = krope_sigma
        self.ope_method = ope_method

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            phi = QBackBone(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)
            if self.use_penultimate:
                self.phi = NNwLinear(phi)
            else:
                self.phi = phi

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))
        self.Qs = []
        self.aux_task = aux_task
        self.aux_alpha = aux_alpha
        self.M_optimizer = None

        if self.aux_task == 'recon':
            self.M_phi = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = ground_state_action_dims,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            
            self.M_params = list(self.M_phi.parameters())
            self.M_optimizer = torch.optim.AdamW(self.M_params, lr = self.M_lr)
        elif self.aux_task == 'bcrl':
            self.M_phi = NeuralNetwork(input_dims = abs_state_action_dims,
                            output_dims = abs_state_action_dims,
                            hidden_dim = -1,
                            hidden_layers = 0,
                            activation = None,
                            final_activation = None)
            self.M_rew = NeuralNetwork(input_dims = abs_state_action_dims,
                                    output_dims = 1,
                                    hidden_dim = -1,
                                    hidden_layers = 0,
                                    activation = None,
                                    final_activation = None)
            
            self.M_params = list(self.M_phi.parameters()) + list(self.M_rew.parameters())
            self.M_optimizer = torch.optim.AdamW(self.M_params, lr = self.M_lr)

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, ope_algo = None, ope_evaluator = None, print_log = True):

        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.rew_range = tr_data.max_abs_reward_diff
        self.best_phi = copy.deepcopy(self.target_phi)
        assert self.rew_range > 0

        check_realizability = self.tabular and ope_algo is not None and ope_evaluator is not None

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader1 = torch.utils.data.DataLoader(tr_data, **params)
        dataloader2 = torch.utils.data.DataLoader(tr_data, **params)
        pw_dataset = isinstance(tr_data, PWDataset)

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, (mb1, mb2) in enumerate(zip(dataloader1, dataloader2)):
                curr_sa = mb1['curr_sa']
                next_sa = mb1['next_sa']
                rewards = mb1['rewards']
                terminal_masks = mb1['terminal_masks']
                
                if pw_dataset:
                    other_curr_sa = mb1['other_curr_sa']
                    other_next_sa = mb1['other_next_sa']
                    other_rewards = mb1['other_rewards']
                    other_terminal_masks = mb1['other_terminal_masks']
                else:
                    other_curr_sa = mb2['curr_sa']
                    other_next_sa = mb2['next_sa']
                    other_rewards = mb2['rewards']
                    other_terminal_masks = mb2['terminal_masks']

                td_loss = self._td_loss(tr_data, rewards, curr_sa, next_sa, terminal_masks)

                aux_loss = self._get_aux_loss(rewards, other_rewards,\
                            curr_sa, other_curr_sa,\
                            next_sa, other_next_sa, terminal_masks, other_terminal_masks)

                objs = {}
                alpha = self.aux_alpha
                for k in td_loss:
                    objs[k] = alpha * aux_loss[k] + (1. - alpha) * td_loss[k]

                total_obj = objs['total_obj']
                losses.append(total_obj)

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    if self.M_optimizer: self.M_optimizer.zero_grad() 
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    if self.M_optimizer: nn.utils.clip_grad_value_(self.M_params, clip_value = 1.0)
                    
                    # gradient step
                    self.optimizer.step()
                    if self.M_optimizer: self.M_optimizer.step() 
                    #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()

            diff = online_target_difference(self.phi, self.target_phi)
            # hard update target network
            if epoch % self.hard_update_freq == 0 or epoch == epochs:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                if self.ope_method == 'lspe':
                    ope_algo.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi.backbone)
                    ope_error = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta())
                    ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta(), metric_type = 'return')
                    objs['phi_spectral_radius'] = ope_algo.get_metrics()['phi_lspe_stats']['spectral_radius']
                    objs['phi_pos_eigen_frac'] = ope_algo.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
                # elif self.ope_method == 'fqe':
                #     ope_algo.train(tr_data, epochs = 10, phi = self.best_phi.backbone)
                #     ope_error = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta())
                #     ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, ope_algo.get_theta(), metric_type = 'return')
                elif self.ope_method == 'fqe-e2e':
                    w_weights = [param.detach() for param in self.best_phi.linear.parameters()][0].numpy().reshape(-1)
                    ope_error = ope_evaluator.evaluate(self.best_phi.backbone, w_weights)
                    ope_ret = ope_evaluator.evaluate(self.best_phi.backbone, w_weights, metric_type = 'return')
                objs['phi_ope_error'] = ope_error
                objs['phi_ope_ret'] = ope_ret

            self.Qs.append(copy.deepcopy(self.best_phi))
            if len(self.Qs) == 21:
                # evaluate
                val_path_err = ope_evaluator.value_eval_path(self.Qs)
                self.Qs.clear()
                objs['phi_val_eval_path_err'] = val_path_err

            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi.backbone,\
                objs, curr_sa, next_sa, pie = self.pie, ope_evaluator = ope_evaluator,\
                check_realizability = check_realizability, terminals = terminal_masks)
            print (f'{epoch}, {diff}, stats: {stats}')

    def _get_aux_loss(self, rews, other_rews,\
                state_actions, other_state_actions,\
                next_state_actions, other_next_state_actions,\
                terminal_masks, other_terminal_masks):
        
        if self.aux_task == 'krope':
            return self._ksme_loss(rews, other_rews,\
                            state_actions, other_state_actions,\
                            next_state_actions, other_next_state_actions, terminal_masks, other_terminal_masks)
        elif self.aux_task == 'recon':
            return self._recon_loss(state_actions, next_state_actions, terminal_masks)
        elif self.aux_task == 'bcrl':    
            M_obj, _, _ = self._M_loss(rews, state_actions, next_state_actions, terminal_masks)
            phi_obj = self._phi_loss(rews, state_actions, next_state_actions, terminal_masks)
            combined_obj = {
                'obj': M_obj + phi_obj,
                'total_obj': M_obj + phi_obj,
            }
            return combined_obj
        elif self.aux_task == 'dr3':
            obj = self._dr3_loss(state_actions, next_state_actions)
            return obj
        elif self.aux_task == 'beer':
            obj = self._beer_loss(state_actions, next_state_actions, rews)
            return obj

    def _ksme_loss(self, rews, other_rews,\
                state_actions, other_state_actions,\
                next_state_actions, other_next_state_actions,\
                terminal_masks, other_terminal_masks):

        reward_dist = torch.Tensor(np.abs(rews - other_rews))
        phi_x = self.phi.backbone.forward(state_actions)
        phi_y = self.phi.backbone.forward(other_state_actions)

        next_phi_x = torch.tensor(self.target_phi.backbone(next_state_actions))
        next_phi_y = torch.tensor(self.target_phi.backbone(other_next_state_actions))

        if self.krope_kernel == 'dot':
            curr_dotprod = torch.sum(phi_x * phi_y, axis = 1)
            curr_Kxy = curr_dotprod
            next_dotprod = torch.sum(next_phi_x * next_phi_y, axis = 1)
            next_Kxy = next_dotprod
        elif self.krope_kernel == 'gaussian':
            sq_euc_dist_xy = torch.square(torch.linalg.norm(phi_x - phi_y, axis = 1))
            curr_Kxy = torch.exp(-self.krope_sigma * sq_euc_dist_xy)
            
            next_sq_euc_dist_xy = torch.square(torch.linalg.norm(next_phi_x - next_phi_y, axis = 1))
            next_Kxy = torch.exp(-self.krope_sigma * next_sq_euc_dist_xy)

        #target = self.rew_range - reward_dist + self.gamma * next_Kxy
        target = 1. - (reward_dist / self.rew_range) + terminal_masks * other_terminal_masks * self.gamma * next_Kxy

        #obj = torch.nn.functional.huber_loss(curr_Kxy, target)

        obj = torch.nn.functional.mse_loss(curr_Kxy, target, reduction = 'none')
        obj = obj.mean()

        frac = torch.Tensor([-1])#torch.count_nonzero(torch.abs(curr_Kxy - target) < 1) / curr_Kxy.shape[0]
        cov = -1#torch.bmm(phi_x.unsqueeze(-1), phi_x.unsqueeze(-2)).mean(dim=0)
        design_loss = 0#-self.log_det(cov)
        total_obj = obj + 1e-2 * design_loss
        objs = {
            'obj': obj,
            'total_obj': total_obj,
            'frac': frac
        }
        return objs

    def _recon_loss(self, curr_sa, next_sa, term_masks):

        phi_csa = self.phi.backbone.forward(curr_sa)
        next_recon = self.M_phi.forward(phi_csa)
        terminal_masks = term_masks.reshape(-1, 1)

        cosine_sim = torch.sum(next_recon * terminal_masks * next_sa, axis = 1)\
            / (torch.linalg.norm(next_recon, axis = 1) * torch.linalg.norm(next_sa, axis = 1))
        recon_loss = -cosine_sim.mean()

        obj = recon_loss.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

    def _phi_loss(self, rewards, curr_sa, next_sa, terminal_masks):
        phi_x = self.phi.backbone.forward(curr_sa)
        terminal_masks = terminal_masks.reshape(-1, 1)
        #next_phi_x = terminal_masks * self.phi.forward(next_sa)
        with torch.no_grad():
            next_phi_x = terminal_masks * torch.Tensor(self.target_phi.backbone(next_sa))

        pred_next_phi_x = self.M_phi.forward(phi_x)
        pred_rew = self.M_rew.forward(phi_x)

        design_loss = 0
        if self.logdet_coeff > 0:
            cov = torch.bmm(phi_x.unsqueeze(-1), phi_x.unsqueeze(-2)).mean(dim=0)
            design_loss = -self.log_det(cov)
        
        rewards = torch.Tensor(rewards.reshape((-1, 1)))
        reward_loss = F.mse_loss(pred_rew, rewards)

        BC_loss = self._self_pred_error(pred_next_phi_x, next_phi_x)

        # Combined Loss
        phi_loss = (
            1 * 1.0 * reward_loss
            + 1 * BC_loss
            + self.logdet_coeff * design_loss
        )
        return phi_loss

    def _M_loss(self, rewards, curr_sa, next_sa, terminal_masks):
        terminal_masks = terminal_masks.reshape(-1, 1)
        with torch.no_grad():
            phi_x = torch.Tensor(self.phi.backbone(curr_sa))
            next_phi_x = terminal_masks * torch.Tensor(self.target_phi.backbone(next_sa))

        pred_next_phi_x = self.M_phi.forward(phi_x)
        pred_rew = self.M_rew.forward(phi_x)
        rewards = torch.Tensor(rewards.reshape((-1, 1)))
        reward_loss = F.mse_loss(pred_rew, rewards)

        BC_loss = self._self_pred_error(pred_next_phi_x, next_phi_x)

        M_loss = 1 * 1. * reward_loss\
            + 1 * BC_loss

        return M_loss, reward_loss, BC_loss

    def _self_pred_error(self, pred_next_phi_x, next_phi_x):

        BC_loss = torch.square(torch.linalg.vector_norm(pred_next_phi_x - self.gamma * next_phi_x, dim=1))
        
        if self.norm_selfpred:
            # pred_next_phi_x_norm = torch.linalg.vector_norm(pred_next_phi_x, dim = 1)
            # next_phi_x_norm = torch.linalg.vector_norm(next_phi_x, dim = 1)
            # BC_loss = BC_loss / (pred_next_phi_x_norm * next_phi_x_norm)
            next_phi_x_norm = torch.linalg.norm(next_phi_x, axis = 1)
            clipped_value = torch.full_like(next_phi_x_norm, fill_value=1e-8, dtype=torch.float)
            next_phi_x_norm = torch.maximum(next_phi_x_norm, clipped_value)
            cosine_sim = torch.sum(pred_next_phi_x * self.gamma * next_phi_x, axis = 1)\
               / (next_phi_x_norm * torch.linalg.norm(pred_next_phi_x, axis = 1))
            BC_loss = -cosine_sim.mean()
        
        BC_loss = BC_loss.mean()
        return BC_loss

    def log_det(self, A):
            assert A.dim() in [2, 3]
            # regularize when computing log-det
            A = A + 1e-5 * torch.eye(A.shape[1], device=A.device)
            return 2 * torch.linalg.cholesky(A).diagonal(dim1=-2, dim2=-1).log().sum(-1)

    def _dr3_loss(self, curr_sa, next_sa):

        phi_csa = self.phi.backbone.forward(curr_sa)
        phi_nsa = self.phi.backbone.forward(next_sa)
        dot_prod = torch.abs(torch.sum(phi_csa * phi_nsa, axis = 1))

        obj = dot_prod.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

    def _beer_loss(self, curr_sa, next_sa, rewards):

        phi_csa = self.phi.backbone.forward(curr_sa)
        phi_csa_norm = torch.linalg.norm(phi_csa, axis = 1)
        phi_csa_sg = phi_csa.detach()
        phi_csa_norm_sg = phi_csa_norm.detach()

        phi_nsa_sg = torch.Tensor(self.phi.backbone(next_sa))
        phi_nsa_norm_sg = torch.linalg.norm(phi_nsa_sg, axis = 1)

        w_norm_sq = torch.square(torch.linalg.norm([param.detach() for param in self.phi.linear.parameters()][0]))
        r_sq = torch.square(rewards)

        cosine = torch.sum(phi_csa * phi_nsa_sg, axis = 1) \
            / (phi_csa_norm * phi_nsa_norm_sg)

        norm_first = torch.square(phi_csa_norm_sg)\
            + (self.gamma ** 2) * torch.square(phi_nsa_norm_sg)\
            - (r_sq / w_norm_sq)
        norm_second = 1. / (2 * self.gamma * phi_csa_norm_sg * phi_nsa_norm_sg)

        norm = norm_first * norm_second

        diff_term = cosine - norm
        final = nn.ReLU()(diff_term)

        obj = final.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

    def _td_loss(self, data, rews, curr_sa, next_sa, term_masks):

        q_curr_outputs = self.phi.forward(curr_sa).reshape(-1)
        with torch.no_grad():
            q_next_outputs = self.target_phi(next_sa).reshape(-1)
        
        target = torch.Tensor(rews + self.gamma * term_masks * q_next_outputs)

        # if self.clip_target:
        #     target = torch.clip(target,\
        #         min = data.min_reward / (1. - self.gamma),\
        #         max = data.max_reward / (1. - self.gamma))

        obj = torch.nn.functional.mse_loss(q_curr_outputs, target, reduction = 'none')
        obj = obj.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs

    def get_phi(self):
        return self.best_phi.backbone

class KernelLihong(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        reg_param = 0,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None,
        soft_update_tau = 5e-3,
        hard_update_freq = 5,
        use_penultimate = False):

        super().__init__(pie, gamma, soft_update_tau, hard_update_freq)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular
        self.use_penultimate = use_penultimate

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation,
                                    norm_type = norm_type)
            if self.use_penultimate:
                self.phi = NNwLinear(phi)
            else:
                self.phi = phi

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, lspe = None, ope_evaluator = None, print_log = True):

        self.total_examples = tr_data.num_samples
        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.rew_range = tr_data.max_abs_reward_diff
        self.best_phi = copy.deepcopy(self.target_phi)
        assert self.rew_range > 0

        check_realizability = self.tabular and lspe is not None and ope_evaluator is not None

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(tr_data, **params)

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, mb in enumerate(dataloader):
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                term_masks = mb['terminal_masks']

                objs = self._kernel_loss(rewards, curr_sa, next_sa, term_masks)
                total_obj = objs['total_obj']

                # start updates for epoch 1 onwards
                if epoch > 0:
                    # clear gradients
                    self.optimizer.zero_grad()
                    
                    # compute gradients
                    try:
                        total_obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
                    
                    # gradient step
                    self.optimizer.step()
                    #print ('mb ', idx)

            # hard update target network
            if epoch % self.hard_update_freq == 0:
                soft_target_update(self.phi, self.target_phi, tau = 1)
                #self.optimizer = reset_optimizer(self.phi, self.lr, 0.9, 0.9)

            #if epoch % 10000 == 0:
            #    self.target_phi = copy.deepcopy(self.phi)

            total_obj = objs['total_obj'].item()
            obj = objs['obj'].item()

            #if epoch % 1000 == 0 or epoch == epochs or epoch == 1:

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_phi = self.phi
            # periodically check value error
            if epoch % 10 == 0 or epoch == epochs:
                lspe.train(tr_data, epochs = 3000, num_action_samples = 20, phi = self.best_phi.backbone)
                ope_error = ope_evaluator.evaluate(self.best_phi.backbone, lspe.get_theta())
                objs['phi_ope_error'] = ope_error
                objs['phi_spectral_radius'] = lspe.get_metrics()['phi_lspe_stats']['spectral_radius']
                objs['phi_pos_eigen_frac'] = lspe.get_metrics()['phi_lspe_stats']['pos_eigen_frac']
            stats = self.pmt.track_phi_training_stats(epoch, self.best_phi.backbone,\
                objs, curr_sa, next_sa, pie = self.pie,\
                ope_evaluator = ope_evaluator, check_realizability = check_realizability, terminals = term_masks)
            #test_err = self._get_test_error(test_data, test_dataloader)
            print (f'{epoch}, stats: {stats}')

    def _kernel_loss(self, rews, state_actions, next_state_actions, term_masks):

        #phi_sa = torch.Tensor(self.phi.backbone.forward(state_actions))
        #phi_sa_nograd = phi_sa.detach()
        phi_sa = torch.Tensor(state_actions)
        phi_sa_nograd = torch.Tensor(state_actions)
        # expanded_tensor = phi_sa.unsqueeze(1) 
        # squared_diff = (expanded_tensor - phi_sa).pow(2)
        # squared_distance = squared_diff.sum(dim=2)
        # pairwise_kernel = torch.exp(-squared_distance / 0.5 ** 2)
        pairwise_kernel = torch.mm(phi_sa, phi_sa_nograd.T)
        resid = rews + self.gamma * term_masks * self.phi.forward(next_state_actions).reshape(-1) \
            - self.phi.forward(state_actions).reshape(-1)
        resid_nograd = resid.detach()

        resid = resid.reshape(-1, 1)
        resid_nograd = resid_nograd.reshape(-1, 1)
        pairwise_resid = torch.mm(resid, resid_nograd.T)
        weighted_resid = pairwise_kernel * pairwise_resid
        off_diagonal = weighted_resid[~torch.eye(*weighted_resid.shape, dtype = torch.bool)]
        on_diagonal = weighted_resid[torch.eye(*weighted_resid.shape, dtype = torch.bool)]

        on_diag_sum = on_diagonal.sum()
        off_diag_sum = off_diagonal.sum()
        n = self.total_examples
        m = state_actions.shape[0]
        v_obj = (1 / (m * n)) * (on_diag_sum + ((n - 1) / (m - 1)) * off_diag_sum)

        obj = 0.5 * (off_diagonal).mean() + 0.5 * v_obj

        total_obj = obj
        objs = {
            'obj': obj,
            'total_obj': total_obj,
        }
        return objs
    
    def get_phi(self):
        return self.best_phi.backbone

class RandomMatrixProjection(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        gamma = None,
        mdp = None,
        pie = None,
        tabular = False,
        norm_type = None):

        super().__init__(pie, gamma)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.gamma = gamma
        self.mdp = mdp
        self.pie = pie
        self.tabular = tabular

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        self.M_phi = Matrix(input_dims = ground_state_action_dims,
                                output_dims = abs_state_action_dims)
        self.phi = self.best_phi = self.M_phi

    def train(self, tr_data, test_data, epochs = 2000, mini_batch_size = 256, lspe = None, ope_evaluator = None, print_log = True):
        return self

class QpieBisim:
    def __init__(self, mdp, q_pie_values, pi = None, eps = 0.1):
        self.mdp = mdp
        self.q_pie_values = q_pie_values
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action
        self.eps = eps

        self.vals = np.round(q_pie_values.flatten() / eps).astype(int)
        self.groups, self.sa_to_group = self.get_grouping(self.vals)
        self.num_abs_sa = len(self.groups)
        self.phi_outdim = self.num_abs_sa
        print (len(self.groups), self.groups)

    def get_grouping(self, array):
        groups = {}
        for num in array:
            if num in groups:
                groups[num].append(num)
            else:
                groups[num] = [num]

        group_ids = list(groups.keys())
        sa_to_group = {}
        for sa_idx, val in enumerate(array):
            group_idx = group_ids.index(val)
            sa_to_group[sa_idx] = group_idx
        return groups, sa_to_group

    def __call__(self, state_actions):
        try:
            if torch.is_tensor(state_actions):
                state_actions = state_actions.numpy()
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            if state_actions.ndim == 1:
                sa_idx = [sa_idx]
            features = []
            for i in sa_idx:
                group_idx = self.sa_to_group[i]
                group_feature = np.zeros(self.num_abs_sa)
                group_feature[group_idx] = 1.
                features.append(group_feature)
            features = np.array(features)
            if state_actions.ndim == 1:
                features = features[0]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class Krylov:
    def __init__(self, mdp, phi_outdim, pi, orthogonal = False):
        self.mdp = mdp
        self.phi_outdim = phi_outdim
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action

        pi_tr_s, pi_r, pi_tr_sa = mdp.get_policy_probs(pi.pi_matrix)
        r = mdp.rewards
        r = r.reshape(r.shape[0] * r.shape[1], -1)
        p_sa = pi_tr_sa.reshape(-1, pi_tr_sa.shape[2] * pi_tr_sa.shape[3])

        kry = [r]
        accum_psa = p_sa
        for i in range(1, phi_outdim):
            kry.append(np.matmul(accum_psa, r))
            accum_psa = np.matmul(accum_psa, p_sa)
        kry = np.array(kry)
        kry = kry.squeeze(-1)
        self.kry = kry.T

        if orthogonal:
            self.kry = self._get_ortho()

    def _get_ortho(self, th = 0.01):
        U, S, Vt = np.linalg.svd(self.kry)
        th = 1 - th
        den = np.sum(S)
        cum_sum = 0
        for srank in range(1, len(S) + 1):
            cum_sum += S[srank - 1] 
            rat = cum_sum / den
            if rat >= th:
                break
        
        new_kry = U[:, :self.phi_outdim]
        return new_kry
        # sub_U = U[:, :srank]
        # sub_S = np.diag(S)[:srank, :srank]
        # sub_Vt = Vt[:srank, :]
        # pdb.set_trace()
        # new_kry = sub_U @ sub_S @ sub_Vt
        
        
    def __call__(self, state_actions):
        try:
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            features = self.kry[sa_idx]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class Schur:
    def __init__(self, mdp, phi_outdim, pi):
        self.mdp = mdp
        self.phi_outdim = phi_outdim
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action

        _, _, pi_tr_sa = mdp.get_policy_probs(pi.pi_matrix)
        p_sa = pi_tr_sa.reshape(-1, pi_tr_sa.shape[2] * pi_tr_sa.shape[3])

        T, Q = schur(p_sa)
        self.schur = Q[:, :self.phi_outdim]
        
    def __call__(self, state_actions):
        try:
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            features = self.schur[sa_idx]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class RandomFeatures:
    def __init__(self, mdp, phi_outdim):
        self.mdp = mdp
        self.phi_outdim = phi_outdim
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action

        all_sa = _generate_all_one_hots(mdp)
        self.all_sa = all_sa.reshape(-1, all_sa.shape[-1])
        self.rand_features = np.random.rand(self.all_sa.shape[0], self.phi_outdim)
        print ('rank of feats {}'.format(np.linalg.matrix_rank(self.rand_features)))

    def __call__(self, state_actions):
        try:
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            features = self.rand_features[sa_idx]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class ConstantFeatures:
    def __init__(self, mdp, phi_outdim, constant = 1e-1):
        self.mdp = mdp
        self.phi_outdim = phi_outdim
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action

        all_sa = utils_gw._generate_all_one_hots(mdp)
        pdb.set_trace()
        self.all_sa = all_sa.reshape(-1, all_sa.shape[-1])
        U, S, Vt = np.linalg.svd(self.kry)
        self.cons_features = np.zeros((self.all_sa.shape[0], self.phi_outdim)) + constant
        print ('rank of feats {}'.format(np.linalg.matrix_rank(self.cons_features)))

    def __call__(self, state_actions):
        try:
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            features = self.cons_features[sa_idx]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class PCAFeatures:
    def __init__(self, mdp, phi_outdim, constant = 1e-1):
        self.mdp = mdp
        self.phi_outdim = phi_outdim
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action

        all_sa = utils_gw._generate_all_one_hots(mdp)
        self.all_sa = all_sa.reshape(-1, all_sa.shape[-1])
        pca = PCA(n_components=phi_outdim, svd_solver='full')
        pca.fit(self.all_sa)
        self.feats = pca.transform(self.all_sa)
        print ('rank of feats {}'.format(np.linalg.matrix_rank(self.feats)))

    def __call__(self, state_actions):
        try:
            sa_idx = np.argmax(state_actions, axis = state_actions.ndim - 1)
            features = self.feats[sa_idx]
            return features
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self

class OffPolicySA(RepAlgo):
    def __init__(self,
        ground_state_dims,
        action_dims,
        abs_state_action_dims,
        hidden_dim = 32,
        hidden_layers = 1,
        activation = 'relu',
        final_activation = None,
        lr = 3e-4,
        reg_param = 0,
        beta = 0.1,
        gamma = None,
        image_state = False,
        mdp = None,
        pie = None,
        tabular = False,
        loss_function = 'huber'):

        super().__init__(pie, gamma)

        self.ground_state_dims = ground_state_dims
        self.action_dims = action_dims
        self.abs_state_action_dims = abs_state_action_dims
        self.hidden_dim = hidden_dim
        self.hidden_layers = hidden_layers
        self.activation = activation
        self.final_activation = final_activation
        self.lr = lr
        self.reg_param = reg_param
        self.gamma = gamma
        self.beta = beta
        self.mdp = mdp
        self.pie = pie
        self.loss_function = loss_function

        self.tabular = tabular

        ground_state_action_dims = ground_state_dims + action_dims
        if self.tabular:
            ground_state_action_dims = ground_state_dims * action_dims

        if image_state:
            return
        else:
            self.phi = NeuralNetwork(input_dims = ground_state_action_dims,
                                    output_dims = abs_state_action_dims,
                                    hidden_dim = hidden_dim,
                                    hidden_layers = hidden_layers,
                                    activation = self.activation,
                                    final_activation = self.final_activation)

        self.target_phi = copy.deepcopy(self.phi)
        self.optimizer = torch.optim.AdamW(self.phi.parameters(), lr = self.lr, betas = (0.9, 0.9))

    def train(self, data, epochs = 2000, print_log = True):

        mini_batch_size = 256
        self.best_phi = copy.deepcopy(self.target_phi)
        max_diff = data.max_abs_reward_diff

        for epoch in range(1, epochs + 1):
            mini_batch = load_mini_batches(data, self.tabular, self.pie, mini_batch_size)
            curr_sa = mini_batch['curr_sa']
            next_sa = mini_batch['next_sa']
            rewards = mini_batch['rewards']

            other_curr_sa = mini_batch['other_curr_sa']
            other_next_sa = mini_batch['other_next_sa']
            other_rewards = mini_batch['other_rewards']

            objs = self._rope_loss(rewards, other_rewards,\
                        curr_sa, other_curr_sa,\
                        next_sa, other_next_sa)
            total_obj = objs['total_obj']

            # clear gradients
            self.optimizer.zero_grad()
            
            # compute gradients
            try:
                total_obj.backward()
            except:
                print ('error on backprop, breaking')
                break

            # processing gradients
            nn.utils.clip_grad_value_(self.phi.parameters(), clip_value = 1.0)
            
            # gradient step
            self.optimizer.step()

            # soft update target network
            soft_target_update(self.phi, self.target_phi)

            total_obj = objs['total_obj'].item()
            reg = objs['reg'].item()
            obj = objs['obj'].item()

            if epoch % 1000 == 0 or epoch == epochs or epoch == 1:
                self.best_phi = self.target_phi
                stats = self.pmt.track_phi_training_stats(epoch, self.target_phi,\
                    objs, curr_sa, next_sa, pie = self.pie)
                print (f'{epoch}, stats: {stats}')

    def _rope_loss(self, rews, other_rews,\
                state_actions, other_state_actions,\
                next_state_actions, other_next_state_actions):

        reward_dist = torch.Tensor(np.abs(rews - other_rews))
        phi_x = self.phi.forward(state_actions)
        phi_x_norm = torch.linalg.norm(phi_x, axis = 1)
        phi_y = torch.tensor(self.target_phi(other_state_actions)) # no gradients
        #phi_y = self.phi.forward(other_state_actions)
        phi_y_norm = torch.linalg.norm(phi_y, axis = 1)
        cs = torch.sum(phi_x * phi_y, axis = 1) / (phi_x_norm * phi_y_norm)
        angle = torch.atan2(torch.sqrt(1. + SQRT_EPS - torch.square(cs)), cs)
        norm_avg = 0.5 * (torch.square(phi_x_norm) + torch.square(phi_y_norm))
        cs_dist = angle
        curr_Uxy = norm_avg + self.beta * cs_dist

        next_phi_x = torch.tensor(self.target_phi(next_state_actions))
        next_phi_x_norm = torch.linalg.norm(next_phi_x, axis = 1)
        next_phi_y = torch.tensor(self.target_phi(other_next_state_actions))
        next_phi_y_norm = torch.linalg.norm(next_phi_y, axis = 1)
        cs = torch.sum(next_phi_x * next_phi_y, axis = 1) / (next_phi_x_norm * next_phi_y_norm)
        angle = torch.atan2(torch.sqrt(1. + SQRT_EPS - torch.square(cs)), cs)
        next_norm_avg = 0.5 * (torch.square(next_phi_x_norm) + torch.square(next_phi_y_norm))
        next_cs_dist = angle
        next_Uxy = next_norm_avg + self.beta * next_cs_dist

        target = reward_dist + self.gamma * next_Uxy

        # target = torch.clip(target,\
        #         min = data.min_abs_reward_diff / (1. - self.gamma),
        #         max = data.max_abs_reward_diff / (1. - self.gamma))
        if self.loss_function == 'huber':
            obj = torch.nn.functional.huber_loss(curr_Uxy, target)
        elif self.loss_function == 'mse':
            obj = torch.nn.functional.mse_loss(curr_Uxy, target)

        reg = self.reg_param * torch.Tensor([0])
        total_obj = obj + reg 
        objs = {
            'obj': obj,
            'reg': reg,
            'total_obj': total_obj,
        }
        return objs

    def log_det(self, A):
            assert A.dim() in [2, 3]
            # regularize when computing log-det
            A = A + 1e-8 * torch.eye(A.shape[1], device=A.device)
            return 2 * torch.linalg.cholesky(A).diagonal(dim1=-2, dim2=-1).log().sum(-1)