import torch
from torch.nn import functional as F
from all2 import nn

from all2.approximation import MultiHeadedApproximation

class JointQDist(MultiHeadedApproximation):
    def __init__(
        self,
        backbone,
        mixing_head,
        mean_head,
        covariance_head,
        optimizer,
        n_actions,
        n_mixture,
        name="joint_q_dist",
        **kwargs
    ):
        
        self.n_actions = n_actions
        self.n_mixture = n_mixture
        self.backbone = backbone
        self.mixing_head = mixing_head
        self.mean_head = mean_head
        self.covariance_head = covariance_head
        self.model = JointQDistModule(backbone, mixing_head, mean_head, covariance_head, n_actions, n_mixture)
        
        super().__init__(self.model, optimizer, name=name, **kwargs)


class JointQDistModule(torch.nn.Module):
    def __init__(self, backbone, mixing_head, mean_head, covariance_head, n_actions, n_mixture):
        super().__init__()
        self.n_actions = n_actions
        self.n_mixture = n_mixture
        self.device = next(backbone.parameters()).device
        self.backbone = nn.RLNetwork(backbone)
        self.mixing_head = mixing_head
        self.mean_head = mean_head
        self.covariance_head = covariance_head


    def _batched_block_select(self, tensor, indices):
        """
        Select KxK blocks from each NxMxM matrix in the batch based on meshgrid-like indexing.
        
        Args:
        tensor (torch.Tensor): A tensor of shape [B, N, M, M].
        indices (torch.Tensor): A tensor of shape [B, K], containing the indices to extract 
                                from the MxM matrices for both row and column indices.

        Returns:
        torch.Tensor: A tensor of shape [B, N, K, K] containing the selected KxK blocks.
        """
        B, N, M, _ = tensor.shape
        K = indices.shape[-1]
        
        # Create an index grid for the batch and N dimension
        batch_indices = torch.arange(B).view(B, 1, 1, 1).expand(B, N, K, K)  # [B, N, K, K]
        n_indices = torch.arange(N).view(1, N, 1, 1).expand(B, N, K, K)  # [B, N, K, K]
        
        # Create the meshgrid for row and column indices using the input indices
        row_indices = indices.unsqueeze(1).unsqueeze(-1).expand(B, N, K, K)  # [B, N, K, K]
        col_indices = indices.unsqueeze(1).unsqueeze(1).expand(B, N, K, K)  # [B, N, K, K]
        
        # Gather the blocks based on the meshgrid indices
        selected_blocks = tensor[batch_indices, n_indices, row_indices, col_indices]
        
        return selected_blocks


    def get_features(self, states):
        return self.backbone(states)
    

    def get_mixing(self, features):
        return F.softmax(self.mixing_head(features).view(-1, self.n_mixture), dim=-1)
    

    def get_means(self, features):
        return self.mean_head(features).view(-1, self.n_mixture, self.n_actions)
    

    def augment_features(self, features1, features2):
        diff = features1 - features2
        hadamard = features1 * features2

        return torch.cat((features1, features2, diff, hadamard), dim=-1)
    

    def get_covariances(self, features1, features2):
        augmented_features = self.augment_features(features1, features2)
        params = self.covariance_head(augmented_features).view(-1, 1, self.n_actions*(self.n_actions+1)//2).expand(-1, self.n_mixture, -1)
        B = params.shape[0]
        L = torch.zeros((B, self.n_mixture, self.n_actions, self.n_actions), device=params.device)
        idx = torch.tril_indices(self.n_actions, self.n_actions)
        L[:, :, idx[0], idx[1]] = params
        # Exponential on the diagonal to ensure positive definiteness
        diag_idx = torch.arange(self.n_actions)
        L[:, :, diag_idx, diag_idx] = torch.exp(L[:, :, diag_idx, diag_idx]/2) + 1e-1

        return L @ L.transpose(-1, -2)


    def decode(self, features1, features2):
        mixing_coeffs = self.get_mixing(features1)
        means = self.get_means(features1)
        covariances = self.get_covariances(features1, features2)
        return mixing_coeffs, means, covariances


    def forward(self, states, actions=None):
        features = self.get_features(states)
        mixing_coeffs, means, covariances = self.decode(features, features)
        
        if actions is None:
            return mixing_coeffs, means, covariances
        if isinstance(actions, list):
            actions = torch.cat(actions)
        
        # unsqueeze actions in dim1 and then repeat it n_mixture times in dim1
        selected_mixing = mixing_coeffs
        selected_means = torch.gather(means, dim=2, index=actions.unsqueeze(1).repeat([1, self.n_mixture, 1]))
        selected_covariances = self._batched_block_select(covariances, actions)

        return selected_mixing, selected_means, selected_covariances
    

    def argmax(self, inputs):
        with torch.no_grad():
            features = self.get_features(inputs)
            features1 = features[:, 0, :]
            features2 = features[:, 1, :]
            mixing_coeffs1 = self.get_mixing(features1).unsqueeze(1)
            mixing_coeffs2 = self.get_mixing(features2).unsqueeze(1)
            means1 = self.get_means(features1).unsqueeze(1)
            means2 = self.get_means(features2).unsqueeze(1)
            return torch.cat((mixing_coeffs1, mixing_coeffs2), dim=1), torch.cat((means1, means2), dim=1)

    def to(self, device):
        self.device = device
        return super().to(device)
    

def jointdqn_cartpole_backbone_constructor(state_size=4):
    return nn.Sequential(
        nn.Scale(1 / 255),
        nn.Linear(state_size, 128),
        nn.ReLU(),
        nn.Linear(128, 128)
    )
    

def jointdqn_cartpole_mixing_head_constructor(n_mixture):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear(128, n_mixture)
    )


def jointdqn_cartpole_mean_head_constructor(n_mixture, n_actions):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear(128, n_mixture*n_actions)
    )


def jointdqn_cartpole_covariance_head_constructor(n_actions):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear(512, n_actions*(n_actions+1)//2)
    )


def jointdqn_backbone_constructor(frames=4):
    return nn.Sequential(
        nn.Scale(1 / 255),
        nn.Conv2d(frames, 32, 8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, 4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, 3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136, 512)
    )


def jointdqn_mixing_head_constructor(n_mixture):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear0(512, n_mixture)
    )


def jointdqn_mean_head_constructor(n_mixture, n_actions):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear0(512, n_mixture*n_actions)
    )


def jointdqn_covariance_head_constructor(n_actions):
    return nn.Sequential(
        nn.ReLU(),
        nn.Linear(2048, n_actions*(n_actions+1)//2)
    )