import yaml
import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn as ott_sinkhorn
import jax.numpy as jnp
import pickle
import jax

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def safe_log(x: jnp.ndarray) -> jnp.ndarray:
    """Safe logarithm to avoid numerical issues."""
    return jnp.log(jnp.maximum(x, 1e-10))

def gen_kl(p: jnp.ndarray, q: jnp.ndarray) -> float:
    """Generalized Kullback-Leibler divergence."""
    return jnp.vdot(p, (safe_log(p) - safe_log(q))) + jnp.sum(q) - jnp.sum(p)

def gen_js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float:
    """Jensen-Shannon divergence."""
    return c * (gen_kl(p, q) + gen_kl(q, p))

class StandardScaler:
    def fit(self, x: jnp.ndarray) -> None:
        mean = x.mean(axis=0)
        std = x.std(axis=0)
        return mean, std
    
    def transform(self, x: jnp.ndarray, mean: jnp.ndarray, std: jnp.ndarray) -> jnp.ndarray:
        return (x - mean) / (std + 1e-8)

def compute_policy_expert_divergence(policy_states, policy_actions, rng, sample_expert_transitions, disc_input, mean=None, std=None):
    """
    Compute a divergence measure between policy and expert distributions.
    """
    policy_states_flat = policy_states.reshape(-1, policy_states.shape[-1])
    policy_actions_flat = policy_actions.reshape(-1, policy_actions.shape[-1])
       
    # Determine batch size based on flattened data
    batch_size = min(policy_states_flat.shape[0], 500)  
    expert_batch, rng = sample_expert_transitions(batch_size, rng)
       
    # Now take a subset of the flattened policy data to match batch_size
    if policy_states_flat.shape[0] > batch_size:
        rng, _rng = jax.random.split(rng)
        indices = jax.random.randint(_rng, (batch_size,), 0, policy_states_flat.shape[0])
        policy_states_sample = policy_states_flat[indices]
        policy_actions_sample = policy_actions_flat[indices]
    else:
        policy_states_sample = policy_states_flat[:batch_size]
        policy_actions_sample = policy_actions_flat[:batch_size]

    if disc_input == 'ss' or disc_input == 's':
        policy_x = policy_states_sample
        expert_x = expert_batch.unnorm_obs

    elif disc_input == 'sa':
        policy_x = jax.lax.stop_gradient(jnp.concatenate([policy_states_sample, policy_actions_sample], axis=-1))
        expert_x = jax.lax.stop_gradient(jnp.concatenate([expert_batch.unnorm_obs, expert_batch.action], axis=-1))

    # mean = mean[:policy_x.shape[-1]]
    # std = std[:policy_x.shape[-1]]

    # policy_x = (policy_x - mean) / (std + 1e-8)
    # expert_x = (expert_x - mean) / (std + 1e-8)

    policy_weights = jnp.ones(policy_x.shape[0]) / policy_x.shape[0]
    expert_weights = jnp.ones(expert_x.shape[0]) / expert_x.shape[0]

    # Computes the couplings using the Sinkhorn algorithm.
    geom = pointcloud.PointCloud(policy_x, expert_x)
    solver = ott_sinkhorn.Sinkhorn()

    # Create and solve the OT problem
    ot_problem = linear_problem.LinearProblem(geom, policy_weights, expert_weights)
    ot_result = solver(ot_problem)

    # The Wasserstein distance is the transport cost
    divergence = ot_result.reg_ot_cost

    divergence = jax.block_until_ready(divergence)
                
    return divergence

def compute_policy_expert_divergence_minatar(policy_states, policy_actions, rng, sample_expert_transitions, n_actions, disc_input='sa'):
    """
    Compute a divergence measure between policy and expert distributions.
    """
    policy_states_flat = policy_states.reshape(-1, policy_states.shape[-1])
    policy_actions_flat = jax.nn.one_hot(policy_actions, n_actions).reshape(-1, n_actions)
        
    # Determine batch size based on flattened data
    batch_size = min(policy_states_flat.shape[0], 500)  
    expert_batch, rng = sample_expert_transitions(batch_size, rng)
        
    # Now take a subset of the flattened policy data to match batch_size
    if policy_states_flat.shape[0] > batch_size:
        rng, _rng = jax.random.split(rng)
        indices = jax.random.randint(_rng, (batch_size,), 0, policy_states_flat.shape[0])
        policy_states_sample = policy_states_flat[indices]
        policy_actions_sample = policy_actions_flat[indices]
    else:
        policy_states_sample = policy_states_flat[:batch_size]
        policy_actions_sample = policy_actions_flat[:batch_size]

    if disc_input == 'ss' or disc_input == 's':
        policy_x = policy_states_sample
        expert_x = expert_batch.unnorm_obs
    elif disc_input == 'sa':
        policy_x = jax.lax.stop_gradient(jnp.concatenate([policy_states_sample, policy_actions_sample], axis=-1))
        expert_x = jax.lax.stop_gradient(jnp.concatenate([expert_batch.unnorm_obs, jax.nn.one_hot(expert_batch.action, n_actions)], axis=-1))

    # scaler = StandardScaler()
    # mean, std = scaler.fit(expert_x)
    # policy_x = scaler.transform(policy_x, mean, std)
    # expert_x = scaler.transform(expert_x, mean, std)

    policy_weights = jnp.ones(policy_x.shape[0]) / policy_x.shape[0]
    expert_weights = jnp.ones(expert_x.shape[0]) / expert_x.shape[0]

    # Computes the couplings using the Sinkhorn algorithm.
    geom = pointcloud.PointCloud(policy_x, expert_x)
    solver = ott_sinkhorn.Sinkhorn()

    # Create and solve the OT problem
    ot_problem = linear_problem.LinearProblem(geom, policy_weights, expert_weights)
    ot_result = solver(ot_problem)

    # The Wasserstein distance is the transport cost
    divergence = ot_result.reg_ot_cost

    divergence = jax.block_until_ready(divergence)
                
    return divergence

"""
Expert Replay Buffer
"""
def make_expert_transitions(config):
    # load expert transitions
    if config['BACKEND'] == 'positional':
        expert_transitions = pickle.load(open(f"./experts_new/{config['ENV_NAME']}/transitions_sorted.pkl", 'rb'))
    else:
        expert_transitions = pickle.load(open(f"./experts_mjx/{config['ENV_NAME']}/transitions_sorted.pkl", 'rb'))
    expert_transitions = jax.tree_util.tree_map(
        lambda x: x[:,:config["N_EXPERT_TRAJS"]], expert_transitions
    )
    if config["SUB_SAMPLE_RATE"] > 1:
        expert_transitions_subset = jax.tree_util.tree_map(
            lambda x: x[::config["SUB_SAMPLE_RATE"]], expert_transitions
        )
        expert_transitions = jax.tree_util.tree_map(
            lambda x, y: jnp.concatenate([x, y[-1:]]),
            expert_transitions_subset, expert_transitions
        )  
    
    print('Returns:', expert_transitions.info['returned_episode_returns'][-1, :])
    print(f'Mean: {jnp.mean(expert_transitions.info["returned_episode_returns"][-1, :])} Std: {jnp.std(expert_transitions.info["returned_episode_returns"][-1, :])}')  
    print('Obs shape:', expert_transitions.unnorm_obs.shape)
    print('Next Obs shape:', expert_transitions.unnorm_next_obs.shape)
    print('Action shape:', expert_transitions.action.shape)

    # flatten transitions
    expert_transitions = jax.tree_util.tree_map(lambda x: x.swapaxes(1,0).reshape((-1,) + x.shape[2:]), expert_transitions)
    expert_size = expert_transitions.done.shape[0]

    def sample_expert_transitions(batch_size, rng):
        rng, _rng = jax.random.split(rng)
        
        # Sample indices with replacement if needed
        indices = jax.random.randint(_rng, (batch_size,), 0, expert_size)
        
        # Apply sampled indices to all elements in expert_transitions
        batch = jax.tree_util.tree_map(lambda x: x[indices], expert_transitions)

        return batch, rng

    return sample_expert_transitions

# if config['DIVERGENCE_TYPE'] == "l2":
    #     # L2 distance
    #     divergence = jnp.mean(jnp.sum(jnp.square(policy_sa - expert_sa), axis=-1))
                
    # elif config['DIVERGENCE_TYPE'] == "mmd":
    #     # Maximum Mean Discrepancy with RBF kernel
    #     def rbf_kernel(x, y, sigma=1.0):
    #         """RBF kernel for MMD calculation"""
    #         norm_sq = jnp.sum(jnp.square(x[:, None, :] - y[None, :, :]), axis=-1)
    #         return jnp.exp(-norm_sq / (2 * sigma**2))
        
    #     def laplacian_kernel(x, y, sigma=0.2):
    #         """Laplacian kernel for MMD calculation"""
    #         l1_distance = jnp.sum(jnp.abs(x[:, None, :] - y[None, :, :]), axis=-1)
    #         return jnp.exp(-l1_distance / (2.0 * sigma))

    #     # Compute kernel matrices
    #     K_XX = rbf_kernel(policy_sa, policy_sa)
    #     K_XY = rbf_kernel(policy_sa, expert_sa)
    #     K_YY = rbf_kernel(expert_sa, expert_sa)
        
    #     # Compute MMD
    #     n = policy_sa.shape[0]
    #     m = expert_sa.shape[0]
    #     mmd = (
    #         jnp.sum(K_XX) / (n * (n - 1)) 
    #         + jnp.sum(K_YY) / (m * (m - 1)) 
    #         - 2 * jnp.sum(K_XY) / (n * m)
    #     )
    #     divergence = jnp.sqrt(mmd + 1e-6)
                
    # elif config['DIVERGENCE_TYPE'] == "kl":
    #     # KL divergence using kernel density estimation

    #     def gaussian_kde(points, query_points, bandwidth=None):
    #         n_samples, n_dim = points.shape
            
    #         # Scott's rule: h = n^(-1/(d+4))
    #         bandwidth = n_samples ** (-1.0 / (n_dim + 4))
            
    #         # Compute pairwise squared distances
    #         diff = query_points[:, None, :] - points[None, :, :]  # (n_query, n_samples, n_dim)
    #         sq_dists = jnp.sum(diff ** 2, axis=-1)  # (n_query, n_samples)
            
    #         # Compute Gaussian kernel
    #         kernel_values = jnp.exp(-sq_dists / (2 * bandwidth ** 2))  # (n_query, n_samples)
            
    #         # Normalize by sample size
    #         density = jnp.sum(kernel_values, axis=1) / (n_samples * ((2 * jnp.pi) ** (n_dim / 2)) * bandwidth ** n_dim)
            
    #         return density
        
    #     # Estimate densities at a shared set of query points
    #     # For efficiency, we'll use the union of both sample sets as query points
    #     query_points = jnp.concatenate([policy_sa, expert_sa], axis=0)
        
    #     # To prevent numerical issues, we'll normalize the data
    #     data_mean = jnp.mean(query_points, axis=0, keepdims=True)
    #     data_std = jnp.std(query_points, axis=0, keepdims=True) + 1e-8
        
    #     # Normalize data
    #     policy_sa_norm = (policy_sa - data_mean) / data_std
    #     expert_sa_norm = (expert_sa - data_mean) / data_std
    #     query_points_norm = (query_points - data_mean) / data_std
        
    #     # Estimate densities
    #     policy_density = gaussian_kde(policy_sa_norm, query_points_norm)
    #     expert_density = gaussian_kde(expert_sa_norm, query_points_norm)
        
    #     # Normalize to ensure densities integrate to 1
    #     policy_density = policy_density / jnp.sum(policy_density)
    #     expert_density = expert_density / jnp.sum(expert_density)
        
    #     # Compute KL divergence
    #     kl_div = gen_kl(policy_density, expert_density)
        
    #     divergence = kl_div
    
    # elif config['DIVERGENCE_TYPE'] == "js":
    #     # JS divergence using kernel density estimation, similar to KL
        
    #     def gaussian_kde(points, query_points, bandwidth=None):
    #         """Gaussian kernel density estimation."""
    #         n_samples, n_dim = points.shape
            
    #         # Scott's rule: h = n^(-1/(d+4))
    #         bandwidth = n_samples ** (-1.0 / (n_dim + 4))
            
    #         # Compute pairwise squared distances
    #         diff = query_points[:, None, :] - points[None, :, :]  # (n_query, n_samples, n_dim)
    #         sq_dists = jnp.sum(diff ** 2, axis=-1)  # (n_query, n_samples)
            
    #         # Compute Gaussian kernel
    #         kernel_values = jnp.exp(-sq_dists / (2 * bandwidth ** 2))  # (n_query, n_samples)
            
    #         # Normalize by sample size
    #         density = jnp.sum(kernel_values, axis=1) / (n_samples * ((2 * jnp.pi) ** (n_dim / 2)) * bandwidth ** n_dim)
            
    #         return density
        
    #     # Use both sets of points as query points
    #     query_points = jnp.concatenate([policy_sa, expert_sa], axis=0)
        
    #     # Normalize the data
    #     data_mean = jnp.mean(query_points, axis=0, keepdims=True)
    #     data_std = jnp.std(query_points, axis=0, keepdims=True) + 1e-8
        
    #     policy_sa_norm = (policy_sa - data_mean) / data_std
    #     expert_sa_norm = (expert_sa - data_mean) / data_std
    #     query_points_norm = (query_points - data_mean) / data_std
        
    #     # Estimate densities
    #     policy_density = gaussian_kde(policy_sa_norm, query_points_norm)
    #     expert_density = gaussian_kde(expert_sa_norm, query_points_norm)
        
    #     # Normalize densities
    #     policy_density = policy_density / jnp.sum(policy_density)
    #     expert_density = expert_density / jnp.sum(expert_density)
        
    #     # Compute JS divergence directly using the provided function
    #     js_div = gen_js(policy_density, expert_density)
        
    #     divergence = js_div