import numpy as np
import torch

from all2.logging import DummyLogger
from all2.agents import Agent

import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt

class JointDQN(Agent):
    """
    A Joint DQN agent.

    Args:
        q_dist (QDist): Approximation of the Q distribution.
        replay_buffer (ReplayBuffer): The experience replay buffer.
        discount_factor (float): Discount factor for future rewards.
        eps (float): Stability parameter for computing the loss function.
        exploration (float): The probability of choosing a random action.
        minibatch_size (int): The number of experiences to sample in each training update.
        replay_start_size (int): Number of experiences in replay buffer when training begins.
        update_frequency (int): Number of timesteps per training update.
        n_update_actions (int): Number of actions that should be explored in training the agent.
    """

    def __init__(
        self,
        q_dist,
        replay_buffer,
        n_actions,
        discount_factor=0.99,
        eps=1e-7,
        exploration=0.02,
        minibatch_size=32,
        replay_start_size=5000,
        update_frequency=4,
        logger=DummyLogger(),
        n_update_actions=None,
        rand_actions=False,
        update_q=0.
    ):
        self.q_dist = q_dist
        self.n_mixture = q_dist.n_mixture
        self.replay_buffer = replay_buffer
        self.logger = logger
        self.eps = eps
        self.exploration = exploration
        self.replay_start_size = replay_start_size
        self.update_frequency = update_frequency
        self.minibatch_size = minibatch_size
        self.discount_factor = discount_factor
        self.n_update_actions = n_update_actions if n_update_actions else n_actions
        self.n_actions = n_actions
        self._state = None
        self._action = None
        self._frames_seen = 0
        self.cur_n_update_actions = n_update_actions
        self.rand_actions = rand_actions
        self.update_q = update_q


    def act(self, state):
        self.replay_buffer.store(self._state, self._action, state)

        self._train()
        self._state = state[0] if type(state) is list else state
        self._action = self._choose_action(self._state)

        # will return n_update_actions actions
        return self._action


    def eval(self, state):
        return self._best_actions(self.q_dist.eval(state))
    

    def _choose_action(self, state):
        if self._should_explore():
            # the below (commented) line samples WITH replacement, and ruins everything!
            # return np.random.randint(0, self.q_dist.n_actions, size=self.num_update_actions)
            return torch.from_numpy(np.random.choice(np.arange(self.q_dist.n_actions), self.n_update_actions, replace=False))
        mixing, means, covariances = self.q_dist.no_grad(state)
        
        return self._best_actions(mixing, means, covariances).detach().cpu()

    def _should_explore(self):
        return (
            len(self.replay_buffer) < self.replay_start_size
            or np.random.rand() < self.exploration
        )
    
    def _batched_block_select(self, tensor, indices):
        """
        Select KxK blocks from each NxMxM matrix in the batch based on meshgrid-like indexing.
        
        Args:
        tensor (torch.Tensor): A tensor of shape [B, N, M, M].
        indices (torch.Tensor): A tensor of shape [B, K], containing the indices to extract 
                                from the MxM matrices for both row and column indices.

        Returns:
        torch.Tensor: A tensor of shape [B, N, K, K] containing the selected KxK blocks.
        """
        B, N, M, _ = tensor.shape
        _, K = indices.shape
        
        # create an index grid for the batch and N dimension
        batch_indices = torch.arange(B).view(B, 1, 1, 1).expand(B, N, K, K)  # [B, N, K, K]
        n_indices = torch.arange(N).view(1, N, 1, 1).expand(B, N, K, K)  # [B, N, K, K]
        
        # create the meshgrid for row and column indices using the input indices
        row_indices = indices.unsqueeze(1).unsqueeze(-1).expand(B, N, K, K)  # [B, N, K, K]
        col_indices = indices.unsqueeze(1).unsqueeze(1).expand(B, N, K, K)  # [B, N, K, K]
        
        # gather the blocks based on the meshgrid indices
        selected_blocks = tensor[batch_indices, n_indices, row_indices, col_indices]
        
        return selected_blocks
    

    def argmax_best_actions(self, mixing, means):
        return torch.argmax(torch.sum(mixing.unsqueeze(-1) * means, dim=-2), dim=-1)


    def _best_actions(self, mixing_coeffs, means, covariances, target_calculation=False):
        scores = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1)

        if target_calculation or not self.rand_actions:
            actions = torch.argmax(scores, dim=-1, keepdim=True)
            if not target_calculation:
                actions = actions.squeeze()
            return actions
        else:
            dev = means.device
            actions = torch.topk(scores, k=self.n_actions, dim=-1)[1]
            b = actions.shape[0]
            batch_perm = torch.empty(b, self.n_update_actions-1, device=dev, dtype=torch.int64)
            for i in range(b):
                batch_perm[i, :] = (torch.randperm(self.n_actions-1, device=dev) + 1)[:self.n_update_actions-1]
            
            bests = torch.zeros((b, 1), device=dev, dtype=torch.int64)

            return torch.gather(actions, 1, torch.hstack((bests, batch_perm))).squeeze()
    

    def _kl_divergence_gaussians(self, means_first, covariances_first, means_second, covariances_second):
        second_inv = torch.linalg.inv(covariances_second)
        term_one = (torch.linalg.slogdet(covariances_second)[1] - torch.linalg.slogdet(covariances_first)[1]).unsqueeze(-1)
        term_two = torch.einsum("...ii->...", second_inv @ covariances_first).unsqueeze(-1) - torch.tensor(self.cur_n_update_actions, dtype=term_one.dtype)
        diff = (means_second - means_first)
        term_three = torch.einsum('bni,bnij,bnj->bn', diff, second_inv, diff).unsqueeze(-1)

        return 0.5 * (term_one + term_two + term_three)
    

    def _cross_kl_divergence_1d_gaussians(self, means1, means2, vars1, vars2):
        means1 = means1.unsqueeze(2)
        means2 = means2.unsqueeze(1)
        vars1 = vars1.unsqueeze(2)
        vars2 = vars2.unsqueeze(1)

        term1 = vars1/(vars2+self.eps)
        term2 = (means2 - means1) ** 2 / (vars2+self.eps)
        term3 = torch.log((vars2 / vars1)+self.eps)
        
        return 0.5 * (term1 + term2 - 1 + term3).squeeze(-1) # [B, N, N]


    def _kl_upper_bound_mog(self, mixing_first, mixing_second, means_first, means_second, covariances_first, covariances_second):
        kl_of_gaussians = self._kl_divergence_gaussians(means_first, covariances_first, means_second, covariances_second).squeeze()
        log_ratio = torch.log(mixing_first/(mixing_second))

        return (mixing_first * (log_ratio + kl_of_gaussians)).sum(-1)
    

    def _1d_mog_variational_reverse_kl(self, mixing1, mixing2, means1, means2, vars1, vars2, num_iter=0):
        kl_table = self._cross_kl_divergence_1d_gaussians(means1, means2, vars1, vars2).mean(-1)

        # Convert vectors to log-space
        log_pi = torch.log(mixing1)         # shape: [A]
        log_omega = torch.log(mixing2)   # shape: [B]

        # Initialization in log-space:
        log_psi = log_pi.unsqueeze(2) + log_omega.unsqueeze(1)  
        # phi is initially the same as psi:
        log_phi = log_psi.clone()
        
        # Iterative updates in log space
        for i in range(num_iter):
            log_psi_prev = log_psi.clone()
            # Update log_psi:
            log_psi = log_phi + log_omega.unsqueeze(1) - torch.logsumexp(log_phi, dim=1, keepdim=True)
            
            # Update log_phi:
            log_phi = log_pi.unsqueeze(2) + log_psi - kl_table - torch.logsumexp(log_psi - kl_table, dim=2, keepdim=True)
            
            # Check convergence: compare the updated psi (in real space) to the previous value.
            err = torch.max(torch.abs(torch.exp(log_psi) - torch.exp(log_psi_prev)))
            if err < 1e-8:
                break

        # Convert back from log-space to obtain the final psi and phi in the original scale.
        phi = torch.exp(log_phi)
        return torch.sum(phi * kl_table, dim=(1, 2)) + torch.sum(phi * (log_phi - log_psi), dim=(1, 2)) 


    def _mog_sliced_reverse_kl(self, mixing1, mixing2, means1, means2, covs1, covs2, probs, num_samples=128):
        directions = self._weighted_sample_unit_vectors(probs, num_samples)

        projected_means1 = torch.einsum('bkn, bsn -> bks', means1, directions).unsqueeze(-1) # [B, N, S, 1]
        projected_means2 = torch.einsum('bkn, bsn -> bks', means2, directions).unsqueeze(-1) # [B, N, S, 1]

        projected_covs1 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs1, directions).unsqueeze(-1)  # [B, N, S, 1]
        projected_covs2 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs2, directions).unsqueeze(-1)  # [B, N, S, 1]

        return self._1d_mog_variational_reverse_kl(mixing1, mixing2, projected_means1, projected_means2, projected_covs1, projected_covs2)
    

    def _1d_mog_mw2(self, mixing1, mixing2, projected_means1, projected_means2, projected_covs1, projected_covs2):
        W2 = self._cross_w2_1d_gaussians(projected_means1, projected_means2, projected_covs1, projected_covs2).mean(-1)
        M = self._log_sinkhorn_transport(W2, mixing1, mixing2)

        return torch.sum(M * W2, dim=(-1, -2))


    def _mog_sliced_mw2(self, mixing1, mixing2, means1, means2, covs1, covs2, num_samples=128):
        directions = self._uniform_sample_unit_vector(num_samples)

        projected_means1 = torch.einsum('bkn, bsn -> bks', means1, directions).unsqueeze(-1) # [B, N, S, 1]
        projected_means2 = torch.einsum('bkn, bsn -> bks', means2, directions).unsqueeze(-1) # [B, N, S, 1]

        projected_covs1 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs1, directions).unsqueeze(-1)  # [B, N, S, 1]
        projected_covs2 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs2, directions).unsqueeze(-1)  # [B, N, S, 1]

        return self._1d_mog_mw2(mixing1, mixing2, projected_means1, projected_means2, projected_covs1, projected_covs2)
    

    def _cross_cramer_1d_gaussians(self, means1, means2, vars1, vars2):
        means1 = means1.unsqueeze(2)
        means2 = means2.unsqueeze(1)
        vars1 = vars1.unsqueeze(2)
        vars2 = vars2.unsqueeze(1)
        
        diff = means1 - means2
        v = torch.sqrt(vars1 + vars2 + self.eps)

        z = diff / v
        unit = 2 * F.gelu(z) - z + 0.797884560802865356 * torch.exp(-z**2/2)
        loss = (v * unit).squeeze(-1)
        loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
        return loss


    def _1d_mog_cramer(self, mixing1, mixing2, projected_means1, projected_means2, projected_covs1, projected_covs2):
        cramer_table_1 = self._cross_cramer_1d_gaussians(projected_means1, projected_means2, projected_covs1, projected_covs2).mean(-1)
        cramer_table_2 = self._cross_cramer_1d_gaussians(projected_means1, projected_means1, projected_covs1, projected_covs1).mean(-1)
        cramer_table_3 = self._cross_cramer_1d_gaussians(projected_means2, projected_means2, projected_covs2, projected_covs2).mean(-1)

        term1_weights = 2 * mixing1.unsqueeze(2) * mixing2.unsqueeze(1)
        term2_weights = mixing1.unsqueeze(2) * mixing1.unsqueeze(1)
        term3_weights = mixing2.unsqueeze(2) * mixing2.unsqueeze(1)

        term1 = torch.sum(term1_weights * cramer_table_1, dim=(-1, -2))
        term2 = torch.sum(term2_weights * cramer_table_2, dim=(-1, -2))
        term3 = torch.sum(term3_weights * cramer_table_3, dim=(-1, -2))

        return term1 - term2 - term3

    
    def _mog_sliced_cramer(self, mixing1, mixing2, means1, means2, covs1, covs2, num_samples=128):
        directions = self._uniform_sample_unit_vector(means1.shape[0], num_samples, means1.device)
        projected_means1 = torch.einsum('bkn, bsn -> bks', means1, directions).unsqueeze(-1) # [B, N, S, 1]
        projected_means2 = torch.einsum('bkn, bsn -> bks', means2, directions).unsqueeze(-1) # [B, N, S, 1]

        projected_covs1 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs1, directions).unsqueeze(-1)  # [B, N, S, 1]
        projected_covs2 = torch.einsum('bsm, bknm, bsn -> bks', directions, covs2, directions).unsqueeze(-1)  # [B, N, S, 1]

        return self._1d_mog_cramer(mixing1, mixing2, projected_means1, projected_means2, projected_covs1, projected_covs2)

    def _train(self):
        if self._should_train():
            # choose whether to use n_update_actions or 1 action to use for update
            self._update_cur_n_update_actions()
            # sample transitions from buffer
            # states: B x ...
            # actions: B x M
            # rewards: B x M
            # next_states: B x M x ...
            states, actions, rewards, next_states, weights = self.replay_buffer.sample(
                self.minibatch_size
            )

            # forward pass
            # get B distributions (one for each batch element)
            mixing_coeffs, means, covariances = self.q_dist.model(states, actions[:, :self.cur_n_update_actions])

            # compute target distribution
            target_mixing, target_means, target_covariances = self._compute_target_marginal_dist(next_states, rewards)

            # compute loss
            # UNCOMMENT FOR SLICED REVERSE KL 
            # L = self._mog_sliced_reverse_kl(mixing_coeffs, target_mixing, means, target_means, covariances, target_covariances, probs)
            
            # UNCOMMENT FOR SlICED FORWARD KL
            # L = self._mog_sliced_reverse_kl(target_mixing, mixing_coeffs, target_means, means, target_covariances, covariances, probs)

            # UNCOMMENT FOR SLICED WASSERSTEIN
            # L = self._mog_sliced_mw2(mixing_coeffs, target_mixing, means, target_means, covariances, target_covariances, probs)

            # UNCOMMENT FOR SLICED CRAMER
            if self.cur_n_update_actions == 1:
                L = self._1d_mog_cramer(mixing_coeffs, target_mixing, means.unsqueeze(-1), target_means.unsqueeze(-1), covariances, target_covariances)
            else:
                L = self._mog_sliced_cramer(mixing_coeffs, target_mixing, means, target_means, covariances, target_covariances)


            loss = (weights * L).mean()

            self.q_dist.reinforce(loss)
            # update replay buffer priorities
            self.replay_buffer.update_priorities(L.detach())


    def _should_train(self):
        self._frames_seen += self.cur_n_update_actions
        return (
            self._frames_seen > self.replay_start_size
            and self._frames_seen % self.update_frequency == 0
        )
    
    
    def _update_cur_n_update_actions(self):
        if np.random.rand() < self.update_q:
            self.cur_n_update_actions = self.n_update_actions
        else:
            self.cur_n_update_actions = 1


    def _compute_target_marginal_dist(self, states, rewards):
        with torch.no_grad():
            target_mixing, target_means, target_covariances = self.q_dist._target(states)
            argmax_mixing, argmax_means = self.q_dist.model.argmax(states)
            argmax_actions = self.argmax_best_actions(argmax_mixing, argmax_means)[:, :self.cur_n_update_actions]
            idx_expanded = argmax_actions.unsqueeze(2).expand(-1, -1, self.n_mixture)
            idx_expanded = idx_expanded.unsqueeze(-1)
            picked_means = torch.gather(target_means, dim=3, index=idx_expanded).squeeze(-1)

            if self.cur_n_update_actions > 1:
                v0 = picked_means[:, 0, :]
                v1 = picked_means[:, 1, :]

                v0e = v0.unsqueeze(2).expand(-1, self.n_mixture, self.n_mixture)  # indexes along dim=1
                v1e = v1.unsqueeze(1).expand(-1, self.n_mixture, self.n_mixture)  # indexes along dim=2

                pairs = torch.stack([v0e, v1e], dim=-1)
                
                target_means = pairs.view(-1, self.n_mixture*self.n_mixture, 2)
                target_mixing = target_mixing[:, 0, :].unsqueeze(1) * target_mixing[:, 1, :].unsqueeze(2)
                target_mixing = target_mixing.flatten(start_dim=1)
                target_covariances = target_covariances.expand(-1, self.n_mixture**2, -1, -1)
                # print("2", target_mixing.shape)

            else:
                target_means = picked_means.transpose(-1, -2)
                target_mixing = target_mixing[:, 0, :]
                target_covariances = target_covariances.expand(-1, self.n_mixture, -1, -1)
                # print("1", target_mixing.shape)

            target_covariances = self._batched_block_select(target_covariances, argmax_actions)
            rewards_vector = rewards[:, :self.cur_n_update_actions].unsqueeze(1)
            rewards_op = torch.einsum("bij,bik->bijk", rewards_vector, rewards_vector)
            rewards_next_means_op = torch.einsum("bij,bik->bijk", rewards_vector, self.discount_factor*target_means)
            next_means_rewards_op = torch.einsum("bij,bik->bijk", self.discount_factor*target_means, rewards_vector)
            next_means_op = torch.einsum("bij,bik->bijk", target_means, target_means)

            return target_mixing, self.discount_factor * target_means + rewards_vector, self._project_to_pd(rewards_op + rewards_next_means_op + next_means_rewards_op + self.discount_factor**2 * (target_covariances + next_means_op) - next_means_op)

    
    def _project_to_pd(self, matrix):
        epsilon = self.eps
        L, Q = torch.linalg.eigh(matrix)

        # set nonpositive eigenvalues to small positive value
        L = torch.where(L > 0, L, torch.full_like(L, epsilon))

        return Q @ torch.diag_embed(L) @ Q.transpose(-1, -2)


    def _uniform_sample_unit_vector(self, B, n, device):
        gaussians = torch.randn((B, n, self.cur_n_update_actions), device=device)
        normalized = F.normalize(gaussians, eps=self.eps)

        return normalized


    def _weighted_sample_unit_vectors(self, probs, n, half_angle=torch.pi/4):
        """
        Given:
        - probs: tensor of shape [B, N] with each row a probability distribution (sums to 1)
        - n: number of unit vectors to sample per batch item
        - half_angle: maximum deviation from the chosen basis vector (in radians)
        - eps: small value to check for near-zero norms
        Returns:
        - A tensor of shape [B, n, N] where each unit vector is computed as:
                v = cos(theta)*e + sin(theta)*u,
            where:
                * e is a standard basis vector chosen by sampling from probs,
                * theta ~ Uniform(0, half_angle), and
                * u is a random unit vector orthogonal to e.
        """
        B, N = probs.shape
        device = probs.device

        # indices: [B, n] with values in 0,...,N-1.
        indices = torch.multinomial(probs, n, replacement=True)

        # Create standard basis vectors.
        e_basis = torch.eye(N, device=device)  # shape: [N, N]
        # Look up the chosen basis vector e for each sample.
        e = e_basis[indices]  # shape: [B, n, N]

        # Sample angles theta uniformly from [0, half_angle]
        theta = torch.rand(B, n, device=device) * half_angle  # shape: [B, n]
        cos_theta = torch.cos(theta)  # shape: [B, n]
        sin_theta = torch.sin(theta)  # shape: [B, n]

        # Sample a random vector and remove its projection onto e.
        u_raw = torch.randn(B, n, N, device=device)
        proj = (u_raw * e).sum(dim=-1, keepdim=True) * e  # projection of u_raw onto e
        u = u_raw - proj  # subtract projection

        # Compute norm and check for near-zero cases.
        norm_u = u.norm(dim=-1, keepdim=True)
        mask = norm_u < self.eps

        fallback_indices = (indices + 1) % N  # shape: [B, n]
        fallback = e_basis[fallback_indices]   # shape: [B, n, N]

        # Replace u with fallback where the norm is too small.
        u = torch.where(mask, fallback, u)
        # Recompute norm after fallback substitution.
        norm_u = u.norm(dim=-1, keepdim=True)
        u = u / (norm_u + self.eps)

        # Reshape cos_theta and sin_theta for broadcasting.
        v = cos_theta.unsqueeze(-1) * e + sin_theta.unsqueeze(-1) * u

        return v
    

    def _multivariate_gaussian_cramer_distance(self, means1, means2, covs1, covs2, probs, num_samples=128):
        directions = self._weighted_sample_unit_vectors(probs, num_samples)

        means1 = means1.squeeze(dim=-2)
        means2 = means2.squeeze(dim=-2)
        covs1 = covs1.squeeze(dim=1)
        covs2 = covs2.squeeze(dim=1)

        projected_means1 = torch.einsum('bi,bni->bn', means1, directions).unsqueeze(-1)
        projected_means2 = torch.einsum('bi,bni->bn', means2, directions).unsqueeze(-1)

        projected_covs1 = torch.einsum('bkm,bmm,bkm->bk', directions, covs1, directions).unsqueeze(-1)  # Shape: [B, N, K, 1]
        projected_covs2 = torch.einsum('bkm,bmm,bkm->bk', directions, covs2, directions).unsqueeze(-1)  # Shape: [B, N, K, 1]

        return self._gaussian_cramer_distance(projected_means1, projected_means2, projected_covs1, projected_covs2)
    

    def _sliced_wasserstein(self, means1, means2, covs1, covs2, probs, num_samples=128):
        directions = self._weighted_sample_unit_vectors(probs, num_samples)

        means1 = means1.squeeze(dim=-2)
        means2 = means2.squeeze(dim=-2)
        covs1 = covs1.squeeze(dim=1)
        covs2 = covs2.squeeze(dim=1)

        projected_means1 = torch.einsum('bi,bni->bn', means1, directions).unsqueeze(-1)
        projected_means2 = torch.einsum('bi,bni->bn', means2, directions).unsqueeze(-1)

        projected_covs1 = torch.einsum('bkm,bmm,bkm->bk', directions, covs1, directions).unsqueeze(-1)  # Shape: [B, N, K, 1]
        projected_covs2 = torch.einsum('bkm,bmm,bkm->bk', directions, covs2, directions).unsqueeze(-1)  # Shape: [B, N, K, 1]

        return self._cross_w2_1d_gaussians(projected_means1, projected_means2, projected_covs1, projected_covs2)
    

    def _gaussian_cramer_distance(self, means1, means2, vars1, vars2):
        diff = means1-means2
        v = torch.sqrt(vars1 + vars2 + 1e-12)

        z = diff / v

        return (v * (2 * F.gelu(z) - z + 0.797884560802865356 * torch.exp(-z**2/2))).squeeze()
    

    def _cross_w2_1d_gaussians(self, means1, means2, vars1, vars2):
        means1 = means1.unsqueeze(2)
        means2 = means2.unsqueeze(1)
        vars1 = vars1.unsqueeze(2)
        vars2 = vars2.unsqueeze(1)

        mean_term = torch.pow(means1-means2, 2)
        var_term = vars1 + vars2 - 2*torch.sqrt(vars1*vars2)

        return (mean_term + var_term).squeeze(-1)


    def _gaussian_wasserstein2_distance(self, means1, means2, covs1, covs2):
        """
        Calculate pairwise squared Wasserstein-2 distances between two batches of Gaussian mixtures.
        
        Args:
            means1: Tensor of shape [batch_size, n_components1, dim]
            means2: Tensor of shape [batch_size, n_components2, dim]
            covs1: Tensor of shape [batch_size, n_components1, dim, dim]
            covs2: Tensor of shape [batch_size, n_components2, dim, dim]
        
        Returns:
            Tensor of shape [batch_size, n_components1, n_components2] containing all pairwise
            squared Wasserstein-2 distances
        """
        
        # Reshape means for broadcasting
        means1_expanded = means1.unsqueeze(2)  # [batch_size, n_components1, 1, dim]
        means2_expanded = means2.unsqueeze(1)  # [batch_size, 1, n_components2, dim]
        
        # Calculate mean term: ||m₁ - m₂||²
        mean_term = torch.sum(
            (means1_expanded - means2_expanded) ** 2,
            dim=-1
        )  # [batch_size, n_components1, n_components2]
                
        # First, expand covariances for broadcasting
        covs1_expanded = covs1.unsqueeze(2)  # [batch_size, n_components1, 1, dim, dim]
        covs2_expanded = covs2.unsqueeze(1)  # [batch_size, 1, n_components2, dim, dim]
        
        cov_sum = torch.diagonal(
            covs1_expanded + covs2_expanded,
            dim1=-2,
            dim2=-1
        ).sum(-1)  # [batch_size, n_components1, n_components2]
        
        eigvals1, eigvecs1 = torch.linalg.eigh(covs1)  # [batch_size, n_components1, dim], [batch_size, n_components1, dim, dim]
        
        # Ensure numerical stability by clipping eigenvalues
        eigvals1 = torch.clamp(eigvals1, min=self.eps)
        
        sqrt_eigvals1 = torch.sqrt(eigvals1)
        sqrt_covs1 = torch.matmul(
            eigvecs1 * sqrt_eigvals1.unsqueeze(-1),
            eigvecs1.transpose(-2, -1)
        )  # [batch_size, n_components1, dim, dim]
        
        # Expand for broadcasting
        sqrt_covs1 = sqrt_covs1.unsqueeze(2)  # [batch_size, n_components1, 1, dim, dim]
        
        inner_term = torch.matmul(
            torch.matmul(sqrt_covs1, covs2_expanded),
            sqrt_covs1
        )  # [batch_size, n_components1, n_components2, dim, dim]
        
        # Compute eigenvalues of inner term
        inner_eigvals = torch.linalg.eigvalsh(inner_term)  # [batch_size, n_components1, n_components2, dim]
        
        # Ensure numerical stability
        inner_eigvals = torch.clamp(inner_eigvals, min=self.eps)
        
        # Compute trace of square root
        sqrt_trace = 2 * torch.sqrt(inner_eigvals).sum(-1)  # [batch_size, n_components1, n_components2]
        
        # Combine all terms
        wasserstein_dist = mean_term + cov_sum - sqrt_trace
        
        return wasserstein_dist


    def _log_sinkhorn_transport(self, A, pi_1, pi_2, epsilon=0.1, max_iters=1024, threshold=1e-7):
        """
        Solves the optimal transport problem using a numerically stable log-space Sinkhorn algorithm.
        Supports batched inputs.
        
        Args:
            A: Cost matrix (torch.Tensor) of shape (batch_size, n, m) or (n, m)
            pi_1: Target row marginals (torch.Tensor) of shape (batch_size, n) or (n,)
            pi_2: Target column marginals (torch.Tensor) of shape (batch_size, m) or (m,)
            epsilon: Regularization parameter (larger values = more stable but less accurate)
            max_iters: Maximum number of iterations
            threshold: Convergence threshold
            
        Returns:
            W: Optimal transport matrix of shape (batch_size, n, m) or (n, m)
        """
        with torch.no_grad():
            # Add batch dimension if inputs are unbatched
            if A.dim() == 2:
                A = A.unsqueeze(0)
                pi_1 = pi_1.unsqueeze(0)
                pi_2 = pi_2.unsqueeze(0)
            
            batch_size, n, m = A.shape
            
            # Convert to log space
            log_pi_1 = torch.log(pi_1)  # (batch_size, n)
            log_pi_2 = torch.log(pi_2)  # (batch_size, m)
            
            # Initialize kernel in log space
            log_K = -A / epsilon  # (batch_size, n, m)
            
            # Stabilize by subtracting maximum value per batch
            log_K = log_K - log_K.amax(dim=(1, 2), keepdim=True)
            
            # Initialize scaling factors in log space
            log_u = torch.zeros(batch_size, n, device=A.device, dtype=A.dtype)  # (batch_size, n)
            log_v = torch.zeros(batch_size, m, device=A.device, dtype=A.dtype)  # (batch_size, m)
            
            # Sinkhorn iterations in log space
            for i in range(max_iters):
                log_u_prev = log_u.clone()
                
                # Update u: (batch_size, n)
                log_u = log_pi_1 - torch.logsumexp(log_K + log_v.unsqueeze(1), dim=2)
                
                # Update v: (batch_size, m)
                log_v = log_pi_2 - torch.logsumexp(log_K.transpose(1, 2) + log_u.unsqueeze(1), dim=2)
                
                # Check convergence (max error across all batches)
                err = torch.max(torch.abs(torch.exp(log_u) - torch.exp(log_u_prev)))
                if err < threshold:
                    break
            
            # Compute final transport matrices
            log_W = log_u.unsqueeze(2) + log_K + log_v.unsqueeze(1)  # (batch_size, n, m)
            W = torch.exp(log_W)
            
            # Ensure exact marginal constraints
            W = W / W.sum(dim=(1, 2), keepdim=True) * pi_1.sum(dim=1, keepdim=True).unsqueeze(2)
            
            # Remove batch dimension if input was unbatched
            if W.size(0) == 1 and A.size(0) == 1:
                W = W.squeeze(0)
                
            return W


class JointDQNTestAgent(Agent):
    def __init__(self, q_dist, exploration=0.0, efficient=False):
        self.q_dist = q_dist
        self.exploration = exploration
        self.efficient = efficient
        self.total_energy = 0
        torch.set_printoptions(precision=3, sci_mode=False)


    def covariance_to_correlation(self, cov):
        std = torch.sqrt(torch.diag(cov))
        outer_std = std[:, None] * std[None, :]
        corr = cov / outer_std
        corr[cov == 0] = 0
        return corr

    def _best_actions(self, mixing_coeffs, means, covariances):
        if self.efficient:
            overall_mean = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1, keepdim=True)
            centered = (means - overall_mean)
            op = torch.einsum("bij,bik->bijk", centered, centered)
            covsum = op + covariances
            overall_cov = torch.sum(mixing_coeffs.unsqueeze(-1).unsqueeze(-1) * covsum, dim=1).squeeze()
            corr = self.covariance_to_correlation(overall_cov)
            er = self.effective_rank(corr, return_scaled=True)
            if er < 0.95 and er > 0.05:
                self.sensitive += 1
                scores = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1)
                actions = torch.argmax(scores, dim=-1)
            else:
                actions = torch.tensor(0, dtype=torch.int32)
        else:
            scores = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1)
            actions = torch.argmax(scores, dim=-1)
            
        return actions

    def act(self, state):
        if np.random.rand() < self.exploration:
            return np.random.randint(0, self.q_dist.n_actions)
        
        mixing, means, covariances = self.q_dist.model(state)
        return self._best_actions(mixing, means, covariances).detach().cpu()


class JointDQNTestAgentMV(Agent):
    def __init__(self, q_dist, exploration=0.0, mv=False, efficient=False, precision=0.015):
        self.q_dist = q_dist
        self.exploration = exploration
        self.lamb = 2.0
        self.mv = mv
        self.efficient = efficient
        self.precision = precision
        torch.set_printoptions(precision=3, sci_mode=False)

    def covariance_to_correlation(self, cov):
        std = torch.sqrt(torch.diag(cov))
        outer_std = std[:, None] * std[None, :]
        corr = cov / outer_std
        corr[cov == 0] = 0 
        return corr
    
    def _project_to_pd(self, matrix):
        epsilon = 1e-3
        L, Q = torch.linalg.eigh(matrix)

        # set nonpositive eigenvalues to small positive value
        L = torch.where(L > epsilon, L, torch.full_like(L, epsilon))

        return Q @ torch.diag_embed(L) @ Q.transpose(-1, -2)
    

    import numpy as np

    def effective_rank(self, A, eps=1e-12, return_scaled=False):
        # singular values (nonnegative)
        s = torch.linalg.svdvals(A)
        # keep positive ones by threshold
        s = s[s > eps]
        if s.size == 0:
            return (1.0, 0.0) if return_scaled else 1.0

        p = s / s.sum()
        # entropy with safe log
        H = -torch.sum(p * torch.log(p + eps))
        er = float(torch.exp(H))

        if return_scaled:
            r = float(s.shape[0])  # numerical rank
            s01 = (er - 1.0) / max(r - 1.0, eps)
            # clip to [0,1] for numerical robustness
            return s01
        return er


    
    def solve_qp_simplex(self, mu, Sigma, lam, tol=1e-6, max_iter=100):
        n = mu.numel()
        one = torch.ones(n, device=mu.device, dtype=mu.dtype)

        L = torch.linalg.cholesky(Sigma)
        c1 = torch.cholesky_solve(one.unsqueeze(-1), L).squeeze(-1)
        c2 = torch.cholesky_solve(mu.unsqueeze(-1), L).squeeze(-1)
        nu_unc = (torch.dot(one, c2) - 2*lam) / torch.dot(one, c1)
        w_unc = (c2 - nu_unc * c1) / (2*lam)
        S = set(torch.where(w_unc > tol)[0].tolist()) or {int(torch.argmax(mu))}

        for _ in range(max_iter):
            idx = torch.tensor(sorted(S), device=mu.device)
            k = idx.numel()
            SigSS = Sigma.index_select(0, idx).index_select(1, idx)
            KKT = torch.zeros(k+1, k+1, device=mu.device, dtype=mu.dtype)
            KKT[:k, :k] = 2*lam*SigSS
            KKT[:k,  k] = 1
            KKT[  k, :k] = 1
            rhs = torch.cat([mu.index_select(0, idx), torch.ones(1, device=mu.device, dtype=mu.dtype)])
            sol = torch.linalg.solve(KKT, rhs)
            wS, nu = sol[:k], sol[k]
            if torch.any(wS <= tol):
                keep = set(idx[wS > tol].tolist())
                if keep == S: keep.discard(int(idx[torch.argmin(wS)]))
                S = keep; continue
            w = torch.zeros(n, device=mu.device, dtype=mu.dtype); w[idx] = wS
            r = mu - 2*lam*(Sigma @ w) - nu*one
            viol = [i for i in range(n) if i not in S and r[i] > tol]
            if not viol: return w
            S.add(max(viol, key=lambda i: float(r[i])))
        raise RuntimeError("did not converge")
    

    def plot_correlation_heatmap(self, corr_matrix, labels=None):
        """
        Plots a heatmap of a correlation matrix with a red-white-green colormap.
        
        Args:
            corr_matrix (np.ndarray): 2D numpy array of correlation coefficients.
            labels (list): Optional list of labels for x and y axes.
        """
        plt.figure(figsize=(8, 6))

        # Define the red-white-green diverging colormap
        cmap = sns.diverging_palette(10, 150, as_cmap=True)

        # Plot the heatmap
        sns.heatmap(
            corr_matrix,
            cmap=cmap,
            vmin=-1, vmax=1,
            center=0,
            annot=True, fmt=".3f",
            annot_kws={"fontsize": 16, "fontweight": "bold"},
            square=True,
            linewidths=0.5,
            cbar=False
        )

        plt.tight_layout()
        plt.xticks(ticks=[], labels=[])
        plt.yticks(ticks=[], labels=[])
        plt.savefig("heat.png", dpi=400, bbox_inches="tight", pad_inches=0)
        plt.close()


    def _best_actions(self, mixing_coeffs, means, covariances):
        overall_mean = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1, keepdim=True)
        centered = (means - overall_mean)
        op = torch.einsum("bij,bik->bijk", centered, centered)
        covsum = op + covariances
        overall_cov = torch.sum(mixing_coeffs.unsqueeze(-1).unsqueeze(-1) * covsum, dim=1).squeeze()
        corr = self.covariance_to_correlation(overall_cov)

        er = self.effective_rank(overall_cov, return_scaled=True)
        print(er)
        self.plot_correlation_heatmap(corr.detach().cpu().numpy())
        if self.efficient:
            if er < (1-self.precision) and er > (self.precision):
                scores = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1)
                actions = torch.argmax(scores, dim=-1)
            else:
                actions = torch.tensor(0, dtype=torch.int32)
        else:
            scores = torch.sum(mixing_coeffs.unsqueeze(-1) * means, dim=1)
            actions = torch.argmax(scores, dim=-1)

        return actions


    def act(self, state):
        if np.random.rand() < self.exploration:
            return np.random.randint(0, self.q_dist.n_actions)
        
        mixing, means, covariances = self.q_dist.model(state)
        return self._best_actions(mixing, means, covariances).detach().cpu()