import torch
import torch.nn as nn
import torch.nn.functional as F

def occ_inference_loss(self, ego_occ, observed_occ, best_mode, valid_mask):

    ego_valid_mask = valid_mask[:,0,9:80:10] 
    ego_occ = (ego_occ.unsqueeze(1)[best_mode.bool()])[ego_valid_mask].sigmoid() 
    observed_occ = (observed_occ.unsqueeze(1)[best_mode.bool()])[ego_valid_mask]  
    num_valid = ego_occ.shape[0] 
    if num_valid > 0:
        occ_inference_loss = (ego_occ * observed_occ).sum() / observed_occ.sum()
    else:
        occ_inference_loss = torch.tensor(0.0, device=ego_occ.device)

    return occ_inference_loss
    

def traj_occ_alignment_loss(self, traj, valid_mask, occ, best_mode, horizon, prob_threshold, resolution = 0.625):

    ego_pos = traj[best_mode.bool()][:,9:horizon:10, :2]
    ego_valid_mask = valid_mask[:, 9:horizon:10]

    bs, T, height, width = occ.shape
    xs = ego_pos[:,:,0] / resolution / width  + 0.25    # (bs, future_steps), normalize
    ys = (-ego_pos[:,:,1]) / resolution / height + 0.5    # (bs, future_steps), normalize
    mask = (xs >= 0) & (xs < 1) & (ys >= 0) & (ys < 1) 
    coords = torch.stack([xs, ys], dim=-1)
    traj_prob = bilinear_interpolation_batch(occ[:,:(horizon/10)][ego_valid_mask & mask], coords[ego_valid_mask & mask])

    cost = prob_threshold - traj_prob.sigmoid()
    loss_mask = cost > 0
    cost.masked_fill_(~loss_mask, 0)
    loss = F.l1_loss(cost, torch.zeros_like(cost), reduction='none').sum(-1)
    loss = loss.sum() / (loss_mask.sum() + 1e-6)
    return loss

def traj_occ_collision_loss(self, traj, valid_mask, data, occ, marginal_occ, interactive_mask, best_mode, horizon, prob_threshold, resolution = 0.625, dis_threshold = 2):

    ego_pos_0 = traj[best_mode.bool()][:,9:horizon:10, :2]
    ego_target_heading = data[:,9:horizon:10]
    ego_valid_mask = valid_mask[:, 9:horizon:10]

    bs_joint_occ = occ.sigmoid()
    # contingency planning
    if interactive_mask.any():
        bs, num_agents = interactive_mask.shape
        agents_occ = torch.zeros(bs, num_agents, 2, 128, 128, device=marginal_occ.device)
        agents_occ[interactive_mask] += marginal_occ[:,:2].sigmoid()
        bs_margin_occ, _  = torch.max(agents_occ,dim = 1)
        bs_joint_occ[:,:2] = torch.max(bs_joint_occ[:,:2], bs_margin_occ)

    # vehicle points
    offset = [-0.2643,  1.4610,  3.1863]
    ego_pos_1 = ego_pos_0 + offset[0] * torch.stack([ego_target_heading.cos(), ego_target_heading.sin()], dim=-1)
    ego_pos_2 = ego_pos_0 + offset[1] * torch.stack([ego_target_heading.cos(), ego_target_heading.sin()], dim=-1)
    ego_pos_3 = ego_pos_0 + offset[2] * torch.stack([ego_target_heading.cos(), ego_target_heading.sin()], dim=-1)

    # collision loss function
    def func(ego_pos, occ, ego_valid_mask):

        bs, T, height, width = occ.shape
        xs = ego_pos[...,0] / resolution  + 32    # (bs, future_steps), normalize
        ys = (-ego_pos[...,1]) / resolution  + 64    # (bs, future_steps), normalize
        mask = (xs >= 0) & (xs < 128) & (ys >= 0) & (ys < 128)  
        coords = torch.stack([xs, ys], dim=-1)[ego_valid_mask & mask]
        occ_valid = bs_joint_occ[:,:(horizon/10)][ego_valid_mask & mask]
        total_collision_loss = torch.tensor(0.0, device=occ.device)
        collision_num = torch.tensor(0, device=occ.device)
        for i in range(coords.shape[0]):
            if (occ_valid[i] > prob_threshold).any():
                occ_coords = torch.nonzero(occ_valid[i] > prob_threshold)[:, [1, 0]]
                dist = torch.norm((coords[i].unsqueeze(0) - occ_coords),dim = -1)
                cost = dis_threshold - dist
                loss_mask = cost > 0
                cost.masked_fill_(~loss_mask, 0)
                total_collision_loss += F.l1_loss(cost, torch.zeros_like(cost), reduction='none').sum()
                collision_num += loss_mask.sum()
        collision_loss = total_collision_loss / (collision_num + 1e-6)
        return collision_loss
    
    loss_1 = func(ego_pos_1, occ, ego_valid_mask)
    loss_2 = func(ego_pos_2, occ, ego_valid_mask)
    loss_3 = func(ego_pos_3, occ, ego_valid_mask)

    return loss_1 + loss_2 + loss_3 


def bilinear_interpolation_batch(values, coords):

    values = values.unsqueeze(1)  

    coords = 2 * coords - 1
    coords = coords.unsqueeze(1).unsqueeze(1) 

    # use grid_sample
    output = F.grid_sample(values, coords, mode='bilinear', align_corners=True)

    result = output.squeeze()
    return result