import numpy as np
import pdb
import os
import torch
import torch.nn.functional as F
from torch import nn
from policies import NeuralNetwork, StablebaselinePolicy, D4RLPolicy, StablebaselinePolicyMixture
import copy
import random
import matplotlib.pyplot as plt
#from umap import UMAP
import seaborn as sns
import pandas as pd
import gymnasium as g
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import pairwise_distances
import warnings
warnings.filterwarnings("ignore", "is_categorical_dtype")
warnings.filterwarnings("ignore", "use_inf_as_na")
warnings.filterwarnings("error")

global_lam_inv = 0

def set_global_lam(FLAGS):
    global global_lam_inv
    global_lam_inv = FLAGS.lam_inv

def set_seed_everywhere(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def clip_target(clip_target, ret, min_rew, max_rew, gamma):
    if clip_target:
        try:
            low_limit = 2 * min_rew / (1 - gamma)
            up_limit = 2 * max_rew / (1 - gamma)
        except:
            low_limit = -np.inf
            up_limit = np.inf
        ret = np.clip(ret, low_limit, up_limit)
    return ret

class DeepOPEEvaluator:
    def __init__(self, mdp, pie, data, oracle_est_ret, rand_est_ret, min_rew, max_rew, gamma):
        self.mdp = mdp
        self.pie = pie
        self.data = data
        self.init_s = data.initial_states
        self.oracle_est_ret = oracle_est_ret
        self.rand_est_ret = rand_est_ret
        self.min_rew = min_rew
        self.max_rew = max_rew
        self.gamma = gamma
        self.csa = self.data.curr_state_actions

    def _prep(self, phi, weights):
        sampled_actions = self.pie.batch_sample(self.data.unnormalize_states(self.init_s))
        init_sa = np.concatenate((self.init_s, sampled_actions), axis = 1)
        if phi is not None:
            init_sa = phi(init_sa)
        init_est_qvals = init_sa @ weights
        est_return = np.mean(init_est_qvals)
        return est_return

    def evaluate(self, phi, weights, metric_type = 'error'):
        est_return = self._prep(phi, weights)

        if metric_type == 'error':
            num = np.mean(np.square(est_return - self.oracle_est_ret))
            denom = np.mean(np.square(self.oracle_est_ret - self.rand_est_ret))

            normalized_err = num / denom
            ope_error = np.nan_to_num(normalized_err, nan = np.inf)
            print (f'final OPE error: {ope_error}, numerator {num}, denom {denom}')
            return ope_error
        elif metric_type == 'return':
            print (f'final OPE return: {est_return}')
            return est_return

    def value_eval_path(self, Qs):

        phi = Qs[len(Qs) // 2].backbone
        ref_phi = self.csa
        if phi is not None:
            ref_phi = phi(ref_phi)

        errs = []
        for idx, q in enumerate(Qs):
            est_q = q(self.csa)
            C, residuals, rank, sing_vals = np.linalg.lstsq(ref_phi, est_q, rcond=None)
            err_norm = residuals[0] / len(self.csa) if len(residuals) else np.array([0])
            normalization = np.mean(np.square(self.oracle_est_ret - self.rand_est_ret))
            err_norm = err_norm / (normalization if normalization > 0 else 1)
            errs.append(err_norm)
        return errs

def pca_plot(features, name):
    from sklearn.decomposition import PCA
    from mpl_toolkits.mplot3d import Axes3D
    import pandas as pd
    pca = PCA(n_components=2)
    shuffle = np.random.choice(len(features), len(features), replace = False)
    feats = features[shuffle]

    pca_result = pca.fit_transform(feats)

    first = pca_result[:,0]
    second = pca_result[:,1]
    third = pca_result[:, 1]

    print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))

    df = pd.DataFrame({'first': first, 'second': second, 'third': third})

    plt.figure(figsize=(16,10))
    sns.scatterplot(
        x='first', y='second',
        data=df,
        legend="full",
        alpha=0.3
    )

    # fig = plt.figure(figsize=(16,10))
    # ax = fig.add_subplot(projection='3d')
    # ax.scatter(
    #     xs=df["first"], 
    #     ys=df["second"], 
    #     zs=df["third"], 
    # )
    plt.savefig('{}_pca.jpg'.format(name))
    plt.close()

def nearest_neighbors(data, pie, num_samples = int(1e4), phi = None, enc_name = None):
    
    from scipy.spatial.distance import cdist
    from sklearn.neighbors import NearestNeighbors
    curr_s = data.curr_states
    curr_a = data.curr_actions
    curr_sa = np.concatenate((curr_s, curr_a), axis = 1)
    next_s = data.next_states

    np.random.seed(0)
    next_sampled_acts = pie.batch_sample(data.unnormalize_states(next_s))
    next_sa = np.concatenate((next_s, next_sampled_acts), axis = 1)

    num_samples = min(4096 * 2, data.num_samples)
    shuffle = np.random.choice(curr_sa.shape[0], num_samples, replace = False)
    curr_sa = curr_sa[shuffle]
    next_sa = next_sa[shuffle]

    with torch.no_grad():
        curr_sa_val0 = pie.pi.critic_target.qf0(torch.Tensor(curr_sa)).numpy()
        curr_sa_val1 = pie.pi.critic_target.qf1(torch.Tensor(curr_sa)).numpy()
        curr_sa_val = (curr_sa_val0 + curr_sa_val1) / 2

        next_sa_val0 = pie.pi.critic_target.qf0(torch.Tensor(next_sa)).numpy()
        next_sa_val1 = pie.pi.critic_target.qf1(torch.Tensor(next_sa)).numpy()
        next_sa_val = (next_sa_val0 + next_sa_val1) / 2

    if phi is not None:
        curr_sa = phi(curr_sa)
        next_sa = phi(next_sa)
    curr_sa = curr_sa.astype(np.float64)
    next_sa = next_sa.astype(np.float64)

    X = np.vstack((curr_sa, next_sa))
    q_vals = np.vstack((curr_sa_val, next_sa_val)).reshape(-1)
    X = curr_sa
    q_vals = curr_sa_val.reshape(-1)

    n_components = 2
    from sklearn.manifold import TSNE
    #from umap import UMAP
    from sklearn.decomposition import PCA
    high_d_visual = TSNE(n_components=n_components, metric = 'cosine')
    viz_type = 'tsne'
    #high_d_visual = UMAP(n_components=n_components)
    #viz_type = 'umap'
    #high_d_visual = PCA(n_components=n_components)
    #viz_type = 'pca'
    high_d_visual_result = high_d_visual.fit_transform(X)

    markers = {
        0: "P",
        1: "o"
    }

    high_d_visual_result_df = pd.DataFrame({'{}_1'.format(viz_type): high_d_visual_result[:, 0], 
                                    '{}_2'.format(viz_type): high_d_visual_result[:, 1]})

    fig, ax = plt.subplots(1, figsize=(12, 8))
    scatter_plot = sns.scatterplot(x='{}_1'.format(viz_type), y='{}_2'.format(viz_type), hue = q_vals,
        palette="viridis", data=high_d_visual_result_df, ax=ax, s=120, legend=False)

    sm = plt.cm.ScalarMappable(cmap='viridis')
    sm.set_array(q_vals)  # You need to set an empty array to properly use the colorbar

    cbar = plt.colorbar(
        sm,
        ax=plt.gca()
    )
    cbar.set_label('Q values', rotation=270, labelpad=15)

    plt.savefig('{}_temp.jpg'.format(enc_name))
    plt.close()


    dot_prods = np.dot(next_sa, curr_sa.T)
    nsa_norms = np.linalg.norm(next_sa, axis = 1).reshape(-1,1)
    csa_norms = np.linalg.norm(curr_sa, axis = 1).reshape(-1,1)
    denoms = np.dot(nsa_norms, csa_norms.T)
    cosine_sim = dot_prods / denoms

    argsrted = np.argsort(cosine_sim, axis=1)
    indices_smallest = argsrted[:, :3]
    indices_largest = argsrted[:, -3:]

    k_neighbors = 5
    nn_model = NearestNeighbors(n_neighbors=k_neighbors)
    nn_model.fit(curr_sa)
    distances, indices = nn_model.kneighbors(next_sa)
    #for i in range(next_sa.shape[0]):
    # for i in range(50):
    #     print(f"Nearest neighbor of X[{i}] with distance {distances[i, :k_neighbors]}")
    #     print(f"Q value of X[{i}] is {next_sa_val[i, 0]} with Q values of NNs are: {curr_sa_val[indices[i]].reshape(-1)}")# and farthest {curr_sa_val[farthest[i]]}")
    #     #print(f"Nearest neighbor of X[{i}] with cosine sim {cosine_sim[i, indices_largest[i]]} and farthest with cosine sim {cosine_sim[i, indices_smallest[i]]}")
    #     #print(f"Q value of X[{i}] is {next_sa_val[i, 0]} with Q values of NNs are: {curr_sa_val[indices_largest[i]].reshape(-1)} and farthest {curr_sa_val[indices_smallest[i]].reshape(-1)}")


    # distances = cdist(next_sa, curr_sa, metric='cosine')
    # min_dists = np.min(distances, axis=1)
    # print ('average min distance {}'.format(np.mean(min_dists)))

    # for i range(num_samples):
    #     print(f"Nearest neighbor of X[{i}] is Y[{idx}] with distance {distances[i, idx]}")

def high_d_plot_fa(data, pie, num_samples = int(1e4), phi = None, fname = None, typ = 'umap'):
    from sklearn.manifold import TSNE
    from sklearn.manifold import MDS
    from sklearn.decomposition import PCA
    from umap import UMAP
    import pandas as pd
    #from scipy.spatial import ConvexHull
    #from shapely.geometry import Point, Polygon

    curr_s = data.curr_states
    curr_a = data.curr_actions
    curr_sa = np.concatenate((curr_s, curr_a), axis = 1)
    next_s = data.next_states
    next_sampled_acts = pie.batch_sample(data.unnormalize_states(next_s))
    next_sa = np.concatenate((next_s, next_sampled_acts), axis = 1)

    for st in ['ds']:
        if st == 'ds':
            num_samples = min(4096, data.num_samples)
            shuffle = np.random.choice(curr_sa.shape[0], num_samples, replace = False)
            curr_sa = curr_sa[shuffle]
            next_sa = next_sa[shuffle]
        elif st == 'unq':
            curr_sa, idx = np.unique(curr_sa, axis = 0, return_index = True)
            next_sa, idx = np.unique(next_sa, axis = 0, return_index = True)

        if phi is not None:
            curr_sa = phi(curr_sa)
            next_sa = phi(next_sa)
        curr_sa = curr_sa.astype(np.float64)
        next_sa = next_sa.astype(np.float64)

        phi_stats = get_phi_stats(curr_sa)
        print ('phi visualization info. shape: {}, stats: {}'.format(curr_sa.shape, phi_stats))
        
        def _plot(X, num_csa = None, viz_type = None, file_name = None):
            n_components = 2
            if viz_type == 'umap':
                high_d_visual = UMAP(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)
            elif viz_type == 'mds':
                high_d_visual = MDS(n_components=n_components, normalized_stress='auto')
                high_d_visual_result = high_d_visual.fit_transform(X)
            elif viz_type == 'pca':
                high_d_visual = PCA(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)
                # components = high_d_visual.components_
                # explained_variance = high_d_visual.explained_variance_ratio_
            elif viz_type == 'tsne':
                high_d_visual = TSNE(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)

            csa = np.zeros((X.shape[0]))
            csa[:num_csa] = 1
            csa = csa.astype(int)
            markers = {
                0: "P",
                1: "o"
            }

            high_d_visual_result_df = pd.DataFrame({'{}_1'.format(viz_type): high_d_visual_result[:, 0], 
                                            '{}_2'.format(viz_type): high_d_visual_result[:, 1], 
                                            'csa': csa})

            
            fig, ax = plt.subplots(1, figsize=(12, 8))
            sns.scatterplot(x='{}_1'.format(viz_type), y='{}_2'.format(viz_type), hue = 'csa',
                palette="deep", data=high_d_visual_result_df, ax=ax, s=120, style="csa", markers = markers)
            # for i, (var, component) in enumerate(zip(explained_variance, components)):
            #     plt.arrow(0, 0, component[0], component[1], color=f'C{i}', label=f'PC{i+1} (Var: {var:.2f})', head_width=0.1)

            # sub = high_d_visual_result[:num_csa]
            # hull = ConvexHull(sub)
            # for simplex in hull.simplices:
            #     plt.plot(sub[simplex, 0], sub[simplex, 1], 'k-')
            # plt.plot(sub[hull.vertices, 0], sub[hull.vertices, 1], 'ro', label='Convex Hull Vertices')

            # hull_vertices = sub[hull.vertices]
            # polygon = Polygon(hull_vertices)
            # points_inside_hull = [polygon.contains(Point(point)) for point in high_d_visual_result[num_csa:]]
            # #points_inside_hull = hull.find_simplex(high_d_visual_result[num_csa:]) >= 0
            # frac_inside = np.count_nonzero(points_inside_hull) / len(points_inside_hull)
            # print (f'fraction of sp,ap in dataset {frac_inside}')

            ax.set_aspect('equal')
            ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize = 'large')
            plt.savefig('{}_{}_{}.jpg'.format(file_name, viz_type, st))
            plt.close()

            #lim = (high_d_visual_result.min(), high_d_visual_result.max())
            #ax.set_xlim(lim)
            #ax.set_ylim(lim)
            #ax.legend(loc=0)
            # ax.get_legend().remove()
            # plt.title('Native State-Action Features', fontsize=20)
            
            # plt.xlabel('')
            # plt.ylabel('')

            # plt.xticks([]) 
            # plt.yticks([])
            # fig.tight_layout()

        X = np.vstack((curr_sa, next_sa))
        num_csa = curr_sa.shape[0]
        X = X[:num_csa]
        try:
            _plot(X, num_csa = num_csa, viz_type = 'pca', file_name = fname)
        except:
            print ('failed generating umap')
        # try:
        #     _plot(X, num_csa = num_csa, viz_type = 'tsne', file_name = fname)
        # except:
        #     print ('failed generating tsne')
        # try:
        #     _plot(X, num_csa = num_csa, viz_type = 'pca', file_name = fname)
        # except:
        #     print ('failed generating pca')
        #_plot(X, num_csa = num_csa, viz_type = 'mds', file_name = fname)
        #_plot(next_sa, next_q_vals, file_name = fname + '-nsa')

def plot_feature_matrix(data, pie, phi = None, fname = None, ground_weights = None):
    dps = data.curr_state_actions
    next_sa = data.next_state_actions
    
    curr_sa, idx = np.unique(dps, axis = 0, return_index = True)
    #q_vals = data.q_values[idx].astype(int)
    q_vals = np.dot(curr_sa, ground_weights).astype(int)
    next_sa, idx = np.unique(next_sa, axis = 0, return_index = True)
    next_q_vals = np.dot(next_sa, ground_weights).astype(int)

    if phi is not None:
        curr_sa = phi(curr_sa)
        next_sa = phi(next_sa)
    curr_sa = curr_sa.astype(np.float64)
    next_sa = next_sa.astype(np.float64)

    srted_idx = np.argsort(q_vals)
    srted_curr_sa = curr_sa[srted_idx]
    srted_curr_sa = srted_curr_sa.T
    x_dim, y_dim = srted_curr_sa.shape
    aspect_ratio = float(y_dim) / x_dim
    plt.imshow(srted_curr_sa, cmap='viridis', interpolation='nearest', aspect = aspect_ratio)
    plt.colorbar()
    plt.title('Matrix Plot')
    plt.xticks(np.arange(0, srted_curr_sa.shape[1]), np.sort(q_vals))
    plt.gca().xaxis.set_major_locator(plt.MaxNLocator(prune='both'))
    plt.savefig('{}_{}_{}.jpg'.format(fname, 'feature-mat', 'unq'))
    plt.close()

def high_d_plot(data, pie, num_samples = int(1e4), phi = None, fname = None, typ = 'umap', q_values = None):
    from sklearn.manifold import TSNE
    from sklearn.manifold import MDS
    from sklearn.decomposition import PCA
    #from umap import UMAP
    import pandas as pd
    #from scipy.spatial import ConvexHull

    dps = data.curr_state_actions
    next_sa = data.next_state_actions

    for st in ['ds']:
        if st == 'ds':
            np.random.seed(0)
            num_samples = min(2048, data.num_samples)
            shuffle = np.random.choice(dps.shape[0], num_samples, replace = False)
            curr_sa = dps[shuffle]
            idx = np.argmax(curr_sa, axis = 1)
            q_vals = q_values[idx].reshape(-1)
            #q_vals = np.dot(curr_sa, ground_weights).astype(int)
            #next_q_vals = np.dot(next_sa, ground_weights).astype(int)
        elif st == 'unq':
            curr_sa, idx = np.unique(dps, axis = 0, return_index = True)
            #q_vals = data.q_values[idx].astype(int)
            q_vals = np.dot(curr_sa, ground_weights).astype(int)
            next_sa, idx = np.unique(next_sa, axis = 0, return_index = True)
            next_q_vals = np.dot(next_sa, ground_weights).astype(int)

        if phi is not None:
            curr_sa = phi(curr_sa)
            next_sa = phi(next_sa)
        curr_sa = curr_sa.astype(np.float64)
        next_sa = next_sa.astype(np.float64)

        phi_stats = get_phi_stats(curr_sa)
        print ('phi visualization info. shape: {}, stats: {}'.format(curr_sa.shape, phi_stats))
        
        def _plot(X, y, num_csa, file_name, viz_type):
            n_components = 2
            if viz_type == 'umap':
                high_d_visual = UMAP(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)
            elif viz_type == 'mds':
                high_d_visual = MDS(n_components=n_components, normalized_stress='auto')
                high_d_visual_result = high_d_visual.fit_transform(X)
            elif viz_type == 'pca':
                high_d_visual = PCA(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)
            elif viz_type == 'tsne':
                high_d_visual = TSNE(n_components=n_components)
                high_d_visual_result = high_d_visual.fit_transform(X)

            csa = np.zeros((y.shape))
            csa[:num_csa] = 1
            csa = csa.astype(int)
            markers = {
                0: "P",
                1: "o"
            }

            high_d_visual_result_df = pd.DataFrame({'{}_1'.format(viz_type): high_d_visual_result[:, 0], 
                                            '{}_2'.format(viz_type): high_d_visual_result[:, 1], 
                                            'label': y,
                                            'csa': csa})

            fig, ax = plt.subplots(1, figsize=(12, 8))
            scatter_plot = sns.scatterplot(x='{}_1'.format(viz_type), y='{}_2'.format(viz_type), hue = y,
                palette="viridis", data=high_d_visual_result_df, ax=ax, s=120, legend=False)

            sm = plt.cm.ScalarMappable(cmap='viridis')
            sm.set_array(y)  # You need to set an empty array to properly use the colorbar

            cbar = plt.colorbar(
                sm,
                ax=plt.gca()
            )
            cbar.set_label('Q values', rotation=270, labelpad=15)

            plt.savefig('{}_{}_{}.jpg'.format(file_name, viz_type, st))
            plt.close()


            # fig, ax = plt.subplots(1, figsize=(12, 8))
            # sns.scatterplot(x='{}_1'.format(viz_type), y='{}_2'.format(viz_type), hue='label',
            #     palette="deep", data=high_d_visual_result_df, ax=ax, s=360, style="csa", markers = markers)

            # sub = high_d_visual_result[:num_csa]
            # hull = ConvexHull(sub)
            # for simplex in hull.simplices:
            #     plt.plot(sub[simplex, 0], sub[simplex, 1], 'k-')
            # plt.plot(sub[hull.vertices, 0], sub[hull.vertices, 1], 'ro', label='Convex Hull Vertices')

            # sns.scatterplot(x='{}_1'.format(viz_type), y='{}_2'.format(viz_type), hue='label',
            #     palette="deep", data=high_d_visual_result_df[num_csa:], ax=ax[1], s=360, style="csa", markers = markers, alpha = 0.5)

            # ax.set_aspect('equal')
            # ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0, fontsize = 'large')
            # plt.savefig('{}_{}_{}.jpg'.format(file_name, viz_type, st))
            # plt.close()

            #lim = (high_d_visual_result.min(), high_d_visual_result.max())
            #ax.set_xlim(lim)
            #ax.set_ylim(lim)
            #ax.legend(loc=0)
            # ax.get_legend().remove()
            # plt.title('Native State-Action Features', fontsize=20)
            
            # plt.xlabel('')
            # plt.ylabel('')

            # plt.xticks([]) 
            # plt.yticks([])
            # fig.tight_layout()

        # X = np.vstack((curr_sa, next_sa))
        # y = np.concatenate((q_vals, next_q_vals))
        X = curr_sa
        y = q_vals
        num_csa = curr_sa.shape[0]
        X = X[:num_csa, :]
        y = y[:num_csa]
        _plot(X, y, num_csa = num_csa, file_name = fname, viz_type = 'umap')
        #_plot(X, y, num_csa = num_csa, file_name = fname, viz_type = 'pca')
        #_plot(X, y, num_csa = num_csa, file_name = fname, viz_type = 'tsne')
        #_plot(X, y, num_csa = num_csa, file_name = fname, viz_type = 'mds')
        #_plot(next_sa, next_q_vals, file_name = fname + '-nsa')

def plot_covariance_heatmap(dataset, pie, gamma, phi = None, enc_name = None, tabular = False, epoch = None):
    
    if tabular:
        curr_sa = dataset.curr_state_actions
        next_states = dataset.next_states
        next_sa = pie.sample_sa_features(next_states)
    else:
        curr_states = dataset.curr_states
        curr_actions = dataset.curr_actions
        curr_sa = np.concatenate((curr_states, curr_actions), axis = 1)
        next_states = dataset.next_states
        pie_next_actions = pie.batch_sample(dataset.unnormalize_states(next_states))
        next_sa = np.concatenate((next_states, pie_next_actions), axis = 1)

    if phi is not None:
        curr_sa = phi(curr_sa)
        next_sa = phi(next_sa)

    #curr_sa = curr_sa / np.linalg.norm(curr_sa, axis = 1, keepdims = True)
    #next_sa = next_sa / np.linalg.norm(next_sa, axis = 1, keepdims = True)

    avg_feat = np.mean(curr_sa, axis = 0)
    avg_nfeat = np.mean(next_sa, axis = 0)
    cov = np.matmul(curr_sa.T, curr_sa) / curr_sa.shape[0]
    eig_vals, _ = np.linalg.eig(cov)
    num_eigs = len(eig_vals)
    real_pos_frac = np.count_nonzero(np.real(eig_vals) >= 0) / num_eigs
    #print ('real eig vals ', real_pos_frac)
    print ('min real eig vals ', np.min(eig_vals))
    print ('all eig ', eig_vals)
    print ('avg feat ', avg_feat) 
    print ('avg nfeat ', avg_nfeat)
    print ('avg feat std', np.std(curr_sa,axis=0))

    ncov = gamma * np.matmul(curr_sa.T, dataset.terminal_masks.reshape(-1, 1) * next_sa) / curr_sa.shape[0]
    #ncov =  np.matmul(curr_sa.T, next_sa) / curr_sa.shape[0]

    A = cov - ncov
    invcov = compute_inverse(cov)
    L = np.matmul(invcov, ncov)
    eig_valsA, _ = np.linalg.eig(A)
    eig_valsL, _ = np.linalg.eig(L)
    eig_valsC, _ = np.linalg.eig(cov)
    eig_valsNC, _ = np.linalg.eig(ncov)
    eig_valsinvC, _ = np.linalg.eig(invcov)
    print ('min real eig vals of ncov ', np.min(eig_valsNC), np.max(np.abs(eig_valsNC)), np.abs(eig_valsNC), eig_valsNC)
    print ('min real eig vals of cov ', np.min(eig_valsC), np.max(np.abs(eig_valsC)), np.abs(eig_valsC), eig_valsC)
    print ('min real eig vals of invcov ', np.min(eig_valsinvC), np.max(np.abs(eig_valsinvC)), np.abs(eig_valsinvC))
    print ('min real eig vals of L ', np.min(eig_valsL), np.max(np.abs(eig_valsL)), np.abs(eig_valsL), eig_valsL)
    print ('min real eig vals of A ', np.min(eig_valsA), np.max(np.abs(eig_valsA)), np.abs(eig_valsA))
    fig, ax = plt.subplots(5, figsize=(12, 32))
    sns.heatmap(cov, cmap='viridis', annot=True, fmt='.1f',\
        cbar_kws={'label': 'covariance'}, ax=ax[0],vmin=-1,vmax=1)
    sns.heatmap(ncov, cmap='viridis', annot=True, fmt='.1f',\
        cbar_kws={'label': 'ncov'}, ax=ax[1],vmin=-1,vmax=1)
    sns.heatmap(invcov, cmap='viridis', annot=True, fmt='.1f',\
        cbar_kws={'label': 'invcov'}, ax=ax[2],vmin=-1,vmax=1)
    sns.heatmap(L, cmap='viridis', annot=True, fmt='.1f',\
        cbar_kws={'label': 'L'}, ax=ax[3],vmin=-1,vmax=1)
    sns.heatmap(A, cmap='viridis', annot=True, fmt='.1f',\
        cbar_kws={'label': 'A'}, ax=ax[4],vmin=-1,vmax=1)
    # sns.heatmap(cov - gamma * ncov, cmap='viridis', annot=True, fmt='.3f',\
    #     cbar_kws={'label': 'cov-gamma*ncov'}, ax=ax[2],vmin=-1,vmax=1)

    plt.title('covariance Heatmap')
    # ax[0].set_xlabel('(s,a)')
    # ax[0].set_ylabel('(s,a)')

    # ax[1].set_xlabel('(s\',a\')')
    # ax[1].set_ylabel('(s,a)')

    plt.savefig('{}_cov_{}.jpg'.format(enc_name, epoch))
    plt.close()

def compute_inverse(matrix):
    global global_lam_inv
    if torch.is_tensor(matrix):
        #matrix += global_lam_inv * torch.eye(*matrix.shape)
        #inv_matrix = torch.linalg.inv(matrix)
        inv_matrix = torch.linalg.pinv(matrix)
    else:
        #matrix += global_lam_inv * np.eye(*matrix.shape)
        #inv_matrix = np.linalg.inv(matrix)
        inv_matrix = np.linalg.pinv(matrix)
    return inv_matrix

def get_phi_distances(data, phi = None, enc_name = None, beta = None, ground_weights = None, abs_weights = None):
    from sklearn.neighbors import NearestNeighbors
    dps = data.curr_state_actions

    all_curr_sa = dps
    all_next_sa = data.next_state_actions
    curr_sa, idx = np.unique(dps, axis = 0, return_index = True)
    q_vals = data.q_values[idx].astype(int)

    next_sa, idx = np.unique(data.next_state_actions, axis = 0, return_index = True)

    curr_sa_unq_q = np.dot(curr_sa, ground_weights)
    next_sa_unq_q = np.dot(next_sa, ground_weights)
    decoded_sa = []
    decoded_nsa = []
    for idx, i in enumerate(curr_sa):
        onehot = np.argmax(i)
        st = onehot // 4
        act = onehot - st * 4
        decoded_sa.append((st, act, curr_sa_unq_q[idx].astype(int)))

    for idx, i in enumerate(next_sa):
        onehot = np.argmax(i)
        st = onehot // 4
        act = onehot - st * 4
        decoded_nsa.append((st, act, next_sa_unq_q[idx].astype(int)))

    # getting q values based on original features
    curr_sa_true_q = np.dot(all_curr_sa, ground_weights)
    next_sa_true_q = np.dot(all_next_sa, ground_weights)
    q_true_diff = np.abs(curr_sa_true_q - next_sa_true_q).astype(int)

    if phi is not None:
        curr_sa = phi(curr_sa)
        next_sa = phi(next_sa)
        all_curr_sa = phi(all_curr_sa)
        all_next_sa = phi(all_next_sa)

    curr_sa = curr_sa.astype(np.float64)
    next_sa = next_sa.astype(np.float64)

    # measuring nearest neighbor of each (s',a') from each (s,a) in dataset
    k_neighbors = 1
    nn_model = NearestNeighbors(n_neighbors=k_neighbors, metric = 'cosine')
    nn_model.fit(all_curr_sa)
    distances, indices = nn_model.kneighbors(next_sa)

    #distances = np.mean(distances, axis = 1)
    distances = distances.reshape(-1)
    indices = indices.reshape(-1)

    srted_indices = np.argsort(distances)[::-1]

    for i in srted_indices:
        print(f"Nearest neighbor of X[{i}] with distance {distances[i]}")
        print(f"Q value of X[{i}] is {next_sa_true_q[i]} with Q values of NNs are: {curr_sa_true_q[indices[i]]}")


    # for i in range(all_next_sa.shape[0]):
    #     if distances[i, 0] == 0:
    #         continue
    #     pdb.set_trace()
    #     print(f"Nearest neighbor of X[{i}] with distance {distances[i, :k_neighbors]}")
    #     print(f"Q value of X[{i}] is {next_sa_true_q[i]} with Q values of NNs are: {curr_sa_true_q[indices[i]].reshape(-1)}")

    # getting cluster representations and then measuring distances between them
    #samps = np.vstack((curr_sa, next_sa))
    curr_sa_labels = [f'({item[0]}, {item[1]}) {item[2]}' for item in decoded_sa]
    next_sa_labels = [f'({item[0]}, {item[1]}) {item[2]}' for item in decoded_nsa]

    fig, ax = plt.subplots(1, figsize=(12, 8))
    cosine_sim_mat = cosine_similarity(curr_sa)
    sns.heatmap(cosine_sim_mat, cmap='viridis', annot=True, fmt='.3f',\
        cbar_kws={'label': 'similarity'}, ax=ax,vmin=-1,vmax=1, xticklabels = curr_sa_labels, yticklabels = curr_sa_labels)

    plt.title('similarity Heatmap')
    plt.xlabel('(s,a)')
    plt.ylabel('(s,a)')
    plt.savefig('{}_csa.jpg'.format(enc_name))
    plt.close()

    np.fill_diagonal(cosine_sim_mat, np.nan)
    mean_sim = np.nanmean(cosine_sim_mat, axis=1)
    print ('csa cosine ', mean_sim)

    # measuring distance of next-state-action clusters from dataset
    fig, ax = plt.subplots(1, figsize=(12, 8))
    cosine_sim = pairwise_distances(next_sa, curr_sa, metric = 'euclidean')
    cosine_sim = cosine_sim / np.max(cosine_sim)
    # heatmap = sns.heatmap(cosine_sim, cmap='viridis', annot=True, fmt='.3f',\
    #     cbar_kws={'label': 'l2'}, ax=ax[0], xticklabels = curr_sa_labels, yticklabels = next_sa_labels)
    # heatmap.set_xticklabels(heatmap.get_xticklabels(), **{'rotation': 45, 'fontsize': 14, 'ha': 'right'})
    # heatmap.set_yticklabels(heatmap.get_yticklabels(), **{'rotation': 0, 'fontsize': 14, 'va': 'center'})
    # ax[0].set_title('l2 dist Heatmap')
    # ax[0].set_xlabel('(s,a)')
    # ax[0].set_ylabel('(s\',a\')')
    #plt.savefig('{}_nsa.jpg'.format(enc_name))
    #plt.close()
    np.fill_diagonal(cosine_sim, np.nan)
    mean_sim = np.nanmean(cosine_sim, axis=1)
    print ('nsa and csa cosine ', mean_sim)

    #fig, ax = plt.subplots(1, figsize=(12, 8))
    cosine_sim = cosine_similarity(next_sa, curr_sa)
    heatmap = sns.heatmap(cosine_sim, cmap='viridis', annot=True, fmt='.3f',\
        cbar_kws={'label': 'similarity'}, ax=ax,vmin=-1,vmax=1, xticklabels = curr_sa_labels, yticklabels = next_sa_labels)
    heatmap.set_xticklabels(heatmap.get_xticklabels(), **{'rotation': 45, 'fontsize': 14, 'ha': 'right'})
    heatmap.set_yticklabels(heatmap.get_yticklabels(), **{'rotation': 0, 'fontsize': 14, 'va': 'center'})
    ax.set_title('similarity Heatmap')
    ax.set_xlabel('(s,a)')
    ax.set_ylabel('(s\',a\')')
    plt.savefig('{}_nsa.jpg'.format(enc_name))
    plt.close()
    np.fill_diagonal(cosine_sim, np.nan)
    mean_sim = np.nanmean(cosine_sim, axis=1)
    print (mean_sim)

    next_sa_pred_q = np.dot(all_next_sa, abs_weights)
    q_ope_error = np.abs(next_sa_true_q - next_sa_pred_q)
    
    normalized_curr_sa = curr_sa / np.linalg.norm(curr_sa, axis=1, keepdims=True)
    cosine_similarities = np.dot(normalized_curr_sa, normalized_curr_sa.T)
    pairwise_angles_rad = cosine_similarities#np.arccos(np.clip(cosine_similarities, -1.0, 1.0))

    feat_next_coadpt = np.sum(all_curr_sa * all_next_sa, axis = 1)
    feat_curr_coadpt = np.sum(all_curr_sa * all_curr_sa, axis = 1)
    feat_coadapt_diff = feat_curr_coadpt - 0.99 * feat_next_coadpt

    euc_dist = np.sqrt(np.sum((curr_sa[:, np.newaxis] - curr_sa) ** 2, axis=-1))

    uti = np.triu_indices(pairwise_angles_rad.shape[0], k=1)
    sub_pw_angle_dist = pairwise_angles_rad[uti]
    sub_pw_euc_dist = euc_dist[uti]
    sub_pw_next_featco = feat_next_coadpt
    sub_pw_curr_featco = feat_curr_coadpt

    if True:#enc_name == 'rope':
        phi_x_norm = np.linalg.norm(all_curr_sa, axis = 1)
        phi_y_norm = np.linalg.norm(all_next_sa, axis = 1)
        cs = np.sum(all_curr_sa * all_next_sa, axis = 1) / (phi_x_norm * phi_y_norm)
        angle = np.arctan2(np.sqrt(1. + 1e-5 - np.square(cs)), cs)
        norm_avg = 0.5 * (np.square(phi_x_norm) + np.square(phi_y_norm))
        cs_dist = angle
        curr_Uxy = norm_avg + beta * cs_dist

        correlation_matrix = np.corrcoef(q_ope_error, feat_next_coadpt)
        correlation_coefficient = correlation_matrix[0, 1]
        print("Correlation Coefficient next featcoadapt:", correlation_coefficient)
        correlation_matrix = np.corrcoef(q_ope_error, feat_coadapt_diff)
        correlation_coefficient = correlation_matrix[0, 1]
        print("Correlation Coefficient diff featcoadapt:", correlation_coefficient)
        #correlation_matrix = np.corrcoef([curr_Uxy, feat_next_coadpt, q_true_diff])
        
        df = pd.DataFrame({'rope': curr_Uxy,
            'feat_next_coadpt': feat_next_coadpt,
            'q_true_diff': q_true_diff,
            'feat_coadapt_diff': feat_coadapt_diff,
            'anglular_dist': angle,
            'q_ope_error': q_ope_error})
        fig, ax = plt.subplots(3, 2)
        sns.scatterplot(x='rope', y='feat_next_coadpt', data=df, ax=ax[0, 0], s=120)
        sns.scatterplot(x='anglular_dist', y='feat_next_coadpt', data=df, ax=ax[0, 1], s=120)
        sns.scatterplot(x='rope', y='feat_coadapt_diff', data=df, ax=ax[1, 0], s=120)
        sns.scatterplot(x='anglular_dist', y='feat_coadapt_diff', data=df, ax=ax[1, 1], s=120)
        sns.scatterplot(x='q_ope_error', y='feat_next_coadpt', data=df, ax=ax[2, 0], s=120)
        sns.scatterplot(x='q_ope_error', y='feat_coadapt_diff', data=df, ax=ax[2, 1], s=120)
        #plt.savefig('{}_dist_feat_corr.jpg'.format(enc_name))
        #plt.savefig('{}_correlations.jpg'.format(enc_name))
        plt.close()

    # df = pd.DataFrame({'feat_next_coadpt': feat_next_coadpt, 'q_true_diff': q_true_diff})
    # fig, ax = plt.subplots(1)
    # sns.scatterplot(x='q_true_diff', y='feat_next_coadpt', data=df, ax=ax, s=120)
    # plt.savefig('{}_featcoadapt_qdiff_corr.jpg'.format(enc_name))
    # plt.close()

    grouped_feat_coadpt = {}
    grouped_diff_feat_coadpt = {}
    for label, e1, e2 in zip(q_true_diff, feat_next_coadpt, feat_coadapt_diff):
        if label not in grouped_feat_coadpt:
            grouped_feat_coadpt[label] = []
            grouped_diff_feat_coadpt[label] = []
        grouped_feat_coadpt[label].append(e1)
        grouped_diff_feat_coadpt[label].append(e2)
    
    unq_qdiffs = np.unique(q_true_diff)
    total_fco = np.array([np.mean(grouped_feat_coadpt[uq]) for uq in unq_qdiffs])
    df = pd.DataFrame({'feat_next_coadpt': total_fco, 'q_true_diff': unq_qdiffs})
    fig, ax = plt.subplots(1)
    sns.scatterplot(x='q_true_diff', y='feat_next_coadpt', data=df, ax=ax, s=120)
    #plt.savefig('{}_featcoadapt_qdiff_corr.jpg'.format(enc_name))
    plt.close()

    diff_fco = np.array([np.mean(grouped_diff_feat_coadpt[uq]) for uq in unq_qdiffs])
    df = pd.DataFrame({'diff_feat_next_coadpt': diff_fco, 'q_true_diff': unq_qdiffs})
    fig, ax = plt.subplots(1)
    sns.scatterplot(x='q_true_diff', y='diff_feat_next_coadpt', data=df, ax=ax, s=120)
    #plt.savefig('{}_difffeatcoadapt_qdiff_corr.jpg'.format(enc_name))
    plt.close()

    sorted_data = np.sort(feat_coadapt_diff)
    cumulative_prob = np.linspace(0, 1, len(sorted_data))
    plt.plot(sorted_data, cumulative_prob, label='Cumulative Distribution')
    #plt.hist(feat_next_coadpt, bins=30, color='blue', alpha=0.7)
    plt.title('CDF')
    plt.xlabel('feat_coadapt_diff')
    plt.ylabel('Cumulative Probability')
    plt.grid(True)
    #plt.savefig('{}_featcoadapt_hist.jpg'.format(enc_name))
    plt.close()

    # pt_to_dist = {}
    # for i, j in zip(uti[0], uti[1]):
    #     pt_to_dist[(i, j)] = pairwise_angles_rad[i, j]

    print ('pairwise cosine dist stats: ', np.mean(sub_pw_angle_dist), np.std(sub_pw_angle_dist), np.max(sub_pw_angle_dist), np.min(sub_pw_angle_dist))
    print ('pairwise euc dist stats: ', np.mean(sub_pw_euc_dist), np.std(sub_pw_euc_dist), np.max(sub_pw_euc_dist), np.min(sub_pw_euc_dist))
    print ('pairwise feat coadapt stats: ', np.sum(feat_coadapt_diff), np.sum(sub_pw_curr_featco), np.mean(sub_pw_curr_featco), np.std(sub_pw_curr_featco), np.max(sub_pw_curr_featco), np.min(sub_pw_curr_featco))



def get_init_sa_stats(data, pie, phi = None, tabular = False):

    if tabular:
        q_init_inputs = torch.Tensor(data.init_state_actions)
    else:
        init_ground_states = torch.Tensor(data.get_initial_states_samples(-1))
        sampled_actions = torch.Tensor(pie.batch_sample(data.unnormalize_states(init_ground_states)))
        q_init_inputs = torch.concat((init_ground_states, sampled_actions), axis = 1)

    if phi is not None:
        q_init_inputs = torch.Tensor(phi(q_init_inputs))
    
    phi_stats = get_phi_stats(q_init_inputs.numpy())
    return phi_stats

def get_q_values(dataset, pie, critic, gamma):

    if critic is None:
        return None, None
    rews = dataset.rewards.reshape(-1,1)
    term_masks = dataset.terminal_masks.reshape(-1,1)
    csa = dataset.curr_state_actions
    ns =  dataset.next_states
    pie_next_action = pie.batch_sample(dataset.unnormalize_states(ns))
    nsa = np.concatenate((ns, pie_next_action), axis = 1)
    q_csa = critic(csa)
    q_target = rews + gamma * term_masks * critic(nsa)
    return q_csa, q_target

def bc_measure(phi, input_dim, data, gamma, lam = 0.):

    def _self_pred_error(pred_next_phi_x, next_phi_x):
        BC_loss = torch.square(torch.linalg.vector_norm(pred_next_phi_x - gamma * next_phi_x, dim=1))
        BC_loss = BC_loss.mean()
        # cosine_sim = torch.sum(pred_next_phi_x * next_phi_x, axis = 1)\
        #     / (torch.linalg.norm(next_phi_x, axis = 1) * torch.linalg.norm(pred_next_phi_x, axis = 1))
        # BC_loss = -cosine_sim.mean()
        return BC_loss

    M_phi = NeuralNetwork(input_dims = input_dim,
                            output_dims = input_dim,
                            hidden_dim = -1,
                            hidden_layers = 0,
                            activation = None,
                            final_activation = None)
    M_rew = NeuralNetwork(input_dims = input_dim,
                            output_dims = 1,
                            hidden_dim = -1,
                            hidden_layers = 0,
                            activation = None,
                            final_activation = None)

    M_lr = 1e-3
    M_params = list(M_phi.parameters()) + list(M_rew.parameters())
    M_optimizer = torch.optim.AdamW(M_params, lr = M_lr)

    epochs = 500
    num_workers = 6
    mini_batch_size = 2048
    params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
    dataloader = torch.utils.data.DataLoader(data, **params)

    total_loss = []
    for epoch in range(0, epochs + 1):
        losses = []
        r_loss = []
        pred_loss = []      
        for mb in dataloader:
            curr_sa = mb['curr_sa']
            next_sa = mb['next_sa']
            rewards = mb['rewards']
            terminal_masks = mb['terminal_masks']
            
            terminal_masks = terminal_masks.reshape(-1, 1)
            with torch.no_grad():
                if phi is not None:
                    phi_x = torch.Tensor(phi(curr_sa))
                    next_phi_x = terminal_masks * torch.Tensor(phi(next_sa))
                else:
                    phi_x = torch.Tensor(curr_sa)
                    next_phi_x = terminal_masks * torch.Tensor(next_sa)

            pred_next_phi_x = M_phi.forward(phi_x)
            pred_rew = M_rew.forward(phi_x)
            rewards = torch.Tensor(rewards.reshape((-1, 1)))
            reward_loss = F.mse_loss(pred_rew, rewards)
        
            BC_loss = _self_pred_error(pred_next_phi_x, next_phi_x)

            M_loss = reward_loss + BC_loss
            losses.append(M_loss)
            r_loss.append(reward_loss)
            pred_loss.append(BC_loss)

            if epoch > 0:
                M_optimizer.zero_grad()
                M_loss.backward()
                nn.utils.clip_grad_value_(M_params, clip_value = 1.0)
                M_optimizer.step()
        avg_loss = torch.Tensor(losses).mean()
        avg_r_loss = torch.Tensor(r_loss).mean()
        avg_pred_loss = torch.Tensor(pred_loss).mean()
        print ('BC measure epoch: {} loss: {}, r loss: {}, pred loss: {}'.format(epoch, avg_loss, avg_r_loss, avg_pred_loss))
        total_loss.append(avg_loss.item())

    reps = torch.Tensor(data.curr_state_actions)
    reps = torch.Tensor(phi(reps))
    cov = torch.matmul(reps.T, reps) / reps.shape[0]
    cov = cov + 1e-4 * torch.eye(cov.shape[1])
    try:
        log_det = 2 * torch.linalg.cholesky(cov).diagonal(dim1=-2, dim2=-1).log().sum(-1).item()
    except:
        log_det = -1

    bc_stats = {
        'bc_err': total_loss[-1],
        'bc_errs': total_loss,
        'logdet': log_det,
        'bc_err_logdet': total_loss[-1] - lam * log_det
    }
    return bc_stats

def print_dataset_info(pathname):
    dataset_info = np.load('{}.npy'.format(pathname), allow_pickle = True).item()

    data = dataset_info['dataset']
    num_terminal_states = data['num_samples'] - np.count_nonzero(data['terminal_masks'])

    info = 'name: {}\n'.format(dataset_info['dataset_name'])\
        + 'batch size: {}\n'.format(dataset_info['batch_size'])\
        + 'traj length: {}\n'.format(dataset_info['traj_len'])\
        + 'oracle ret est: {}\n'.format(dataset_info['oracle_est_ret'])\
        + 'data ret est: {}\n'.format(dataset_info['data_est_ret'])\
        + 'pib ret est: {}\n'.format(dataset_info['pib_est_ret'])\
        + 'gamma: {}\n'.format(dataset_info['gamma'])\
        + 'samples: {}\n'.format(data['num_samples'])\
        + 'num endings: {}'.format(num_terminal_states)
    print (info)

def truncation_horizon_limit(gamma, thresh = 0.1, increments = 50):
    max_T = 2000
    for t in range(0, max_T, increments):
        if gamma ** t < thresh: # sufficiently small gamma^t
            gms_weight = np.sum([gamma ** x for x in range(0, t + 1)]) * (1. - gamma)
            print (f'rollout horizon {t}, gamma^t {gamma ** t}, eff horizon weight: {gms_weight}')
            return t
    return max_T

def load_env(env_name, gamma, env_type = 'dmc'):
    import gymnasium as gym
    if env_type == 'dmc':
        if env_name == 'CartPoleSwingUp':
            env = gym.make("dm_control/cartpole-swingup-v0")
        elif env_name == 'CheetahRun':
            env = gym.make("dm_control/cheetah-run-v0", disable_env_checker = True)
        elif env_name == 'WalkerStand':
            env = gym.make("dm_control/walker-stand-v0", disable_env_checker = True)
        elif env_name == 'FingerEasy':
            env = gym.make("dm_control/finger-turn_easy-v0", disable_env_checker = True)
        env = gym.wrappers.FlattenObservation(env)
    elif env_type == 'd4rl':
        if env_name == 'Walker':
            env = gym.make("Walker2d-v4")
        elif env_name == 'Hopper':
            env = gym.make("Hopper-v4")
        elif env_name == 'Cheetah':
            env = gym.make("HalfCheetah-v4")
    return env

def load_d4rl_policy(config, env, env_name, pi_num):
    dir_path = 'env_policies'#env_name.lower() + '/policies'
    pi_url = config[env_name]['d4rl']['pi_url'].format(pi_num)
    pi_filepath = dir_path + '/' + pi_url
    pi = D4RLPolicy(pi_filepath)
    return pi

def load_policies(config, env, env_name, pi_num, env_type = 'dmc'):
    if env_type == 'dmc':
        algo = 'SAC'
        if env_name == 'CartPoleSwingUp':
            pi_path ='env_policies/cpswing_{}_steps'.format(pi_num)
        elif env_name == 'CheetahRun':
            pi_path ='env_policies/cheetah-run_{}_steps'.format(pi_num)
        elif env_name == 'WalkerStand':
            pi_path ='env_policies/walker-std_{}_steps'.format(pi_num)
        elif env_name == 'FingerEasy':
            pi_path ='env_policies/finger-easy_{}_steps'.format(pi_num)
        pi = StablebaselinePolicy(name = 'MlpPolicy', algo = algo,\
            env = env, pretrained_path = pi_path, deterministic = False)
    elif env_type == 'd4rl':
        pi = load_d4rl_policy(config, env, env_name, pi_num)
    return pi

def load_sb_mixture(env, pis, weights):
    return StablebaselinePolicyMixture(env, pis, weights)

def load_mini_batches(data, tabular, pie = None, mini_batch_size = 512):
    terminal_masks, other_terminal_masks = None, None
    if tabular:
        sub_data = data.get_samples(mini_batch_size)
        curr_sa = sub_data['curr_state_actions']
        rewards = sub_data['rewards']
        #next_sa = sub_data['next_state_actions']
        next_sa = pie.sample_sa_features(sub_data['next_states'])
        terminal_masks = sub_data['terminal_masks']

        sub_data = data.get_samples(mini_batch_size)
        other_curr_sa = sub_data['curr_state_actions']
        other_rewards = sub_data['rewards']
        #other_next_sa = sub_data['next_state_actions']
        other_next_sa = pie.sample_sa_features(sub_data['next_states'])
        other_terminal_masks = sub_data['terminal_masks']
    else:
        sub_data = data.get_samples(mini_batch_size)
        curr_states = sub_data['curr_states']
        curr_actions = sub_data['curr_actions']
        curr_sa = np.concatenate((curr_states, curr_actions), axis = 1)
        rewards = sub_data['rewards']
        next_states = sub_data['next_states']
        pie_next_actions = pie.batch_sample(data.unnormalize_states(next_states))
        next_sa = np.concatenate((next_states, pie_next_actions), axis = 1)
        terminal_masks = sub_data['terminal_masks']

        sub_data = data.get_samples(mini_batch_size)
        other_curr_states = sub_data['curr_states']
        other_curr_actions = sub_data['curr_actions']
        other_curr_sa = np.concatenate((other_curr_states, other_curr_actions), axis = 1)
        other_rewards = sub_data['rewards']
        other_next_states = sub_data['next_states']
        other_pie_next_actions = pie.batch_sample(data.unnormalize_states(other_next_states))
        other_next_sa = np.concatenate((other_next_states, other_pie_next_actions), axis = 1)
        other_terminal_masks = sub_data['terminal_masks']

    batch = {
        'curr_sa': curr_sa,
        'next_sa': next_sa,
        'rewards': rewards,
        'other_curr_sa': other_curr_sa,
        'other_next_sa': other_next_sa,
        'other_rewards': other_rewards,
        'terminal_masks': terminal_masks,
        'other_terminal_masks': other_terminal_masks
    }
    return batch

def _get_rank(matrix, type = 'rank', th = 0.01):
    if type == 'srank':
        _, sing_vals, _ = np.linalg.svd(matrix)
        th = 1 - th
        den = np.sum(sing_vals)
        cum_sum = 0
        for srank in range(1, len(sing_vals) + 1):
            cum_sum += sing_vals[srank - 1] 
            rat = cum_sum / den
            if rat >= th:
                break
        return srank
    elif type == 'rank':
        return np.linalg.matrix_rank(matrix)

def get_phi_stats(reps, next_reps = None, gamma = None, terminals = None):

    if reps.shape[0] >= 25000:
        sampled_rows = np.random.choice(reps.shape[0], size=10000, replace=False)
        sub_reps = reps[sampled_rows, :]
        if next_reps is not None:
            sub_next_reps = next_reps[sampled_rows, :]
    else:
        sub_reps = reps
        sub_next_reps = next_reps
    mean_dim = np.mean(np.mean(reps, axis=0))
    std_dim = np.mean(np.std(reps, axis=0))
    rank, srank = -1, -1
    srank = _get_rank(reps)
    rank = srank

    # U, S, Vt = np.linalg.svd(reps)
    # basis_vectors = Vt.T
    # print (basis_vectors, S)

    cosine_sim_mat = cosine_similarity(sub_reps)
    np.fill_diagonal(cosine_sim_mat, np.nan)
    cosine_sim_mat = np.abs(cosine_sim_mat)
    orthos = 1 - cosine_sim_mat
    min_ortho = np.nanmin(orthos)
    max_ortho = np.nanmax(orthos)
    scaled_orthos = (orthos - min_ortho) / max(max_ortho - min_ortho, 1e-10)
    ortho = np.nanmean(orthos)
    scaled_ortho = np.nanmean(scaled_orthos)

    norm_reps = sub_reps / np.maximum(np.linalg.norm(sub_reps, axis = 1, keepdims = True), 1e-10)
    pw_euc = pairwise_distances(norm_reps, norm_reps, metric = 'euclidean')
    np.fill_diagonal(pw_euc, np.nan)
    entropy = np.nanmean(pw_euc)# / np.nanmax(pw_euc)

    # _, sing_vals, _ = np.linalg.svd(reps)
    

    # num_rows, _ = reps.shape
    # thresh = 1e-5
    # u, s, _ = np.linalg.svd(reps, full_matrices=False)
    # rank = max(np.sum(s >= thresh), 1)
    # u1 = u[:, :rank]
    # projected_basis = np.matmul(u1, np.transpose(u1))
    # norms = np.linalg.norm(projected_basis, axis=0, ord=2) ** 2
    # #eff_dim = num_rows * np.max(norms)
    # eff_dim = np.max(norms)
    eff_dim = -1

    cov = np.matmul(reps.T, reps) / reps.shape[0]
    #cov_det = np.linalg.det()
    #cov = torch.Tensor(cov + 1e-5 * np.eye(cov.shape[1]))
    cov = cov + 1e-6 * np.eye(cov.shape[1])
    min_cov_eig = -1
    condition_num = -1
    rep_cond_num = -1
    sings = -1
    try:
        #log_det = np.linalg.det(cov)
        #cov_eigs, _ = np.linalg.eig(cov)
        #min_cov_eig = np.min(cov_eigs)
        cov = torch.Tensor(cov)
        #condition_num = torch.linalg.cond(cov).item()
        sings = torch.linalg.svdvals(cov).numpy()
        condition_num = sings[0] / sings[-1]
        log_det = 2 * torch.linalg.cholesky(cov).diagonal(dim1=-2, dim2=-1).log().sum(-1).item()
    except:
        log_det = -1

    spectral_rad = -1
    feat_next_coadpt = -1
    feat_curr_coadpt = -1
    feat_dp_diff = -1
    real_pos_frac = -1
    cosine_sim = -1
    dyn_aware = -1
    real_eigs_A = -1
    if next_reps is not None:
        rep_cond_num = torch.linalg.cond(torch.Tensor(reps - gamma * terminals * next_reps)).item()
        #reps = reps / np.linalg.norm(reps, axis = 1, keepdims = True)
        #next_reps = next_reps / np.linalg.norm(next_reps, axis = 1, keepdims = True)

        rnd_states_sa_ind = np.random.choice(reps.shape[0], size=reps.shape[0], replace=False)
        rnd_states_sa = reps[rnd_states_sa_ind, :]
        rnd_gap = np.sum(np.linalg.norm(reps - rnd_states_sa, axis = 1))
        succ_gap = np.sum(np.linalg.norm(reps - terminals * next_reps, axis = 1))
        dyn_aware = -1#1 - succ_gap / rnd_gap

        num_samples = reps.shape[0]
        lam = np.matmul(reps.T,  reps)
        lam /= num_samples
        lam_inv = compute_inverse(lam)
        curr_next_feat = gamma * np.matmul(reps.T, terminals * next_reps)
        curr_next_feat = curr_next_feat / num_samples
        L = np.matmul(lam_inv, curr_next_feat)
        eig_vals_L, _ = np.linalg.eig(L)
        spectral_rad = np.max(np.abs(eig_vals_L))

        A = np.matmul(reps.T, (reps - gamma * terminals * next_reps))
        A = A / reps.shape[0]
        eig_vals_A = np.linalg.eigvals(A)
        num_eigs = len(eig_vals_A)
        real_eigs_A = np.real(eig_vals_A)
        real_eigs_A[real_eigs_A <= 1e-10] = 0
        real_pos_frac = np.count_nonzero(real_eigs_A > 0) / num_eigs
        #real_pos_frac = np.real(eig_vals_A)
        # spectral_rad = real_pos_frac#np.max(np.abs(eig_vals))
        #spectral_rad = np.max(np.abs(eig_vals))
        # curr_dot_prod = np.sum(np.sum(reps * reps, axis = 1))
        # next_dot_prod = np.sum(np.sum(reps * next_reps, axis = 1))
        # feat_dp_diff = gamma * next_dot_prod - curr_dot_prod
        feat_next_coadpt = np.sum(reps * terminals * next_reps, axis = 1)
        #feat_curr_coadpt = np.sum(reps * reps, axis = 1)
        #feat_dp_diff = np.sum(feat_curr_coadpt - gamma * feat_next_coadpt)

        feat_next_coadpt = np.sum(feat_next_coadpt)
        #feat_curr_coadpt = feat_curr_coadpt.sum()
        # cosine_sim = np.sum(reps * next_reps, axis = 1) / (np.linalg.norm(reps, axis = 1) * np.linalg.norm(next_reps, axis = 1))
        # cosine_sim = np.mean(cosine_sim)

    stats = {
        'mean_dim': mean_dim,
        'std_dim': std_dim,
        'rank': rank,
        'srank': srank,
        #'eff_dim': eff_dim,
        'spectral_radius': spectral_rad,
        'log_det': log_det,
        'feat_coadapt': feat_next_coadpt,
        #'feat_selfcoadapt': feat_curr_coadpt,
        #'feat_dp_diff': feat_dp_diff,
        'pos_eigen_frac': real_pos_frac,
        #'cosine_sim': cosine_sim,
        #'min_cov_eig': min_cov_eig,
        'orthogonality': ortho,
        'scaled_orthogonality': scaled_ortho,
        'dyn_aware': dyn_aware,
        'cov_condition_num': condition_num,
        'data_condition_num': rep_cond_num,
        'entropy': entropy,
        'cov_singular_values': sings,
        'eigen_values': real_eigs_A
    }
    return stats

def get_phi_capacity_stats(curr_sa, pie, phi_curr_sa = None):

    if not hasattr(pie, 'pi'):
        return  {
            'specialization': -1,
            'complexity_red': -1
        }
    if curr_sa.shape[0] >= 25000:
        sampled_rows = np.random.choice(curr_sa.shape[0], size=10000, replace=False)
        curr_sa = curr_sa[sampled_rows, :]
        if phi_curr_sa is not None:
            phi_curr_sa = phi_curr_sa[sampled_rows, :]
    if phi_curr_sa is None:
        phi_curr_sa = curr_sa

    with torch.no_grad():
        curr_sa_val0 = pie.pi.critic_target.qf0(torch.Tensor(curr_sa)).numpy()
        curr_sa_val1 = pie.pi.critic_target.qf1(torch.Tensor(curr_sa)).numpy()
        curr_sa_val = (curr_sa_val0 + curr_sa_val1) / 2

    #phi_curr_sa = phi_curr_sa / np.linalg.norm(phi_curr_sa, axis = 1, keepdims = True)

    dval = pairwise_distances(curr_sa_val, metric = 'l1')
    euc_dist = pairwise_distances(phi_curr_sa, metric = 'euclidean')

    L = np.divide(dval, euc_dist, out=np.ones_like(dval) * -1, where=euc_dist!=0)
    L[L == -1] = np.nan

    # complexity reduction
    Lmax = np.nanmax(L)
    Lrep = np.nanmean(L)
    comp_red = 1 - (Lrep / Lmax)
        
    # specialization check
    max_dval = np.max(dval)
    max_euc_dist = np.max(euc_dist)
    dval = dval / max_dval
    euc_dist = euc_dist / max_euc_dist

    special = dval / (euc_dist + 1e-2)
    special = np.minimum(special, 1)
    special = np.mean(special)

    res = {
        'specialization': special,
        'complexity_red': comp_red
    }
    return res

class PhiMetricTracker:
    def __init__(self, pie, gamma):
        self.pie = pie
        self.gamma = gamma
        self.init_realize_error = -1
        self.metrics = {
            'phi_tr_loss': {},
            'phi_brm_loss': {},
            'phi_ope_error': {},
            'phi_ope_ret': {},
            'phi_realize_error': {},
            'phi_mean_dim': {},
            'phi_std_dim': {},
            'phi_rank': {},
            'phi_norm': {},
            'phi_gradient_norm': {},
            'phi_pos_eigen_frac': {},
            'phi_feat_coadapt': {},
            'phi_log_det': {},
            'phi_cosine_sim': {},
            'phi_spectral_radius': {},
            'phi_orthogonality': {},
            'phi_scaled_orthogonality': {},
            'phi_dyn_aware': {},
            'phi_comp_red': {},
            'phi_specialization': {},
            'phi_cov_condition_num': {},
            'phi_cov_singular_values': {},
            'phi_data_condition_num': {},
            'phi_entropy': {},
            'phi_eigen_values': {},
            'phi_ortho_qvaldiff_corr': {},
            'phi_ortho_sr_corr': {},
            'phi_val_eval_path_err': {}
        }
    
    def track_phi_training_stats(self, epoch, target_phi, losses,\
        ground_curr_sa, ground_next_sa, pie = None,\
        ope_evaluator = None, check_realizability = False, terminals = None):

        # optimization related metrics
        with torch.no_grad():
            #online_net_norm = torch.norm(torch.cat([p.view(-1) for p in phi.parameters()]), p=2)
            params = [p.view(-1) for p in target_phi.parameters()]
            target_net_norm = torch.Tensor([0])
            if len(params):
                target_net_norm = torch.norm(torch.cat(params), p=2)
            self.metrics['phi_tr_loss'][epoch] = losses['total_obj'].item()
            self.metrics['phi_norm'][epoch] = target_net_norm.item()
            
            if 'brm_obj'in losses:
                self.metrics['phi_brm_loss'][epoch] = losses['brm_obj'].item()
            
            if 'phi_ope_error' in losses:
                self.metrics['phi_ope_error'][epoch] = losses['phi_ope_error'].item()

            if 'phi_ope_ret' in losses:
                self.metrics['phi_ope_ret'][epoch] = losses['phi_ope_ret'].item()
            
            if 'phi_spectral_radius' in losses:
                self.metrics['phi_spectral_radius'][epoch] = losses['phi_spectral_radius']

            if 'phi_pos_eigen_frac' in losses:
                self.metrics['phi_pos_eigen_frac'][epoch] = losses['phi_pos_eigen_frac']

            if 'phi_val_eval_path_err' in losses:
                self.metrics['phi_val_eval_path_err'][epoch] = losses['phi_val_eval_path_err']
            
            if ope_evaluator is not None and check_realizability:
                realize_error = ope_evaluator.realizability(target_phi)
                if epoch == 0 and realize_error != 0:
                    self.init_realize_error = realize_error
                self.metrics['phi_realize_error'][epoch] = realize_error# / self.init_realize_error

                # realizability can be checked, we have true q values to measure ortho
                ortho_qdiff_corr, ortho_sr_corr = ope_evaluator.ortho_vs_qval(target_phi)
                self.metrics['phi_ortho_qvaldiff_corr'][epoch] = ortho_qdiff_corr
                self.metrics['phi_ortho_sr_corr'][epoch] = ortho_sr_corr

            temp_curr_sa = target_phi(ground_curr_sa)
            temp_next_sa = target_phi(ground_next_sa)
            terminals = terminals.reshape(-1, 1).numpy()

            cap_stats = get_phi_capacity_stats(ground_curr_sa, pie, temp_curr_sa)
            self.metrics['phi_comp_red'][epoch] = cap_stats['complexity_red']
            self.metrics['phi_specialization'][epoch] = cap_stats['specialization']

            phi_stats = get_phi_stats(temp_curr_sa, temp_next_sa, gamma = self.gamma, terminals = terminals)
            self.metrics['phi_mean_dim'][epoch] = phi_stats['mean_dim']
            self.metrics['phi_std_dim'][epoch] = phi_stats['std_dim']
            self.metrics['phi_rank'][epoch] = phi_stats['rank']
            #self.metrics['phi_pos_eigen_frac'][epoch] = phi_stats['pos_eigen_frac']
            self.metrics['phi_feat_coadapt'][epoch] = phi_stats['feat_coadapt']
            self.metrics['phi_log_det'][epoch] = phi_stats['log_det']
            #self.metrics['phi_spectral_radius'][epoch] = phi_stats['spectral_radius']
            self.metrics['phi_orthogonality'][epoch] = phi_stats['orthogonality']
            self.metrics['phi_scaled_orthogonality'][epoch] = phi_stats['scaled_orthogonality']
            self.metrics['phi_dyn_aware'][epoch] = phi_stats['dyn_aware']
            self.metrics['phi_cov_condition_num'][epoch] = phi_stats['cov_condition_num']
            self.metrics['phi_data_condition_num'][epoch] = phi_stats['data_condition_num']
            self.metrics['phi_entropy'][epoch] = phi_stats['entropy']
            #self.metrics['phi_cov_singular_values'][epoch] = phi_stats['cov_singular_values']
            #self.metrics['phi_eigen_values'][epoch] = phi_stats['eigen_values']

            current_epoch_stats = {}
            for met in self.metrics.keys():
                if epoch in self.metrics[met]:
                    current_epoch_stats[met] = self.metrics[met][epoch]
            return current_epoch_stats

def weight_regularizer(model, name = 'l2'):
    reg_loss = None
    if name == 'l2':
        with torch.enable_grad():
            reg_loss = torch.zeros(1)
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    reg_loss = reg_loss + (0.5 * torch.sum(torch.pow(param, 2)))
    elif name == 'orth':
        with torch.enable_grad():
            reg_loss = torch.zeros(1)
            for name, param in model.named_parameters():
                if 'bias' not in name:
                    param_flat = param.view(param.shape[0], -1)
                    sym = torch.mm(param_flat, torch.t(param_flat))
                    sym -= torch.eye(param_flat.shape[0])
                    reg_loss = reg_loss + sym.abs().sum()
    return reg_loss

def reset_optimizer(model, lr, adam_beta1, adam_beta2):
    return torch.optim.AdamW(model.parameters(), lr = lr, betas = (adam_beta1, adam_beta2))

def soft_target_update(online_net, target_net, tau = 0.005):
    online_params = online_net.state_dict()
    target_params = target_net.state_dict().items()
    for name, target_param in target_params:
        updated_params = tau * online_params[name] + (1. - tau) * target_param
        target_param.copy_(updated_params)

def online_target_difference(online_net, target_net):
    # online_params = online_net.state_dict()
    # target_params = target_net.state_dict().items()
    # diff = 0
    # for name, target_param in target_params:
    #     unit_online = online_params[name]
    #     unit_online = unit_online / torch.norm(unit_online)
    #     unit_target = target_param
    #     unit_target = unit_target / torch.norm(unit_target)
    #     diff += torch.norm(unit_online - unit_target)
    diff = 0
    return diff
    
# def collect_data_discrete(env, policy, num_trajectory, truncated_horizon, gamma = None):
#     if not gamma:
#         gamma = 1.
#     phi = env.phi
#     paths = []
#     total_reward = 0.0
#     densities = np.zeros((env.n_state, truncated_horizon))
#     frequency = np.zeros(env.n_state)
#     for i_trajectory in range(num_trajectory):
#         path = {}
#         path['obs'] = []
#         path['acts'] = []
#         path['rews'] = []
#         path['nobs'] = []
#         state = env.reset()
#         sasr = []
#         accum_gamma = np.ones(env.n_state)
#         for i_t in range(truncated_horizon):
#             action = policy(state)
#             #p_action = policy[state, :]
#             #action = np.random.choice(p_action.shape[0], 1, p = p_action)[0]
#             next_state, reward, done, _ = env.step(action)
#             path['obs'].append(state)
#             path['acts'].append(action)
#             path['rews'].append(reward)
#             path['nobs'].append(next_state)
#             #sasr.append((state, action, next_state, reward))
#             frequency[state] += 1
#             densities[state, i_t] += 1
#             total_reward += reward
#             state = next_state
#             if done:
#                 break
#         paths.append(path)

#     gammas = np.array([gamma ** i for i in range(truncated_horizon)])
#     d_sum = np.sum(densities, axis = 0)
#     densities = np.divide(densities, d_sum, out=np.zeros_like(densities), where = d_sum != 0)
#     disc_densities = np.dot(densities, gammas)
#     final_densities = (disc_densities / np.sum(gammas))
#     return paths, frequency, total_reward / (num_trajectory * truncated_horizon), final_densities

def collect_data(env, policy, num_trajectory, truncated_horizon = None, use_true_latent = False, random_pi = False):
    paths = []
    num_samples = 0
    total_reward = 0.0
    all_s = []
    all_a = []
    rets = []
    for i_trajectory in range(num_trajectory):
        path = {}
        path['obs'] = []
        path['nobs'] = []
        path['acts'] = []
        path['rews'] = []
        path['dones'] = []
        state, _ = env.reset() # v4 gym outputs ob, {}
        sasr = []
        i_t = 0
        ret = 0
        while True:
            if random_pi:
                action = env.action_space.sample()
            else:
                #action = policy(env.convert_to_latents(state) if hasattr(env, 'convert_to_latents') and use_true_latent else state)
                action = policy(state)
            next_state, reward, done, truncated, _ = env.step(action) # v4 changes
            path['obs'].append(state)
            path['acts'].append(action)
            path['rews'].append(reward)
            path['nobs'].append(next_state)
            #sasr.append((state, action, next_state, reward))
            total_reward += reward
            state = next_state
            path['dones'].append(done)
            all_s.append(state)
            all_a.append(action)
            i_t += 1
            num_samples += 1
            ret += reward
            #path['dones'].append(done or truncated)
            if done or truncated:
                break
            if truncated_horizon is not None and i_t >= truncated_horizon:
                break
        rets.append(ret)
        paths.append(path)

    avg_ret = np.mean(rets)
    # print ('online interaction')
    # print ('state')
    # print (np.mean(all_s,axis=0))
    # print (np.std(all_s, axis=0))
    # print ('action')
    # print (np.mean(all_a,axis=0))
    # print (np.std(all_a, axis=0))    
    return paths, total_reward / num_samples#(num_trajectory * truncated_horizon)

def collect_data_samples(env, policy, num_samples_to_collect, random_pi = False):
    paths = []
    num_samples = 0
    total_reward = 0.0
    all_s = []
    all_a = []

    collected_samples = 0
    num_trajs = 0
    while True:
        path = {}
        path['obs'] = []
        path['nobs'] = []
        path['acts'] = []
        path['rews'] = []
        path['dones'] = []
        state, _ = env.reset() # v4 gym outputs ob, {}
        while True:
            if random_pi:
                action = env.action_space.sample()
            else:
                action = policy(state)
            next_state, reward, done, truncated, _ = env.step(action) # v4 changes
            path['obs'].append(state)
            if isinstance(env.action_space, g.spaces.Discrete):
                act = np.zeros(env.action_space.n)
                act[action] = 1.
                action = act
            path['acts'].append(action)
            path['rews'].append(reward)
            path['nobs'].append(next_state)
            total_reward += reward
            state = next_state
            path['dones'].append(done)
            all_s.append(state)
            all_a.append(action)
            collected_samples += 1
            if done or truncated:
                break
            
        num_trajs += 1
        paths.append(path)
        if collected_samples >= num_samples_to_collect:
            break
    
    print ('collected {} trajectories'.format(num_trajs))
    
    # print ('online interaction')
    # print ('state')
    # print (np.mean(all_s,axis=0))
    # print (np.std(all_s, axis=0))
    # print ('action')
    # print (np.mean(all_a,axis=0))
    # print (np.std(all_a, axis=0))
    return paths, total_reward / collected_samples#(num_trajectory * truncated_horizon)

def merge_datasets(data_list):
    merged_data = []
    for d in data_list:
        for traj in d:
            merged_data.append(traj)
    return merged_data 

def format_data_new(data, gamma, normalize_state = False):
    g_data = data['ground_data']

    s = []
    a = []
    sa = []
    sprime = []
    abs_s = []
    abs_sa = []
    abs_sprime = []
    rewards = []
    gammas = []
    terminal_masks = []
    policy_ratios = []
    for idx in range(len(g_data)):
        path = g_data[idx]
        obs = path['obs']
        nobs = path['nobs']
        acts = path['acts']
        rews = path['rews']
        dones = path['dones']
        accum_gamma = 1.
        for t in range(len(obs)):
            o = obs[t] / (255. if normalize_state else 1)
            no = nobs[t] / (255. if normalize_state else 1)
            a_pib = acts[t]
            r = rews[t]
            s.append(o)
            a.append(a_pib)
            sprime.append(no)
            rewards.append(r)
            gammas.append(accum_gamma)
            accum_gamma *= gamma
            terminal_masks.append(int((not dones[t])))


    data = {
        'state_b': np.array(s),
        'action_b': np.array(a),
        'state_b_act_b': np.array(sa),
        'next_state_b': np.array(sprime),
        'rewards': np.array(rewards),
        'gammas': np.array(gammas),
        'init_state': data['initial_states'] / (255. if normalize_state else 1),
        'num_samples': len(s),
        'terminal_masks': np.array(terminal_masks),
        'true_g_policy_ratios': np.array(policy_ratios)
    }
    return data 

def split_dataset(data, tr_set_fraction = 1.):

    num_samples = data['dataset']['num_samples']
    num_tr = int(tr_set_fraction * num_samples)
    tr_indices = np.random.choice(num_samples, num_tr, replace = False)
    tr_indices_set = set(tr_indices)
    test_indices = np.array([i for i in range(num_samples) if i not in tr_indices_set])
    num_test = len(test_indices)
    assert len(tr_indices) + len(test_indices) == num_samples

    tr_data_obj = {
        'dataset_name': data['dataset_name'],
        'num_samples': num_tr,
        'seed': data['seed'],
        'dataset': {
            'init_state': data['dataset']['init_state'],
            'num_samples': num_tr
        }
    }
    test_data_obj = {
        'dataset_name': data['dataset_name'],
        'num_samples': num_test,
        'seed': data['seed'],
        'dataset': {
            'init_state': data['dataset']['init_state'],
            'num_samples': num_test
        }
    }

    split_skip = set(['gammas', 'true_g_policy_ratios', 'state_b_act_b', 'init_state', 'num_samples'])
    for key in data['dataset']:
        if key in split_skip:
            continue
        tr_data_obj['dataset'][key] = data['dataset'][key][tr_indices]
        if len(test_indices):
            test_data_obj['dataset'][key] = data['dataset'][key][test_indices]

    if num_test == 0:
        test_data_obj = None
    return tr_data_obj, test_data_obj

def get_err(true_val, pred_vals, metric = 'abs'):
    if metric == 'mse':
        error = np.square(np.array(true_val) - np.array(pred_vals))
    elif metric == 'abs':
        error = np.abs(np.array(true_val) - np.array(pred_vals))
    res = error.mean() #get_CI(error)
    return res

# statistics/visualization related
def get_CI(data, confidence = 0.95):

    if (np.array(data) == None).all():
        return {}
    if confidence == 0.95:
        z = 1.96
    elif confidence == 0.99:
        z = 2.576
    stats = {}
    n = len(data)
    mean = np.mean(data)
    std = np.std(data)
    err = z * (std / np.sqrt(n))
    lower = mean - z * (std / np.sqrt(n))
    upper = mean + z * (std / np.sqrt(n))
    stats = {
        'mean': mean,
        'std': std,
        'lower': lower,
        'upper': upper,
        'err': err,
        'max': np.max(data),
        'min': np.min(data)
    }
    return stats


