import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from torch.distributions.normal import Normal

import mmcv
from mmdet.models import weighted_loss
from mmdet.models.builder import LOSSES
from mmcv.utils import TORCH_VERSION, digit_version

#from SSR.projects.mmdet3d_plugin.SSR.planner.metric_stp3 import PlanningMetric
# from projects.mmdet3d_plugin.SSR.utils.plan_loss import PlanMapBoundLoss, plan_map_bound_loss, PlanCollisionLoss, PlanMapDirectionLoss

# Option 2: Learnable Attention Pooling
# class AttentionPool(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.query = nn.Parameter(torch.randn(1, dim))  # learnable query
#         self.attn = nn.MultiheadAttention(dim, num_heads=1)

#     def forward(self, x):
#         # x: (T, B, D) → MultiheadAttention expects (T, B, D)
#         query = self.query.unsqueeze(1).expand(1, x.size(1), -1)  # shape: (1, B, D)
#         pooled, _ = self.attn(query, x, x)  # output shape: (1, B, D)
#         return pooled.squeeze(0)  # shape: (B, D)


class IntrinsicRewardModel(nn.Module):
    """Calculate the Reward with Reward Head"""
    def __init__(self, state_dim=256, hidden_dim=128, dropout_rate=0.1, min_std=0.1, max_std=1):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.gelu = nn.GELU()
        self.mu_head = nn.Linear(hidden_dim, 1)
        self.sigma_head = nn.Linear(hidden_dim, 1)

        self.min_std = min_std
        self.max_std = max_std

        # ---- Initialization for stable reward ----
        nn.init.zeros_(self.mu_head.weight)
        nn.init.zeros_(self.mu_head.bias)
        nn.init.zeros_(self.sigma_head.weight)
        nn.init.zeros_(self.sigma_head.bias)

    def forward(self, x):
        # shape of x: (num_scene_token, batch_size, dim) e.g. (16, 1, 256)
        if x.dim() == 3:
            x = x.mean(dim=0) # (B, 256)

        x = self.fc1(x)
        x = self.gelu(x)

        mu = self.mu_head(x)
        sigma = self.sigma_head(x)
        sigma = torch.sigmoid(sigma)  # squash into (0, 1)
        sigma = sigma * (self.max_std - self.min_std) + self.min_std # (min_std, max_std)

        reward_distribution = Normal(mu.squeeze(1), sigma.squeeze(1))

        return reward_distribution
    
class CriticRewardModel(nn.Module):
    def __init__(self, state_dim=256, action_dim=12, hidden_dim=256, dropout_rate=0.1, min_std=0.1, max_std=1):
        super().__init__()
        # state encoders
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.GELU()
        )
        self.next_state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.GELU()
        )
        # action encoder
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.GELU()
        )

        # fusion and output
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.mu_head = nn.Linear(hidden_dim, 1)
        self.sigma_head = nn.Linear(hidden_dim, 1)

        self.min_std = min_std
        self.max_std = max_std

        # Stable init
        nn.init.zeros_(self.mu_head.weight)
        nn.init.zeros_(self.mu_head.bias)
        nn.init.zeros_(self.sigma_head.weight)
        nn.init.zeros_(self.sigma_head.bias)

    def forward(self, state_tokens, action, next_state_tokens):
        # state_tokens: (16, B, 256), action: (B, 6, 12), next_state_tokens: (16, B, 256)
        B = state_tokens.shape[1]

        state = state_tokens.mean(dim=0)      # (B, 256)
        next_state = next_state_tokens.mean(dim=0)  # (B, 256)
        action = action.view(B, -1)           # (B, 12)

        state_feat = self.state_encoder(state)
        action_feat = self.action_encoder(action)
        next_state_feat = self.next_state_encoder(next_state)

        fused = torch.cat([state_feat, action_feat, next_state_feat], dim=-1)  # (B, 768)
        fused = self.fusion(fused)

        mu = self.mu_head(fused)
        sigma = torch.sigmoid(self.sigma_head(fused))
        sigma = sigma * (self.max_std - self.min_std) + self.min_std

        return Normal(mu.squeeze(-1), sigma.squeeze(-1))

class ImitationRewardModel(nn.Module):
    def __init__(self, state_dim=256, action_dim=12, hidden_dim=256, dropout_rate=0.1, min_std=0.1, max_std=1):
        super().__init__()
        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.GELU()
        )
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, hidden_dim),
            nn.GELU()
        )

        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate)
        )

        self.mu_head = nn.Linear(hidden_dim, 1)
        self.sigma_head = nn.Linear(hidden_dim, 1)

        self.min_std = min_std
        self.max_std = max_std

        nn.init.zeros_(self.mu_head.weight)
        nn.init.zeros_(self.mu_head.bias)
        nn.init.zeros_(self.sigma_head.weight)
        nn.init.zeros_(self.sigma_head.bias)

    def forward(self, state_tokens, action):
        B = state_tokens.shape[1]
        state = state_tokens.mean(dim=0)     # (B, 256)
        action = action.view(B, -1)          # (B, 12)

        state_feat = self.state_encoder(state)
        action_feat = self.action_encoder(action)

        fused = torch.cat([state_feat, action_feat], dim=-1)
        fused = self.fusion(fused)

        mu = self.mu_head(fused)
        sigma = torch.sigmoid(self.sigma_head(fused))
        sigma = sigma * (self.max_std - self.min_std) + self.min_std

        return Normal(mu.squeeze(-1), sigma.squeeze(-1))
