import numpy as np
from matplotlib import pyplot as plt
from matplotlib.tri import Triangulation
import numpy as np
import pdb
import copy
from utils import compute_inverse, load_mini_batches, clip_target, high_d_plot
from policies import NeuralNetwork
import torch
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors
import seaborn as sns

def create_demo_data(M, N):
    # create some demo data for North, East, South, West
    # note that each of the 4 arrays can be either 2D (N by M) or 1D (N*M)
    # M columns and N rows
    valuesN = np.repeat(np.abs(np.sin(np.arange(N))), M)
    valuesE = np.arange(M * N) / (N * M)
    valuesS = np.random.uniform(0, 1, (N, M))
    valuesW = np.random.uniform(0, 1, (N, M))
    return [valuesN, valuesE, valuesS, valuesW]

def triangulation_for_triheatmap(M, N):
    xv, yv = np.meshgrid(np.arange(-0.5, M), np.arange(-0.5, N))  # vertices of the little squares
    xc, yc = np.meshgrid(np.arange(0, M), np.arange(0, N))  # centers of the little squares
    x = np.concatenate([xv.ravel(), xc.ravel()])
    y = np.concatenate([yv.ravel(), yc.ravel()])
    cstart = (M + 1) * (N + 1)  # indices of the centers

    trianglesN = [(i + j * (M + 1), i + 1 + j * (M + 1), cstart + i + j * M)
                  for j in range(N) for i in range(M)]
    trianglesE = [(i + 1 + j * (M + 1), i + 1 + (j + 1) * (M + 1), cstart + i + j * M)
                  for j in range(N) for i in range(M)]
    trianglesS = [(i + 1 + (j + 1) * (M + 1), i + (j + 1) * (M + 1), cstart + i + j * M)
                  for j in range(N) for i in range(M)]
    trianglesW = [(i + (j + 1) * (M + 1), i + j * (M + 1), cstart + i + j * M)
                  for j in range(N) for i in range(M)]
    return [Triangulation(x, y, triangles) for triangles in [trianglesN, trianglesE, trianglesS, trianglesW]]

def triang_heatmap(dsa, fname, show_cb = True, labels = None):
    # if np.max(dsa) != 0:
    #     dsa = dsa / np.max(dsa)
    M, N = dsa[0].shape[0], dsa[0].shape[1] # e.g. 5 columns, 4 rows
    values = dsa
    triangul = triangulation_for_triheatmap(M, N)
    cmaps = ['Blues', 'Greens', 'Purples', 'Reds']  # ['winter', 'spring', 'summer', 'autumn']
    norms = [plt.Normalize(0, 1) for _ in range(4)]
    fig, ax = plt.subplots()
    # imgs = [ax.tripcolor(t, np.ravel(val), cmap=cmap, norm=norm, ec='white')
    #         for t, val, cmap, norm in zip(triangul, values, cmaps, norms)]
    vmin = np.min(dsa)
    vmax = np.max(dsa)
    print (vmin, vmax)
    imgs = [ax.tripcolor(t, val.ravel(), cmap='RdYlGn', vmin=vmin, vmax=vmax, ec='white')
            for t, val in zip(triangul, values)]
    
    true_labels = False
    if labels is None:
        labels = np.zeros_like(values)
    else:
        true_labels = True

    for val, lab, dir in zip(values, labels, [(-1, 0), (0, 1), (1, 0), (0, -1)]):
        for i in range(M): # cols
            for j in range(N): # rows
                if true_labels:
                    v = lab[j, i]
                else:
                    v = val[j, i]

                if i == 2 and j == 0:
                    continue
                #ax.text(i + 0.3 * dir[1], j + 0.3 * dir[0], round(v), color='k' if 0.1 < v <= 1. else 'w', ha='center', va='center', fontsize = 'xx-large')
                ax.text(i + 0.3 * dir[1], j + 0.3 * dir[0], round(v), color='k', ha='center', va='center', fontsize = 'xx-large')
                #ax.text(i + 0.3 * dir[1], j + 0.3 * dir[0], f'{v:.1f}', color='k' if 0.1 < v <= 1. else 'w', ha='center', va='center', fontsize = 'xx-large')
                #ax.text(i + 0.3 * dir[1], j + 0.3 * dir[0], f'{v:.2f}', ha='center', va='center')

    if show_cb:
        cbar = fig.colorbar(imgs[0], ax=ax)
        cbar.ax.tick_params(labelsize=20)

    ax.set_xticks(range(M))
    plt.xticks(fontsize='xx-large')
    ax.set_yticks(range(N))
    plt.yticks(fontsize='xx-large')
    ax.invert_yaxis()
    ax.margins(x=0, y=0)
    ax.set_aspect('equal', 'box')  # square cells
    plt.tight_layout()

    plt.savefig(fname, bbox_inches = 'tight')
    #plt.show()

def _generate_all_one_hots(mdp):
    if hasattr(mdp, 'state_action_features'):
        state_actions = [mdp.state_action_features[key] for key in mdp.state_action_features]
    else:
        state_actions = []
        sa_dim = mdp.n_state * mdp.n_action
        for s_idx in range(mdp.n_state):
            per_state = []
            for a_idx in range(mdp.n_action):
                state_action_feat = np.zeros((sa_dim))
                state_action_feat[mdp.n_action * s_idx + a_idx] = 1.
                per_state.append(state_action_feat)
            state_actions.append(per_state)
    state_actions = np.array(state_actions)
    return state_actions

def estimated_Q_values(mdp_name, mdp, true_q_values, phi, lstd_ground_weights, lstd_abs_weights):
    sa_dim = true_q_values.shape[0] * true_q_values.shape[1]
    
    if mdp_name == 'RandomMDP':
        state_actions = _generate_all_one_hots(mdp)
        for s_idx in range(mdp.n_state):
            for a_idx in range(mdp.n_action):
                sa_feat = state_actions[s_idx][a_idx]
                lstd_est_ground = -1
                if lstd_ground_weights is not None:
                    lstd_est_ground = sa_feat @ lstd_ground_weights
                lstd_est_abs = -1
                abs_sa = sa_feat
                if phi is not None:
                    abs_sa = phi(sa_feat)
                lstd_est_abs = abs_sa @ lstd_abs_weights
                print ('true val: {}, lstd est ground: {}, lstd est abs: {}'\
                    .format(true_q_values[s_idx, a_idx], lstd_est_ground, lstd_est_abs))
    elif mdp_name == 'Roy':
        state_actions = mdp.state_action_features
        for s_idx in range(mdp.n_state):
            sa_feat = state_actions[s_idx]
            lstd_est_ground = -1
            if lstd_ground_weights is not None:
                lstd_est_ground = sa_feat @ lstd_ground_weights
            lstd_est_abs = -1
            abs_sa = sa_feat
            if phi is not None:
                abs_sa = phi(sa_feat)
            lstd_est_abs = abs_sa @ lstd_abs_weights
            print ('true val: {}, lstd est ground: {}, lstd est abs: {}'\
                .format(true_q_values[s_idx, 0], lstd_est_ground, lstd_est_abs))

class OPEEvaluator:
    def __init__(self, mdp, pie, csa, init_s, q_values, pib_q_values, min_rew, max_rew, gamma,\
        successor_sa = None, rewards = None, init_sa = None, sa_visitation = None, raw_state = None):
        self.mdp = mdp
        self.pie = pie
        self.csa = csa
        self.init_s = init_s
        self.q_values = q_values
        self.pib_q_values = pib_q_values
        self.min_rew = min_rew
        self.max_rew = max_rew
        self.gamma = gamma
        self.successor_sa = successor_sa
        self.rewards = rewards
        self.init_sa = init_sa
        self.sa_visitation = sa_visitation
        self.raw_state = raw_state

    def evaluate(self, phi, weights, metric_type = 'error'):
        if metric != error:
            return 0
        q_flat = self.q_values.reshape(-1, 1)
        if hasattr(self.mdp, 'diverage_mdp'):
            idx = self.raw_state
        else:
            idx = np.argmax(self.csa, axis = 1)
        all_qs = q_flat[idx].reshape(-1)
        csa = self.csa
        if phi is not None:
            csa = phi(self.csa)
        est_qvals = csa @ weights
        num = np.mean(np.square(est_qvals - all_qs))
        denom = 1
        if self.pib_q_values is not None:
            pib_q_flat = self.pib_q_values.reshape(-1, 1)
            all_pib_qs = pib_q_flat[idx].reshape(-1)
            denom = np.mean(np.square(all_pib_qs - all_qs))
            if denom == 0:
                denom = 1.

        # inv = compute_inverse(np.dot(csa.T, csa))
        # A = _generate_all_one_hots(self.mdp)
        # A = A.reshape(-1, A.shape[-1])

        # A_phi = A
        # if phi is not None:
        #     A_phi = phi(A)
        # sr_phi = np.dot(self.successor_sa, A_phi)
        # init_sa = self.pie.sample_sa_features(self.init_s)
        # indx = np.argmax(init_sa, axis = 1)
        # init_sr_phi = sr_phi[indx, :]

        # we = (1. - self.gamma) * (self.csa.shape[0] / self.init_s.shape[0]) * np.dot(np.matmul(inv, init_sr_phi.T), np.ones(self.init_sa.shape[0]))
        # sr_est_q = np.dot(sr_phi, we)
        # ret = np.mean(np.dot(csa, we) * self.rewards) / (1 - self.gamma)

        # code for OPE eval from init states

        # init_sa = self.pie.sample_sa_features(self.init_s)
        # init_idx = np.argmax(init_sa, axis = 1)
        # init_q = q_flat[init_idx].reshape(-1)
        # if phi is not None:
        #     init_sa = phi(init_sa)
        # init_est_qvals = init_sa @ weights
        #num = np.mean(np.square(init_est_qvals - init_q))

        #print (np.square(ret - np.mean(init_q)) / denom)

        normalized_err = num / denom
        ope_error = np.nan_to_num(normalized_err, nan = np.inf)
        print (f'final OPE error: {ope_error}, numerator {num}, denom {denom}')
        return ope_error
    
    def realizability(self, phi):
        ret = realizability_measure(self.mdp, self.q_values, phi)
        return ret['realizable_err']

    def ortho_vs_qval(self, phi):
        ret = ortho_vs_qval(self.mdp, self.q_values, phi, self.successor_sa)
        return ret['ortho_qvaldiff_corr'], ret['ortho_sr_corr']

    def value_eval_path(self, Qs):

        phi = Qs[len(Qs) // 2].backbone
        ref_phi = self.csa
        if phi is not None:
            ref_phi = phi(ref_phi)

        errs = []
        for idx, q in enumerate(Qs):
            est_q = q(self.csa)
            C, residuals, rank, sing_vals = np.linalg.lstsq(ref_phi, est_q, rcond=None)
            err_norm = residuals[0] / len(self.csa) if len(residuals) else np.array([0])
            normalization = np.mean(np.abs(self.q_values))
            err_norm = err_norm / (normalization if normalization > 0 else 1)
            errs.append(err_norm)
        return errs

def ortho_vs_qval(mdp, q_values, phi = None, successor_sa = None):
    A = _generate_all_one_hots(mdp)
    A = A.reshape(-1, A.shape[-1])
    q_flat = q_values.reshape(-1, 1)
    if phi is not None:
        A = phi(A)

    qval_diff = pairwise_distances(q_flat, metric = 'l1')
    cosine_sim_mat = cosine_similarity(A)
    cosine_sim_mat = np.abs(cosine_sim_mat)
    ortho = 1 - cosine_sim_mat

    ortho_idx = np.triu_indices(ortho.shape[0], k = 1)
    ortho = ortho[ortho_idx]
    qval_diff_idx = np.triu_indices(qval_diff.shape[0], k = 1)
    qval_diff = qval_diff[qval_diff_idx]
    sorted_orthos = ortho[qval_diff.argsort()]

    sr_correlation = 0
    sr_sorted_orthos = -1
    sr_correlation = -1
    if successor_sa is not None:
        sr_phi = np.dot(successor_sa, A)
        sr_cos = cosine_similarity(sr_phi)
        sr_cos = np.abs(sr_cos)
        sr_ortho = 1 - sr_cos

        sr_ortho_idx = np.triu_indices(sr_ortho.shape[0],k=1)
        sr_ortho = sr_ortho[sr_ortho_idx]
        sr_sorted_orthos = ortho[sr_ortho.argsort()]
        try:
            sr_correlation = np.corrcoef(ortho, sr_ortho)[0, 1]
        except:
            sr_correlation = 0

    try:
        correlation = np.corrcoef(ortho, qval_diff)[0, 1]
    except:
        correlation = 0
    ortho_temp = np.nanmean(ortho)
    stats = {
        'orthogonality_unq': ortho_temp,
        'ortho_qvaldiff_corr': correlation,
        'qvaldiff_sorted_ortho': sorted_orthos,
        'ortho_sr_corr': sr_correlation,
        'sr_sorted_ortho': sr_sorted_orthos
    }
    #print ('ortho qval stat {}'.format(stats))
    return stats

def realizability_measure(mdp, q_values, phi = None):
    A = _generate_all_one_hots(mdp)
    A = A.reshape(-1, A.shape[-1])
    q_flat = q_values.reshape(-1, 1)
    if phi is not None:
        A = phi(A)
        # temp = np.dot(A, A.T)
        # temp = temp[~np.eye(*temp.shape).astype(bool)].reshape(-1, 1)
        # temp2 = np.dot(temp, temp.T)
        # temp2 = temp2[~np.eye(*temp2.shape).astype(bool)].reshape(-1, 1)
        # print (temp.mean(), temp.min(), temp.max())
        # print (temp2.mean(), temp2.min(), temp2.max())

    C, residuals, rank, sing_vals = np.linalg.lstsq(A, q_flat, rcond=None)
    err_norm = residuals[0] if len(residuals) else np.array([0])
    normalization = np.mean(np.abs(q_flat)) 
    err_norm = err_norm / (normalization if normalization > 0 else 1)
    cond_number = sing_vals.max() / max(sing_vals.min(), 1e-8)
    #print (cond_number, cond_number, np.linalg.norm(C))
    realize_stats = {
        'realizable_err': err_norm,
        'realizable_w_norm': np.linalg.norm(C),
        'realizable_w': C,
        'realizable_cond': cond_number
    }
    #print (f'realizability stats {realize_stats}')
    return realize_stats

def bc_measure(data, q_values, phi, input_dim, pi, gamma):

    w = torch.nn.Linear(input_dim, 1, bias = False)
    w_optimizer = torch.optim.AdamW(w.parameters(), lr = 1e-4)

    mini_batch_size = 1024
    epochs = 10000    

    for epoch in range(1, epochs + 1):
        mini_batch = load_mini_batches(data, True, pi, mini_batch_size)
        curr_sa = torch.Tensor(mini_batch['curr_sa'])
        next_sa = torch.Tensor(mini_batch['next_sa'])
        rewards = torch.Tensor(mini_batch['rewards'])
        terminal_masks = torch.Tensor(mini_batch['terminal_masks'])
        if phi is not None:
            curr_sa = torch.Tensor(phi(curr_sa))
            next_sa = torch.Tensor(phi(next_sa))

        next_fsa = w.forward(next_sa).reshape(-1)
        tf = rewards + gamma * terminal_masks * next_fsa
        target_tf = tf.detach()

        num = curr_sa.shape[0]
        cov = (torch.matmul(curr_sa.T, curr_sa) / num)
        cov_inv = compute_inverse(cov)
        left_proj_op = curr_sa @ cov_inv
        right_proj_op = (curr_sa.T @ tf) / num
        projected_tf = left_proj_op @ right_proj_op

        loss = torch.nn.functional.huber_loss(target_tf, projected_tf)

        w_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(w.parameters(), clip_value = 1.0)

        # gradient ascent to find sup
        for p in w.parameters():
            p.grad *= -1
        w_optimizer.step()

        if epoch % 1000 == 0 or epoch == epochs or epoch == 1:
            param_weights = torch.Tensor([torch.linalg.norm(p.data) for p in w.parameters()])
            weight_norm = torch.sum(param_weights)
            print ('{} bc error {}, total weight norm {}, each weight norm {}'.format(epoch, loss, weight_norm, param_weights))

    whole_batch = load_mini_batches(data, True, pi, -1)
    curr_sa = torch.Tensor(whole_batch['curr_sa'])
    next_sa = torch.Tensor(whole_batch['next_sa'])
    rewards = torch.Tensor(whole_batch['rewards'])
    terminal_masks = torch.Tensor(whole_batch['terminal_masks'])
    if phi is not None:
        curr_sa = torch.Tensor(phi(curr_sa))
        next_sa = torch.Tensor(phi(next_sa))

    with torch.no_grad():
        next_fsa = w(next_sa).reshape(-1)
    tf = rewards + gamma * terminal_masks * next_fsa
    target_tf = tf

    num = curr_sa.shape[0]
    cov = (torch.matmul(curr_sa.T, curr_sa) / num)
    cov_inv = compute_inverse(cov)
    left_proj_op = curr_sa @ cov_inv
    right_proj_op = (curr_sa.T @ tf) / num
    projected_tf = left_proj_op @ right_proj_op

    final_bc = torch.nn.functional.huber_loss(target_tf, projected_tf).item()
    # q_flat = q_values.reshape(-1, 1)
    # num = csa.shape[0]
    # idx = np.argmax(csa, axis = 1)
    # all_qs = q_flat[idx].reshape(-1)
    # if phi is not None:
    #     csa = phi(csa)
    # cov = (np.matmul(csa.T, csa) / num)
    # cov_inv = compute_inverse(cov)
    # left_proj_op = csa @ cov_inv
    # right_proj_op = (csa.T @ all_qs) / num
    # projected_q = left_proj_op @ right_proj_op
    # projected_q = projected_q.reshape(-1)

    # proj = np.mean(np.abs(all_qs - projected_q))
    # eps = proj / np.mean(np.abs(all_qs))
    bc_stats = {
        'bc_err': final_bc
    }
    print (f'bc stats {bc_stats}')
    return bc_stats

def td_soln_measure(mdp, phi, realizable_w, td_w):
    A = _generate_all_one_hots(mdp)
    A = A.reshape(-1, A.shape[-1])
    if phi is not None:
        A = phi(A)
    
    realizable_w = realizable_w.reshape(-1)
    x = A @ realizable_w
    y = A @ td_w
    soln_dist = np.nan_to_num(np.linalg.norm(x - y), nan = np.inf)
    weight_dist = np.nan_to_num(np.linalg.norm(realizable_w - td_w), nan = np.inf)
    td_w_norm = np.linalg.norm(td_w)
    realizable_w_norm = np.linalg.norm(realizable_w)
    td_solution_stats = {
        'td_sol_dist': soln_dist, # distance between TDs solution and solution by linear regression (known labels)
        'td_w_dist': weight_dist, # distance between TDs weight and weight from linear regression,
    }
    print (f'td sol stats {td_solution_stats}')
    return td_solution_stats

def induced_trans_measure(mdp, pi, phi, csa_data, phi_outdim):

    # get pi transition matrix
    pi_tr_s, pi_r, pi_tr_sa = mdp.get_policy_probs(pi.pi_matrix)
    r = mdp.rewards
    r = r.reshape(r.shape[0] * r.shape[1], -1)
    p_sa = pi_tr_sa.reshape(-1, pi_tr_sa.shape[2] * pi_tr_sa.shape[3])

    data_dist = (csa_data.T @ csa_data) / csa_data.shape[0]

    sa = _generate_all_one_hots(mdp)
    sa = sa.reshape(-1, sa.shape[-1])

    # PI * psa * PI
    if phi is not None:
        sa = phi(sa)

    cov = sa.T @ data_dist @ sa
    cov_inv = compute_inverse(cov)

    left_proj_op = sa @ cov_inv
    right_proj_op = sa.T @ data_dist
    PI = left_proj_op @ right_proj_op

    ind_psa = PI @ p_sa @ PI

    ind_psa_eigs = np.linalg.eigvals(ind_psa)
    psa_eigs = np.linalg.eigvals(p_sa)
    dist = np.linalg.norm(ind_psa_eigs - psa_eigs)

    print (f'eig val distance between psa and induced psa {dist}')
    return dist

def generate_one_hots(mdp, paths, pi, q_values):
    sa_dim = mdp.n_state * mdp.n_action
    states = []
    state_features = []
    state_action_features = []
    rewards = []
    actions = []
    action_features = []
    next_states = []
    next_state_features = []
    next_state_action_features = []
    dones = []
    init_state_features = []
    init_state_action_features = []
    init_state_action_b_features = []

    sa_counts = {}
    for idx in range(len(paths)):
        #print (len(paths[idx]['obs']))
        for t in range(len(paths[idx]['obs'])):
            state = paths[idx]['obs'][t]
            action = paths[idx]['acts'][t]
            next_state = paths[idx]['nobs'][t]

            if t == 0:
                init_state_feat = np.zeros((mdp.n_state))
                init_state_feat[state] = 1.
                init_state_features.append(init_state_feat)

                # state-action features of behavior policy
                init_state_action_feat = np.zeros((sa_dim))
                init_state_action_feat[mdp.n_action * state + action] = 1.
                init_state_action_b_features.append(init_state_action_feat)

                exp_init_state_action_feat = np.zeros((sa_dim))
                for a_idx in range(mdp.n_action):
                    init_state_action_feat = np.zeros((sa_dim))
                    init_state_action_feat[mdp.n_action * state + a_idx] = 1.
                    exp_init_state_action_feat += pi.get_prob(state, a_idx) * init_state_action_feat
                init_state_action_features.append(exp_init_state_action_feat)
            
            states.append(state)
            rewards.append(paths[idx]['rews'][t])
            actions.append(action)
            next_states.append(next_state)
            dones.append(paths[idx]['dones'][t])

            if (state, action) not in sa_counts:
                sa_counts[(state, action)] = 0
            sa_counts[(state, action)] += 1

            # curr state action features
            state_action_feat = np.zeros((sa_dim))
            state_action_feat[mdp.n_action * state + action] = 1.
            state_action_features.append(state_action_feat)

            # expected next state action features
            exp_next_state_action_feat = np.zeros((sa_dim))
            for a_idx in range(mdp.n_action):
                next_state_action_feat = np.zeros((sa_dim))
                next_state_action_feat[mdp.n_action * next_state + a_idx] = 1.
                exp_next_state_action_feat += pi.get_prob(next_state, a_idx) * next_state_action_feat
            next_state_action_features.append(exp_next_state_action_feat)

            state_feat = np.zeros((mdp.n_state))
            state_feat[state] = 1.
            state_features.append(state_feat)

            action_feat = np.zeros((mdp.n_action))
            action_feat[action] = 1.
            action_features.append(action_feat)

            next_state_feat = np.zeros((mdp.n_state))
            next_state_feat[next_state] = 1.
            next_state_features.append(next_state_feat)

    data = {
        'states': states,
        'rewards': rewards,
        'actions': actions,
        'next_states': next_states,
        'dones': dones
    }

    for sac in sa_counts:
        sa_counts[sac] = sa_counts[sac] / len(states)
        #print (sac, sa_counts[sac], sa_counts[sac] * len(states))
    
    q_vals = []
    for s, a in zip(states, actions):
        q_vals.append(q_values[s, a])
    #print (len(state_features))

    one_hot = {
        'dataset': {
            'state_b': np.array(state_features),
            'action_b': np.array(action_features),
            'state_action_b': np.array(state_action_features),
            'next_state_b': np.array(next_state_features),
            'next_state_action': np.array(next_state_action_features),
            'rewards': np.array(rewards),
            'terminal_masks': 1. - np.array(dones),
            'num_samples': len(state_features),
            'init_state': np.array(init_state_features),
            'init_state_action': np.array(init_state_action_features),
            'init_state_action_b': np.array(init_state_action_b_features),
            'q_values': np.array(q_vals)
        }
    }
    return one_hot

def generate_roy_features(mdp, paths, pi, q_values):
    sa_dim = mdp.n_state * mdp.n_action
    states = []
    state_features = []
    state_action_features = []
    rewards = []
    actions = []
    action_features = []
    next_states = []
    next_state_features = []
    next_state_action_features = []
    dones = []
    init_state_features = []
    init_state_action_features = []

    for s, r, ns, d in zip(paths['states'], paths['rewards'], paths['next_states'], paths['dones']):
        states.append(s)
        next_states.append(ns)

        state_action_feat = mdp.state_action_features[s]
        state_features.append(state_action_feat)
        state_action_features.append(state_action_feat)

        rewards.append(r)
        dones.append(d)

        next_state_action_feat = mdp.state_action_features[ns]
        next_state_action_features.append(next_state_action_feat)
        next_state_features.append(next_state_action_feat)

        action_feat = np.zeros((mdp.n_action))
        action_feat[0] = 1.
        action_features.append(action_feat)
    
    for init_state in paths['init_states']:
        init_state_feat = mdp.state_action_features[init_state]
        init_state_features.append(init_state_feat)
        init_state_action_features.append(init_state_feat)
    
    features = {
        'dataset': {
            'state': np.array(states),
            'next_state': np.array(next_states),
            'state_b': np.array(state_features),
            'action_b': np.array(action_features),
            'state_action_b': np.array(state_action_features),
            'next_state_b': np.array(next_state_features),
            'next_state_action': np.array(next_state_action_features),
            'rewards': np.array(rewards),
            'terminal_masks': 1. - np.array(dones),
            'num_samples': len(state_features),
            'init_state': np.array(init_state_features),
            'init_state_action': np.array(init_state_action_features),
            'q_values': np.array(q_values),
            'init_state_action_b': np.array(init_state_action_features)
        }
    }
    return features

def generate_bairds_features(mdp, paths, pi, q_values):
    sa_dim = mdp.n_state * mdp.n_action
    states = []
    state_features = []
    state_action_features = []
    rewards = []
    actions = []
    action_features = []
    next_states = []
    next_state_features = []
    next_state_action_features = []
    dones = []
    init_state_features = []
    init_state_action_features = []

    for path in paths:
        got_init = False
        for s, a, r, ns, d in zip(path['obs'], path['acts'], path['rews'], path['nobs'], path['dones']):
            
            state_feat = mdp.state_features[s]
            state_features.append(state_feat)
            act_feat = mdp.action_features[a]
            state_action_feat = mdp.get_sa_feat(s, a)
            state_action_features.append(state_action_feat)

            if not got_init:
                init_state_features.append(state_feat)
                init_state_action_features.append(state_action_feat)
                got_init = True

            rewards.append(r)
            dones.append(d)

            next_state_feat = mdp.state_features[ns]
            next_state_features.append(next_state_feat)
            next_state_action_features.append(next_state_feat)

            action_features.append(act_feat)
    
    features = {
        'dataset': {
            'state_b': np.array(state_features),
            'action_b': np.array(action_features),
            'state_action_b': np.array(state_action_features),
            'next_state_b': np.array(next_state_features),
            'next_state_action': np.array(next_state_action_features),
            'rewards': np.array(rewards),
            'terminal_masks': 1. - np.array(dones),
            'num_samples': len(state_features),
            'init_state': np.array(init_state_features),
            'init_state_action': np.array(init_state_action_features),
            'q_values': np.array(q_values)
        }
    }
    return features

def collect_data_discrete(env, policy, num_trajectory, truncated_horizon, use_true_latent = False, random_pi = False):
    paths = []
    num_samples = 0
    total_reward = 0.0
    for i_trajectory in range(num_trajectory):
        path = {}
        path['obs'] = []
        path['nobs'] = []
        path['acts'] = []
        path['rews'] = []
        path['dones'] = []
        state, _ = env.reset() # v4 gym outputs ob, {}
        sasr = []
        i_t = 0
        #for i_t in range(truncated_horizon):
        while True:
            if random_pi:
                action = np.random.choice(np.arange(env.n_action))
            else:
                action = policy(env.convert_to_latents(state) if hasattr(env, 'convert_to_latents') and use_true_latent else state)
            next_state, reward, done, _, _ = env.step(action) # v4 changes
            path['obs'].append(state)
            path['acts'].append(action)
            path['rews'].append(reward)
            path['nobs'].append(next_state)
            #sasr.append((state, action, next_state, reward))
            total_reward += reward
            state = next_state
            path['dones'].append(done)
            i_t += 1
            if done or (truncated_horizon != -1 and i_t >= truncated_horizon):
                break
        paths.append(path)
        num_samples += len(paths[-1]['obs'])
    
    return paths, total_reward / num_samples#(num_trajectory * truncated_horizon)

def heatmap(matrix, filename):

    matrix = matrix / np.max(matrix)

    plt.imshow(matrix, cmap='hot', interpolation='nearest')

    plt.savefig(filename)
    plt.close()

def random_mdp_Q_vals(env, pi, gamma):
    pi = pi.pi_matrix
    p_s, r_s, p_sa = env.get_policy_probs(pi)
    p_sa = p_sa.reshape(-1, p_sa.shape[2] * p_sa.shape[3])

    r = env.rewards
    r_flat = r.reshape(r.shape[0] * r.shape[1], -1)

    num_sa = p_sa.shape[0]
    discounted_p = p_sa * gamma

    sr_sa = np.linalg.inv(np.eye(num_sa) - discounted_p)
    qvals = np.matmul(sr_sa, r_flat)

    r = r_s
    p = p_s
    num_states = p.shape[0]
    discounted_p = p * gamma
    vals = np.matmul(np.linalg.inv(np.eye(num_states) - discounted_p), r)

    return qvals, vals, sr_sa

def random_mdp_krope(env, pi, gamma, q_values):

    pi = pi.pi_matrix
    p_s, r_s, p_sa = env.get_policy_probs(pi)
    p_sa = p_sa.reshape(-1, p_sa.shape[2] * p_sa.shape[3])
    num_sa = p_sa.shape[0]

    cross_psa = []
    for sa in range(num_sa):
        for other_sa in range(num_sa):
            first = p_sa[sa, :]
            second = p_sa[other_sa, :]
            result = np.array([a * b for a in first for b in second])
            cross_psa.append(result)
    cross_psa = np.array(cross_psa)
    r = env.rewards
    r_flat = r.reshape(-1)
    #r_flat = r.reshape(r.shape[0] * r.shape[1], -1)

    r_diff = np.array([1 - (abs(r_flat[i] - r_flat[j]) / env.reward_range) for i in range(len(r_flat)) for j in range(len(r_flat))])

    #r_diff = np.array([env.reward_range - abs(r_flat[i] - r_flat[j]) for i in range(len(r_flat)) for j in range(len(r_flat))])
    r_diff = r_diff.reshape(r_diff.shape[0], 1)

    discounted_cross_p = cross_psa * gamma

    num_pairs_combs = cross_psa.shape[0]
    sr_cross = np.linalg.inv(np.eye(num_pairs_combs) - discounted_cross_p)
    K = np.matmul(sr_cross, r_diff).reshape(num_sa, num_sa)

    diag_K = np.diag(K)
    
    # Compute the pairwise distances using the formula
    dK = np.zeros((num_sa, num_sa))
    for i in range(num_sa):
        for j in range(num_sa):
            dK[i, j] = diag_K[i] + diag_K[j] - 2 * K[i, j]

    q_values = q_values.reshape(-1)
    q_diff = np.array([abs(q_values[i] - q_values[j]) for i in range(len(q_values)) for j in range(len(q_values))])
    q_diff = q_diff.reshape(q_diff.shape[0], 1)

    print (K)
    x = np.linalg.eigvals(K)
    print (np.linalg.matrix_rank(K), torch.linalg.matrix_rank(torch.Tensor(K)), np.linalg.cholesky(K), np.sum(x) / np.max(x), len(np.where(x > 0)[0]))
    pdb.set_trace()

    # r = r_s
    # p = p_s
    # num_states = p.shape[0]
    # discounted_p = p * gamma
    # vals = np.matmul(np.linalg.inv(np.eye(num_states) - discounted_p), r)

    # return qvals, vals, sr_sa

def compute_Q_values(env_name, mdp, pi, gamma, epochs = 100, lr = 1e-1):

    q_values, v_values, successor_sa = None, None, None
    if env_name == 'ToyMDP':
        q_values = np.zeros((mdp.n_state, mdp.n_action))
        prev_q_values = np.zeros_like(q_values)

        for epoch in range(epochs):
            for s in range(mdp.n_state):
                for a in range(mdp.n_action):
                    q_val = 0
                    for ns in range(mdp.n_state):
                        if ns == mdp.n_state - 1:
                            continue
                        for na in range(mdp.n_action):
                            q_val += (mdp.transitions[s][a][ns] * pi.get_prob(ns, na) * q_values[ns][na])
                    q_values[s][a] = mdp.rewards[s][a] + gamma * q_val
            
            if ((epoch + 1) % 50 == 0):
                diff = np.abs(q_values - prev_q_values)
                thresh = (diff <= 1e-2)
                count = np.count_nonzero(thresh)
                if count == q_values.shape[0] * q_values.shape[1]:
                    print ('converged, done, itr {}'.format(epoch + 1))
                    break
                prev_distances = copy.deepcopy(q_values)
                print ('epoch: {} {}'.format(epoch, np.linalg.norm(q_values)))
                lr /= 2.
        v_values = []
        for s_idx in range(q_values.shape[0]):
            pie_val = 0
            for a_idx in range(q_values.shape[1]):
                pie_val += pie.get_prob(s_idx, a_idx) * q_values[s_idx][a_idx]
            v_values.append(pie_val)
    elif env_name == 'RandomMDP' or env_name == 'ChainMDP':
        q_values, v_values, successor_sa = random_mdp_Q_vals(mdp, pi, gamma)
        q_values = q_values.reshape(mdp.n_state, mdp.n_action)
        if mdp.use_terminal_state:
            q_values[-1, :] = 0
            v_values[-1] = 0
    return q_values, v_values, successor_sa

def group_clusters(mdp, metric, name, q_values, show_cb = False):
    distances = metric.distances
    all_sa = [(s,a) for s in range(mdp.n_state) for a in range(mdp.n_action)]
    assigned_sa = [False for _ in range(mdp.n_state * mdp.n_action)]

    terminal_groups = [] # just tracking the group ids that contain only terminal states
    group_id = 1
    group_mappings = {}
    # start assigning each sa
    for sa in all_sa:
        s = sa[0]
        a = sa[1]
        idx = s * mdp.n_action + a
        if assigned_sa[idx]:
            continue
        else:
            if group_id not in group_mappings:
                group_mappings[group_id] = []
                if q_values:
                    group_mappings[group_id].append((sa, q_values[s, a]))
                else:
                    group_mappings[group_id].append((sa, 0))
                assigned_sa[idx] = True
                if s == mdp.n_state - 1:
                    terminal_groups.append(group_id)
                group_id += 1
        # search for other sas that are 0 distance away

        for other_sa in all_sa:
            o_s = other_sa[0]
            o_a = other_sa[1]
            o_idx = o_s * mdp.n_action + o_a
            # ignore self or if already assigned
            if sa == other_sa or assigned_sa[o_idx]:
                continue

            # point is unassigned
            # if reference point and this point are 0 distance away
            if distances[idx][o_idx] <= 1e-5:
                if q_values:
                    group_mappings[group_id - 1].append((other_sa, q_values[o_s, o_a]))
                else:
                    group_mappings[group_id - 1].append((other_sa, 0))
                assigned_sa[o_idx] = True
                if o_s == mdp.n_state - 1:
                    terminal_groups.append(group_id)
    
    ss_mat = np.zeros((mdp.length, mdp.length))
    d_vals = np.array([ss_mat for _ in range(mdp.n_action)])

    for gid in group_mappings.keys():
        coords = group_mappings[gid]
        for sa_q in coords:
            sa = sa_q[0]
            s = sa[0]
            a = sa[1]
            x = int(s / mdp.length)
            y = s % mdp.length
            # if terminal state
            label = gid
            if s == mdp.n_state - 1:
                label = min(terminal_groups) # assign all terminal states to the min recorded terminal state group
            d_vals[a][x][y] = label

    for a in range(mdp.n_action):
        d_vals[a] = np.flip(d_vals[a], axis=0)

    d_vals = np.array([d_vals[1], d_vals[0], d_vals[3], d_vals[2]])
    print (group_mappings)
    triang_heatmap(d_vals, '{}_clusters.jpg'.format(name), show_cb)

    return d_vals

class GridworldPolicy:
    def __init__(self, mdp, typ, f_name = None, L_right_prob = 0.5):
        self.mdp = mdp
        self.f_name = f_name
        self.typ = typ
        self.pi_matrix = np.zeros((mdp.n_state, mdp.n_action))
        self.num_states = mdp.n_state
        self.num_actions = mdp.n_action
        self.L_right_prob = L_right_prob
        assert self.num_states % 2 == 1 # assuming odd so that single center is well-defined
        self._load_pi()
    
    def _load_pi(self):
        if self.typ == 'random':
            ths = [1. / self.num_actions for _ in range(self.num_states * self.num_actions)]
            ths = np.array(ths)
            ths = ths.reshape(self.num_states, self.num_actions)
            self.pi_matrix = self._softmax(ths)
        elif self.typ == 'file':
            f = open(self.f_name, 'r')
            ths = f.readlines()
            ths = [float(th[:-1]) for th in ths]
            f.close()
            ths = np.array(ths)
            ths = ths.reshape(self.num_states, self.num_actions)
            self.pi_matrix = self._softmax(ths)
        elif 'L' in self.typ:
            ths = [-100. for _ in range(self.num_states * self.num_actions)]
            ths = np.array(ths)
            ths = ths.reshape(self.num_states, self.num_actions)
            ths[0][0] = self.L_right_prob
            #ths[0][1] = (1. - self.L_right_prob)
            for i in range(1, self.num_states):
                x, y = self.mdp.state_decoding(i)
                # bottom row
                if x < self.mdp.length - 1 and y == 0:
                    ths[i][0] = 1. # go right
                # left column
                elif x == 0 and y < self.mdp.length - 1:
                    ths[i][1] = 1. # go up
                # right column
                elif x == self.mdp.length - 1 and y < self.mdp.length - 1:
                    ths[i][1] = 1. # go up
                # top row
                elif x < self.mdp.length - 1 and y == self.mdp.length - 1:
                    ths[i][0] = 1. # go right
            # default: move right when at dead center if not random    
            if 'det' in self.typ:
                ths[self.num_states // 2][0] = 1.
            self.pi_matrix = self._softmax(ths)
        elif 'CW' == self.typ or 'CCW' == self.typ:
            dom_prob = 1.
            rem_prob = (1. - dom_prob) / (self.num_actions - 1)
            ths = [rem_prob for _ in range(self.num_states * self.num_actions)]
            ths = np.array(ths)
            ths = ths.reshape(self.num_states, self.num_actions)
            for i in range(0, self.num_states):
                x, y = self.mdp.state_decoding(i)
                # bottom row
                if x < self.mdp.length - 1 and y == 0:
                    if 'CW' == self.typ:
                        if x == 0:
                            ths[i][1] = dom_prob # go up                
                        else:
                            ths[i][2] = dom_prob # go left
                    elif 'CCW' == self.typ:
                        ths[i][0] = dom_prob # go right
                # left column
                elif x == 0 and y < self.mdp.length - 1:
                    if 'CW' == self.typ:
                        ths[i][1] = dom_prob # go up
                    elif 'CCW' == self.typ:
                        if x == 0:
                            ths[i][0] = dom_prob # go right    
                        else:
                            ths[i][3] = dom_prob # go down
                # right column
                elif x == self.mdp.length - 1 and y < self.mdp.length - 1:
                    if 'CW' == self.typ:
                        if x == self.mdp.length - 1:
                            ths[i][2] = dom_prob # go left
                        else:
                            ths[i][3] = dom_prob # go down
                    elif 'CCW' == self.typ:
                        ths[i][1] = dom_prob # go up                      
                # top row
                elif x < self.mdp.length - 1 and y == self.mdp.length - 1:
                    if 'CW' == self.typ:
                        ths[i][0] = dom_prob # go right
                    elif 'CCW' == self.typ:
                        if x == 0:
                            ths[i][3] = dom_prob # go down
                        else:
                            ths[i][2] = dom_prob # go left
                if 'CW' == self.typ:
                    ths[self.num_states // 2][2] = dom_prob # go left
                elif 'CCW' == self.typ:
                    ths[self.num_states // 2][3] = dom_prob # go down
                self.pi_matrix = ths

        # for i in range(self.num_states):
        #     x, y = self.mdp.state_decoding(i)
        #     print (i, (x, y), self.pi_matrix[i])
        # pdb.set_trace()
    
    def _softmax(self, ths):
        temp = np.zeros((self.mdp.n_state, self.mdp.n_action))
        for s in range(self.num_states):
            action_ths = ths[s]
            exp = np.exp(action_ths)
            sm = exp / np.sum(exp)
            temp[s] = sm
        #temp = temp.round(decimals = 3)
        return temp


    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    def batch_sample(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)
        act_feats = np.zeros((batch_size, self.mdp.n_action))
        act_feats[np.arange(batch_size), acts] = 1.
        return act_feats
    
    def sample_sa_features(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)

        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = np.zeros((sa_dim))
            next_state_action_feat[self.num_actions * s + a] = 1.
            feats.append(next_state_action_feat)

        feats = np.array(feats)
        return feats

class RandomMDPPolicy:
    def __init__(self, mdp, policy_type, mix_ratio = 1., other_pi = None):
        self.mdp = mdp
        self.policy_type = policy_type
        self.pi_matrix = np.zeros((mdp.num_states, mdp.num_actions))
        self.num_states = mdp.num_states
        self.num_actions = mdp.num_actions
        self.mix_ratio = mix_ratio
        self.other_pi = other_pi
        self._load_pi()
    
    def _load_pi(self):
        if self.policy_type == 'stochastic':
            self.pi_matrix = np.random.dirichlet(np.ones(self.num_actions), size=self.num_states)
        elif self.policy_type == 'deterministic':
            self.pi_matrix = np.zeros((self.num_states, self.num_actions))
            np.put_along_axis(self.pi_matrix, np.random.randint(self.num_actions, size=self.num_states)[:, None], values=1., axis=1)
        elif self.policy_type == 'uniform':
            self.pi_matrix = np.zeros((self.num_states, self.num_actions))
            self.pi_matrix[:] = 1. / self.num_actions
        elif self.policy_type == 'mix':
            uniform= np.zeros((self.num_states, self.num_actions))
            uniform[:] = 1. / self.num_actions
            dirich = self.other_pi.pi_matrix
            self.pi_matrix = (1. - self.mix_ratio) * dirich + self.mix_ratio * uniform
        elif self.policy_type == 'left':
            self.pi_matrix = np.zeros((self.num_states, self.num_actions))
            self.pi_matrix[:, 0] = 0.9
            self.pi_matrix[:, 1] = 0.1
        elif self.policy_type == 'half':
            self.pi_matrix = np.zeros((self.num_states, self.num_actions))
            self.pi_matrix[:self.num_states // 2, 0] = 0.9
            self.pi_matrix[:self.num_states // 2, 1] = 0.1
            self.pi_matrix[self.num_states // 2:, 0] = 0.1
            self.pi_matrix[self.num_states // 2:, 1] = 0.9

        if self.mdp.use_terminal_state:
            self.pi_matrix[-1, :] = 0

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    def batch_sample(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)
        act_feats = np.zeros((batch_size, self.mdp.n_action))
        act_feats[np.arange(batch_size), acts] = 1.
        return act_feats
    
    def sample_sa_features(self, states):
        scalar_state = states.ndim == 1

        st = np.argmax(states, axis = -1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = -1)

        if scalar_state:
            uni = np.random.rand(1)
        else:
            batch_size = states.shape[0]
            uni = np.random.rand(batch_size, 1)
            
        acts = (uni < cu).argmax(axis = -1)

        if scalar_state:
            st = [st]
            acts = [acts]
        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = np.zeros((sa_dim))
            next_state_action_feat[self.num_actions * s + a] = 1.
            feats.append(next_state_action_feat)
        feats = np.array(feats)
        if scalar_state:
            feats = feats[0]
        return feats

class FourroomsMDPPolicy:
    def __init__(self, mdp, pi_matrix):
        self.mdp = mdp
        self.pi_matrix = pi_matrix
        self.num_states = mdp.n_state
        self.num_actions = mdp.n_action

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    def batch_sample(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)
        act_feats = np.zeros((batch_size, self.mdp.n_action))
        act_feats[np.arange(batch_size), acts] = 1.
        return act_feats
    
    def sample_sa_features(self, states):
        scalar_state = states.ndim == 1

        st = np.argmax(states, axis = -1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = -1)

        if scalar_state:
            uni = np.random.rand(1)
        else:
            batch_size = states.shape[0]
            uni = np.random.rand(batch_size, 1)
            
        acts = (uni < cu).argmax(axis = -1)

        if scalar_state:
            st = [st]
            acts = [acts]
        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = np.zeros((sa_dim))
            next_state_action_feat[self.num_actions * s + a] = 1.
            feats.append(next_state_action_feat)
        feats = np.array(feats)
        if scalar_state:
            feats = feats[0]
        return feats

class RoyMDPPolicy:
    def __init__(self, mdp):
        self.mdp = mdp
        self.pi_matrix = np.zeros((mdp.num_states, mdp.num_actions))
        self.num_states = mdp.num_states
        self.num_actions = mdp.num_actions
        self._load_pi()
    
    def _load_pi(self):
        self.pi_matrix[:, 0] = 1.

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    # def batch_sample(self, states):
    #     batch_size = states.shape[0]
    #     st = np.argmax(states, axis = 1) # get index of one-hot
    #     probs = self.pi_matrix[st]
    #     cu = probs.cumsum(axis = 1)
    #     uni = np.random.rand(batch_size, 1)
    #     acts = (uni < cu).argmax(axis = 1)
    #     act_feats = np.zeros((batch_size, self.mdp.n_action))
    #     act_feats[np.arange(batch_size), acts] = 1.
    #     return act_feats
    
    def sample_sa_features(self, states):
        #pdb.set_trace()
        return states

class BairdsPolicy:
    def __init__(self, mdp, policy_type, mix_ratio = 0):
        self.mdp = mdp
        self.pi_matrix = np.zeros((mdp.num_states, mdp.num_actions))
        self.num_states = mdp.num_states
        self.num_actions = mdp.num_actions
        self.policy_type = policy_type
        self.mix_ratio = mix_ratio
        self._load_pi()
    
    def _load_pi(self):
        if self.policy_type == 'eval':
            self.pi_matrix[:, 0] = 1.
        elif self.policy_type == 'beh':
            self.pi_matrix[:, 0] = 1 / 7
            self.pi_matrix[:, 1] = 6 / 7
        elif self.policy_type == 'mix':
            eval_matrix = np.zeros((self.num_states, self.num_actions))
            eval_matrix[:, 0] = 1.

            beh_matrix = np.zeros((self.num_states, self.num_actions))
            beh_matrix[:, 0] = 1 / 7
            beh_matrix[:, 1] = 6 / 7
            self.pi_matrix = (1. - self.mix_ratio) * eval_matrix + self.mix_ratio * beh_matrix

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    # def batch_sample(self, states):
    #     batch_size = states.shape[0]
    #     st = np.argmax(states, axis = 1) # get index of one-hot
    #     probs = self.pi_matrix[st]
    #     cu = probs.cumsum(axis = 1)
    #     uni = np.random.rand(batch_size, 1)
    #     acts = (uni < cu).argmax(axis = 1)
    #     act_feats = np.zeros((batch_size, self.mdp.n_action))
    #     act_feats[np.arange(batch_size), acts] = 1.
    #     return act_feats
    
    def sample_sa_features(self, states):
        is_tensor = torch.is_tensor(states)
        if is_tensor:
            states = states.numpy()
        scalar_state = states.ndim == 1

        st = np.argmax(states, axis = -1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = -1)

        if scalar_state:
            uni = np.random.rand(1)
        else:
            batch_size = states.shape[0]
            uni = np.random.rand(batch_size, 1)
            
        acts = (uni < cu).argmax(axis = -1)

        if scalar_state:
            st = [st]
            acts = [acts]
        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = self.mdp.get_sa_feat(s, a)
            feats.append(next_state_action_feat)
        feats = np.array(feats)
        if scalar_state:
            feats = feats[0]
        if is_tensor:
            feats = torch.Tensor(feats)
        return feats

class TaxiPolicy:
    def __init__(self, mdp, pi_matrix):
        self.mdp = mdp
        self.pi_matrix = pi_matrix
        self.num_states = mdp.n_state
        self.num_actions = mdp.n_action

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    def batch_sample(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)
        act_feats = np.zeros((batch_size, self.mdp.n_action))
        act_feats[np.arange(batch_size), acts] = 1.
        return act_feats
    
    def sample_sa_features(self, states):
        scalar_state = states.ndim == 1

        st = np.argmax(states, axis = -1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = -1)

        if scalar_state:
            uni = np.random.rand(1)
        else:
            batch_size = states.shape[0]
            uni = np.random.rand(batch_size, 1)
            
        acts = (uni < cu).argmax(axis = -1)

        if scalar_state:
            st = [st]
            acts = [acts]
        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = np.zeros((sa_dim))
            next_state_action_feat[self.num_actions * s + a] = 1.
            feats.append(next_state_action_feat)
        feats = np.array(feats)
        if scalar_state:
            feats = feats[0]
        return feats

class ToyMDPPolicy:
    def __init__(self, mdp, pi_num):
        self.mdp = mdp
        self.pi_num = pi_num
        self.pi_matrix = np.zeros((mdp.n_state, mdp.n_action))
        self.num_states = mdp.n_state
        self.num_actions = mdp.n_action
        self._load_pi()
    
    def _load_pi(self):
        self.pi_matrix[1][0] = 1.
        self.pi_matrix[2][0] = 1.
        self.pi_matrix[3][0] = 1.
        self.pi_matrix[4][0] = 1.
        self.pi_matrix[5][0] = 1.
        self.pi_matrix[6][0] = 1.

        if self.pi_num == 0:
            self.pi_matrix[0][0] = 0.7
            self.pi_matrix[0][1] = 0.1
            self.pi_matrix[0][2] = 0.1
            self.pi_matrix[0][3] = 0.1
        elif self.pi_num == 1:
            self.pi_matrix[0][0] = 0.25
            self.pi_matrix[0][1] = 0.25
            self.pi_matrix[0][2] = 0.25
            self.pi_matrix[0][3] = 0.25
        elif self.pi_num == 2:
            self.pi_matrix[0][0] = 0.02
            self.pi_matrix[0][1] = 0.02
            self.pi_matrix[0][2] = 0.48
            self.pi_matrix[0][3] = 0.48

    def __call__(self, s):
        a = self._get_action(s)
        return a

    def _get_action(self, s):
        try:
            return np.random.choice(np.arange(self.num_actions), p = self.pi_matrix[s])
        except:
            print (np.sum(self.pi_matrix, axis = 1))
            pdb.set_trace()
    
    def get_prob(self, s, a):
        return self.pi_matrix[s][a]
    
    def batch_sample(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)
        act_feats = np.zeros((batch_size, self.mdp.n_action))
        act_feats[np.arange(batch_size), acts] = 1.
        return act_feats
    
    def sample_sa_features(self, states):
        batch_size = states.shape[0]
        st = np.argmax(states, axis = 1) # get index of one-hot
        probs = self.pi_matrix[st]
        cu = probs.cumsum(axis = 1)
        uni = np.random.rand(batch_size, 1)
        acts = (uni < cu).argmax(axis = 1)

        feats = []
        sa_dim = self.num_states * self.num_actions
        for (s, a) in zip(st, acts):
            next_state_action_feat = np.zeros((sa_dim))
            next_state_action_feat[self.num_actions * s + a] = 1.
            feats.append(next_state_action_feat)

        feats = np.array(feats)
        return feats

class PWC:
    def __init__(self, mdp, pi_q_values, phi_outdim, eps = 0):
        self.mdp = mdp
        self.pi_q_values = pi_q_values.astype(int)
        self.unq_q_vals = np.unique(self.pi_q_values)
        self.num_clusters = len(self.unq_q_vals)
        self.q_to_cluster_id = {value: index for index, value in enumerate(self.unq_q_vals)}
        print (self.q_to_cluster_id )
        self.phi_outdim = phi_outdim
        
        self.n_state = mdp.n_state
        self.n_action = mdp.n_action
        self.eps = eps
        self.groups = self._generate_groups()
        self.vectors = self._generate_vectors()
        from sklearn.metrics.pairwise import cosine_similarity
        #test = cosine_similarity(self.vectors)
        test = np.matmul(self.vectors, self.vectors.T)
        print (test)
        print (np.mean(test))
        pdb.set_trace()

    def _generate_groups(self):
        groups = np.zeros((self.n_state * self.n_action))
        for s in range(self.n_state):
            for a in range(self.n_action):
                q_val = self.pi_q_values[s][a]
                sa_num = self.n_action * s + a
                groups[sa_num] = self.q_to_cluster_id[q_val]
        return groups.astype(int)

    def _gram_schmidt(self, A):
        """
        Gram-Schmidt orthogonalization process for a matrix A.

        Parameters:
        - A: numpy array, shape (m, n), where m is the number of vectors and n is the dimension.
        - epsilon: float, a small value to avoid division by zero.

        Returns:
        - Q: numpy array, shape (m, n), the orthogonalized matrix.
        """
        m, n = A.shape
        Q = np.zeros((m, n))
        R = np.zeros((n, n))

        for j in range(n):
            v = A[:, j]
            for i in range(j):
                R[i, j] = np.dot(Q[:, i], A[:, j])
                v -= R[i, j] * Q[:, i]
            R[j, j] = np.linalg.norm(v) + self.eps  # Add epsilon to avoid division by zero
            Q[:, j] = v / R[j, j]

        return Q

    def _gram_schmidt_random(self, dim, num_vectors):
        """
        Generate a set of orthogonal vectors using the Gram-Schmidt process starting with a random vector.

        Parameters:
        - dim: int, dimension of the vectors.
        - num_vectors: int, number of orthogonal vectors to generate.
        - epsilon: float, a small value to avoid division by zero.

        Returns:
        - Q: numpy array, shape (dim, num_vectors), the set of orthogonal vectors.
        """
        # Generate a random matrix
        A = np.random.rand(dim, num_vectors)
        
        # Apply the Gram-Schmidt process to orthogonalize the columns
        Q = self._gram_schmidt(A)

        return Q

    def _other_generate_vectors(self, d, N):
        vectors = []
        
        # Generate the first vector randomly
        v1 = np.random.rand(d)
        vectors.append(v1)
        
        for _ in range(N - 1):
            # Generate a new random vector
            #v2 = np.random.choice([3, 2]) * vectors[0]
            if self.eps == 0:
                v2 = vectors[0]
            else:
                v2 =  ((_ + 1) * self.eps) * vectors[0]#np.random.rand(d)

            #v2 = v2 / np.linalg.norm(v2)
            
            # Adjust the new vector based on epsilon
            new_vector = v2#(self.eps * v1 + (1 - self.eps) * v2) / np.linalg.norm(self.eps * v1 + (1 - self.eps) * v2)
            vectors.append(new_vector)
            
        vecs = np.array(vectors)

        rank = np.linalg.matrix_rank(vecs)
        print ('rank ', rank)
        return vecs


    def _generate_vectors(self):
        vecs = self._other_generate_vectors(self.phi_outdim, self.num_clusters)

        #vecs = self._gram_schmidt_random(self.phi_outdim, self.num_clusters)
        return vecs

    def __call__(self, state_actions):

        try:
            num_sa = np.argmax(state_actions, axis = 1)
            group_ids = self.groups[num_sa]
            result = self.vectors[group_ids]
            return result
        except:
            pdb.set_trace()
    
    def train(self, val):
        return self


