import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import torch

import torch
import torch.nn.functional as F


def natural_joint_collision_trajectory_loss(
    joint_logits, x_rst, x_initial, pene_thresh=0.2, lambda_weight=1.0
):
    
    B, L, J, _ = x_rst.shape
    joint_logits = rearrange(joint_logits, '(b l) j c -> b l j c', b=B, l=L)
    joint_probs = F.softmax(joint_logits, dim=-1)[..., 1]  # [B, L, J]

    
    max_risk = joint_probs.max(dim=1)[0]  # [B, J]

    
    traj_diff = F.l1_loss(
        x_rst.permute(0, 2, 1, 3),     # [B, J, L, 3]
        x_initial.permute(0, 2, 1, 3), # [B, J, L, 3]
        reduction='none'
    ).mean(dim=[2, 3])  # [B, J]

    
    loss_matrix = max_risk * traj_diff  # [B, J]
    loss = lambda_weight * loss_matrix.sum() / max((max_risk > pene_thresh).sum(), 1)

    return loss

def enhanced_joint_binary_loss(
    joint_logits, 
    x_rst, x_initial, 
    sigma0=0.1, lambda_traj=5.0, delta=2, pene_thresh=0.2,
    spatial_weight_factor=1.0, temporal_weight_factor=1.0,
):
    
    B, L, J, _ = x_rst.shape
    
    
    joint_logits = rearrange(joint_logits, '(b l) j c -> b l j c', b=B, l=L, j=J, c=2)
    
    
    joint_probs = F.softmax(joint_logits, dim=-1)[..., 1]  # [B, L, J] - penetration scores
    pene_mask = joint_probs > pene_thresh                  # [B, L, J] - risk mask

    loss = torch.tensor(0.0, device=x_rst.device)
    count = 0
    
    for t in range(delta, L - delta):
        probs_t = joint_probs[:, t]        # [B, J]
        mask_t = pene_mask[:, t]           # [B, J]
        
        
        if mask_t.sum() == 0:
            continue
        
        x_t = x_rst[:, t]                  # [B, J, 3]
        
        
        
        x_i = x_t.unsqueeze(2)             # [B, J, 1, 3]
        x_j = x_t.unsqueeze(1)             # [B, 1, J, 3]
        distsq = ((x_i - x_j) ** 2).sum(dim=-1)  # [B, J, J]
        
        
        diag_mask = torch.eye(J, device=x_rst.device).unsqueeze(0)  # [1, J, J]
        distsq = distsq + diag_mask * 1e6
        
        
        min_distsq = distsq.min(dim=-1)[0]  # [B, J]
        spatial_weight = torch.exp(-min_distsq / (sigma0**2))  # [B, J]
        
        
        
        r_opt = x_rst[:, t+delta] - x_rst[:, t-delta]     # [B, J, 3]
        r_orig = x_initial[:, t+delta] - x_initial[:, t-delta]  # [B, J, 3]
        
        
        cos_sim = F.cosine_similarity(r_opt, r_orig, dim=-1, eps=1e-6)  # [B, J]
        traj_dev = 1 - cos_sim
        temporal_weight = 1 + lambda_traj * traj_dev  # [B, J]
        
        
        combined_weight = (
            spatial_weight_factor * spatial_weight + 
            temporal_weight_factor * temporal_weight
        ) / (spatial_weight_factor + temporal_weight_factor)
        
        
        
        logits_t = joint_logits[:, t]  # [B, J, 2]
        
        
        targets = torch.zeros(B, J, dtype=torch.long, device=x_rst.device)  # [B, J]
        
        
        ce_loss = F.cross_entropy(
            rearrange(logits_t, 'b j c -> (b j) c'),  # [B*J, 2] 
            targets.view(-1),      # [B*J]
            reduction='none'
        ).view(B, J)  # [B, J]
        
        
        weighted_loss = combined_weight * ce_loss  # [B, J]

        
        loss += weighted_loss[mask_t].sum()
        count += mask_t.sum().item()

    return loss / max(count, 1)


def jcef_plus_loss_batched(
    collision_map, 
    x_rst, x_initial, 
    sigma0=0.1, lambda_traj=5.0, delta=2, pene_thresh=0.2
):
    
    B, L, J, _ = x_rst.shape
    
    
    collision_map = rearrange(collision_map, '(b l) i j c -> b l i j c', b=B, l=L, i=J, j=J, c=2)
    
    
    dmap = F.softmax(collision_map, dim=-1)[..., 1]  # [B, L, J, J] - penetration scores
    pene_mask = dmap > pene_thresh                    # [B, L, J, J] - risk mask

    loss = torch.tensor(0.0, device=x_rst.device)
    count = 0

    for t in range(delta, L - delta):
        d_t = dmap[:, t]              # [B, J, J]
        mask_t = pene_mask[:, t]      # [B, J, J]
        
        
        if mask_t.sum() == 0:
            continue
        
        x_t = x_rst[:, t]             # [B, J, 3]
        x_i = x_t.unsqueeze(2)        # [B, J, 1, 3]
        x_j = x_t.unsqueeze(1)        # [B, 1, J, 3]
        distsq = ((x_i - x_j) ** 2).sum(dim=-1)  # [B, J, J]

        # relative motion vectors in time window
        r_opt = x_rst[:, t+delta] - x_rst[:, t-delta]     # [B, J, 3]
        r_orig = x_initial[:, t+delta] - x_initial[:, t-delta]  # [B, J, 3]
        v_opt = r_opt.unsqueeze(2) - r_opt.unsqueeze(1)   # [B, J, J, 3]
        v_orig = r_orig.unsqueeze(2) - r_orig.unsqueeze(1)  # [B, J, J, 3]

        # cosine similarity
        cos_sim = F.cosine_similarity(v_opt, v_orig, dim=-1, eps=1e-6)  # [B, J, J]
        traj_dev = 1 - cos_sim
        sigma2 = sigma0**2 * (1 + lambda_traj * traj_dev)               # [B, J, J]

        
        
        logits_t = collision_map[:, t]  # [B, J, J, 2]
        
        
        
        targets = torch.zeros(B, J, J, dtype=torch.long, device=x_rst.device)  # [B, J, J]
        
        
        
        # targets[:, finger_joints, finger_joints] = 1  
        
        
        ce_loss = F.cross_entropy(
            rearrange(logits_t, 'b i j c -> (b i j) c'),  # [B*J*J, 2]
            rearrange(targets, 'b i j -> (b i j)'),      # [B*J*J]
            reduction='none'
        )  
        
        
        ce_loss = rearrange(ce_loss, '(b i j) -> b i j', b=B, i=J, j=J)
        
        spatial_weight = torch.exp(-distsq / sigma2)  # [B, J, J]
        energy = spatial_weight * ce_loss  # [B, J, J]

        loss += energy[mask_t].sum()
        count += mask_t.sum().item()

    return loss / max(count, 1)


class PosePenetrationDetector(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.1, activation='relu', task_type='binary', scaling_factor=1.0):
        super().__init__()
        
        self.task_type = task_type  
        self.scaling_factor = scaling_factor  
        
        
        layers = []
        prev_dim = input_dim
        
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),  
                nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.mlp = nn.Sequential(*layers)
        
        
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            
            nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    
    def forward(self, x):
        # x shape: [batch_size, input_dim]
        logits = self.mlp(x)
        
        if self.task_type == 'binary':
            
            return logits
        elif self.task_type == 'joint_score':
            
            scores = torch.sigmoid(logits) * self.scaling_factor
            return scores
        elif self.task_type == 'joint_binary':
            
            return rearrange(logits, 'b (j c) -> b j c', j=22)
        elif self.task_type in ['joint_binary_energy', 'natural_joint_collision_trajectory']:
            
            return rearrange(logits, 'b (j c) -> b j c', j=22)
        elif self.task_type == 'joint_joint_binary':
            
            return rearrange(logits, 'b (i j c) -> b i j c', i=22, j=22)
        else:
            raise ValueError(f"Unknown task_type: {self.task_type}")
    
    def predict(self, x):
        if self.task_type == 'binary':
            logits = self.forward(x)
            probs = F.softmax(logits, dim=-1)
            return probs
        elif self.task_type == 'joint_score':
            scores = self.forward(x)
            return scores
        elif self.task_type == 'joint_binary':
            logits = self.forward(x)
            
            probs = F.softmax(logits, dim=-1)
            return probs
        elif self.task_type == 'joint_binary_energy':
            logits = self.forward(x)
            
            probs = F.softmax(logits, dim=-1)
            return probs
        elif self.task_type == 'joint_joint_binary':
            logits = self.forward(x)
            
            probs = F.softmax(logits, dim=-1)
            return probs
        else:
            raise ValueError(f"Unknown task_type: {self.task_type}")

    def pene_loss(self, feats_rst, feats_initial=None, joints_rst=None, joints_initial=None, jcef_kwargs=None, debug=False):
        flattened_feats = rearrange(feats_rst, 'b l c -> (b l) c')
        logits = self.forward(flattened_feats)
        
        if self.task_type == 'binary':
            
            loss_pene = nn.CrossEntropyLoss()(logits, torch.zeros(logits.size(0), dtype=torch.long, device=logits.device))
        elif self.task_type == 'joint_score':
            
            target = torch.zeros_like(logits)
            loss_pene = nn.MSELoss()(logits, target)
        elif self.task_type == 'joint_binary':
            
            
            
            target = torch.zeros(logits.shape[0], 22, dtype=torch.long, device=logits.device)
            loss_pene = nn.CrossEntropyLoss()(logits.view(-1, 2), target.view(-1))
        elif self.task_type == 'joint_binary_energy':
            
            loss_pene = enhanced_joint_binary_loss(
                joint_logits=logits,
                x_rst=joints_rst,
                x_initial=joints_initial,
                sigma0=jcef_kwargs.get('sigma0', 0.1),
                lambda_traj=jcef_kwargs.get('lambda_traj', 5.0),
                delta=jcef_kwargs.get('delta', 2),
                pene_thresh=jcef_kwargs.get('pene_thresh', 0.2),
                spatial_weight_factor=jcef_kwargs.get('spatial_weight_factor', 1.0),
                temporal_weight_factor=jcef_kwargs.get('temporal_weight_factor', 1.0),
            )
        elif self.task_type == 'natural_joint_collision_trajectory':
            loss_pene = natural_joint_collision_trajectory_loss(
                joint_logits=logits,
                x_rst=joints_rst,
                x_initial=joints_initial,
            )
        elif self.task_type == 'joint_joint_binary':
            
            
            
            loss_pene = jcef_plus_loss_batched(
                collision_map=logits,
                x_rst=joints_rst,
                x_initial=joints_initial,
                sigma0=jcef_kwargs['sigma0'],
                lambda_traj=jcef_kwargs['lambda_traj'],
                delta=jcef_kwargs['delta'],
                pene_thresh=jcef_kwargs['pene_thresh']
            )
        else:
            raise ValueError(f"Unknown task_type: {self.task_type}")
            
        return loss_pene