import torch.distributions as torch_distribution
import torch.nn.functional as F
from torch import nn
from .tools import weight_init, uniform_weight_init
import torch
from mmcv.cnn.bricks.transformer import BaseTransformerLayer

def build_mlp(config):
    input_dim = config["input_dim"]
    mlp_layers = config["mlp_layers"]

    assert isinstance(mlp_layers, list) and len(mlp_layers) == 3, "the mlp_layers in config is wrong"

    layers = []
    for out_dim in mlp_layers:
        layers.append(nn.Linear(input_dim, out_dim))
        layers.append(nn.ReLU())
        input_dim = out_dim

    assert isinstance(layers[-1], nn.ReLU)
    layers.pop()

    return nn.Sequential(*layers)

class ValueHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.min_std = config['min_std']
        self.max_std = config['max_std']
        self.input_dim = config['input_dim']
        self.mlp_dims = config['mlp_dims']
        self.norm = config['norm'] if config['norm'] else True

        self.layers = nn.Sequential()
        for i, output_dim in enumerate(self.mlp_dims):
            self.layers.add_module(
                f"critic_linear_{i}", nn.Linear(self.input_dim, output_dim)
            )
            if self.norm:
                self.layers.add_module(
                    f"critic_norm_{i}", nn.LayerNorm(output_dim, eps=1e-3)
                )
            self.layers.add_module(
                f"critic_act_{i}", nn.ReLU()
            )
            self.input_dim = output_dim
        
        self.layers.apply(weight_init)

        self.mean_layer = nn.Linear(self.input_dim, 1)
        self.mean_layer.apply(uniform_weight_init(1.0))

        self.std_layer = nn.Linear(self.input_dim, 1)
        self.std_layer.apply(uniform_weight_init(1.0))
    
    def forward(self, state):
        # state is a tensor with shape [16, B, D]
        if state.dim() == 3:
            state = state.mean(dim=0)
    
        out = self.layers(state)
        mean = self.mean_layer(out)
        std = self.std_layer(out)
        std = (self.max_std - self.min_std) * torch.sigmoid(
                    std + 2.0
                ) + self.min_std
        dist = torch_distribution.normal.Normal(mean.squeeze(1), std.squeeze(1))
        return dist
    


class ValueHeadTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed_dim = config['input_dim']              # e.g. 256
        self.num_heads = config.get('num_heads', 8)
        self.ffn_dim = config.get('ffn_dim', 1024)
        self.min_std = config['min_std']
        self.max_std = config['max_std']
        self.num_layers = config.get('num_layers', 2)

        self.transformer_layers = nn.ModuleList([
            BaseTransformerLayer(
                attn_cfgs=dict(
                    type='MultiheadAttention',
                    embed_dims=self.embed_dim,
                    num_heads=self.num_heads
                ),
                feedforward_channels=self.ffn_dim,
                operation_order=('self_attn', 'norm', 'ffn', 'norm')
            )
            for _ in range(self.num_layers)
        ])


        # Optional: use learnable "value token" to summarize
        self.value_token = nn.Parameter(torch.randn(1, 1, self.embed_dim))

        # Final MLP to output mean and std
        self.mlp = nn.Sequential(
            nn.LayerNorm(self.embed_dim),
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(),
        )

        self.mean_head = nn.Linear(self.embed_dim, 1)
        self.std_head = nn.Linear(self.embed_dim, 1)

    def forward(self, scene_tokens, scene_pos=None):
        """
        scene_tokens: (N, B, D) — 16 scene tokens
        pos_enc: (N, B, D) — optional positional encoding
        """

        N, B, D = scene_tokens.shape

        # [1, B, D] value token → repeat to [1, B, D]
        value_token = self.value_token.expand(-1, B, -1)

        # Concatenate value_token with scene tokens → [1+N, B, D]
        tokens = torch.cat([value_token, scene_tokens], dim=0)

        if scene_pos is not None:
            value_pos = torch.zeros(1, B, D, device=scene_tokens.device)
            pos_enc = torch.cat([value_pos, scene_pos], dim=0)
            for layer in self.transformer_layers:
                tokens = layer(
                    query=tokens,
                    key=tokens,
                    value=tokens,
                    query_pos=pos_enc,
                    key_pos=pos_enc,
                )
        else:
            for layer in self.transformer_layers:
                tokens = layer(
                    query=tokens,
                    key=tokens,
                    value=tokens,
                )

        # Extract the [CLS]-style value token (first one)
        v = tokens[0]  # shape [B, D]

        v = self.mlp(v)
        mean = self.mean_head(v)
        std = self.std_head(v)
        std = (self.max_std - self.min_std) * torch.sigmoid(std + 2.0) + self.min_std

        return torch_distribution.normal.Normal(mean.squeeze(1), std.squeeze(1))
