'''
Behavioral Topology (BeTop): https://arxiv.org/abs/2409.18031
'''
'''
Pipeline developed upon planTF:
https://arxiv.org/pdf/2309.10443
'''
import torch
import torch.nn as nn
from nuplan.planning.simulation.trajectory.trajectory_sampling import TrajectorySampling
from nuplan.planning.training.modeling.torch_module_wrapper import TorchModuleWrapper
from nuplan.planning.training.preprocessing.target_builders.ego_trajectory_target_builder import (
    EgoTrajectoryTargetBuilder,
)

from src.models.betop.layers.common_layers import build_mlp
from src.models.betop.layers.transformer_encoder_layer import TransformerEncoderLayer

from src.models.betop.modules.agent_encoder import AgentEncoder
from src.models.betop.modules.map_encoder import MapEncoder
from src.models.betop.modules.trajectory_decoder import TrajectoryDecoder, PredTrajectoryDecoder, NewContigencyDecoder, ContigencyDecoder

from src.models.betop.modules.topo_decoder import OccupancyFuser, OccupancyDecoder

# no meaning, required by nuplan
trajectory_sampling = TrajectorySampling(num_poses=8, time_horizon=8, interval_length=1)
from time import time

class SelfTransformer(nn.Module):
    def __init__(self):
        super(SelfTransformer, self).__init__()
        heads, dim, dropout = 4, 128, 0.1
        self.self_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
        self.norm_1 = nn.LayerNorm(dim)
        self.norm_2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim), nn.Dropout(dropout))

    def forward(self, inputs, mask=None):
        attention_output, attention_weight = self.self_attention(inputs, inputs, inputs, key_padding_mask=mask)
        attention_output = self.norm_1(attention_output + inputs)
        output = self.norm_2(self.ffn(attention_output) + attention_output)
        return output

class PreSafeModel(nn.Module):
    def __init__(self):
        super(PreSafeModel, self).__init__()
        self.trai_gru = nn.GRU(
            input_size=8,
            hidden_size=64,
            num_layers=2,
            batch_first=True,
            dropout=0.1
        )
        self.safe_fc = nn.Sequential(
            nn.Linear(3, 64),
            nn.LayerNorm(64),
            nn.GELU(),
            nn.Linear(64, 64),
            nn.LayerNorm(64),
            nn.GELU()
        )
        self.fusion_dim = 128
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.fusion_dim,
            nhead=4,
            dim_feedforward=512,
            batch_first=True,
            dropout=0.1
        )
        self.interaction_transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        self.prediction = nn.Sequential(
            nn.Linear(128, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Linear(256, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),
            nn.Linear(256, 3)
        )

    def forward(self, trai, safe_data):
        B, N, T, F_trai = trai.shape
        _, _, F_safe = safe_data.shape
        trai_reshaped = trai.view(B * N, T, F_trai)
        gru_out, _ = self.trai_gru(trai_reshaped)
        trai_feat = gru_out[:, -1, :]  # [B*N, 64]
        safe_reshaped = safe_data.reshape(B * N, F_safe)
        safe_feat = self.safe_fc(safe_reshaped)  # [B*N, 64]
        fused_feat = torch.cat([trai_feat, safe_feat], dim=1)  # [B*N, 128]
        fused_feat = fused_feat.view(B, N, self.fusion_dim)
        output_feat = self.interaction_transformer(fused_feat)  # [B, N, 128]
        safe_data_future = self.prediction(output_feat)  # [B, N, 3]
        return safe_data_future, output_feat

class RAG_Model(nn.Module):
    def __init__(self):
        super(RAG_Model, self).__init__()

    def forward(self, encoder, topk):
        ego_encoder = encoder[:, 0, :]
        distance = torch.norm((encoder[:, 1:, :] - ego_encoder[:, None]), p=2, dim=-1)
        _, mask_top = torch.topk(distance, k=topk, dim=1)
        return mask_top

def select_top(encoder, mask_top):
    other_encoer = encoder[:, 1:]
    mask_top = mask_top.unsqueeze(-1).expand(-1, -1, encoder.size(-1))
    encoder_select = torch.gather(other_encoer, dim=1, index=mask_top)
    encoder = torch.cat([encoder[:, 0][:, None], encoder_select], dim=1)
    return encoder

def select_top_2(encoder, mask_top):
    other_encoer = encoder[:, :, 1:]
    mask_top = mask_top.unsqueeze(1).unsqueeze(-1).expand(-1, encoder.shape[1], -1, encoder.size(-1))
    encoder_select = torch.gather(other_encoer, dim=2, index=mask_top)
    encoder = torch.cat([encoder[:, :, 0][:, :, None], encoder_select], dim=2)
    return encoder

def select_mask(mask, mask_top):
    other_mask = mask[:, 1:]
    mask_select = torch.gather(other_mask, dim=1, index=mask_top)
    mask = torch.cat([mask[:, 0][:, None], mask_select], dim=1)
    return mask

class PlanningModel(TorchModuleWrapper):
    def __init__(
        self,
        dim=128,
        state_channel=6,
        polygon_channel=6,
        history_channel=9,
        history_steps=21,
        future_steps=80,
        encoder_depth=4,
        drop_path=0.2,
        drop_key=0.3,
        num_heads=8,
        num_modes=6,
        use_ego_history=False,
        state_attn_encoder=True,
        state_dropout=0.75,
        joint_pred=False,
        rel_pred=False,
        feature_builder=None,
        planner=None,
        occ_pred=False,
        conti_plan=False,
        traj_step=5,
        multi_pred=False,
        conti_loss=False,
        marginal_mode=6,
    ) -> None:
        super().__init__(
            feature_builders=[feature_builder],
            target_builders=[EgoTrajectoryTargetBuilder(trajectory_sampling)],
            future_trajectory_sampling=trajectory_sampling,
        )

        self.dim = dim
        self.history_steps = history_steps
        self.future_steps = future_steps
        self.rel_pred = rel_pred
        self.conti_loss = conti_loss

        self.pos_emb = build_mlp(5 if self.rel_pred else 4, [dim] * 2)
        
        self.agent_encoder = AgentEncoder(
            state_channel=state_channel,
            history_channel=history_channel,
            dim=dim,
            hist_steps=history_steps,
            drop_path=drop_path,
            use_ego_history=use_ego_history,
            state_attn_encoder=state_attn_encoder,
            state_dropout=state_dropout,
            perspect_norm=rel_pred
        )

        self.map_encoder = MapEncoder(
            dim=dim,
            polygon_channel=polygon_channel,
            perspect_norm=rel_pred
        )
        self.PreSafeModel = PreSafeModel()
        self.rag_model = RAG_Model()
        
        self.encoder_blocks = nn.ModuleList(
            TransformerEncoderLayer(dim=dim, num_heads=num_heads, drop_path=dp, drop_key=drop_key)
            for dp in [x.item() for x in torch.linspace(0, drop_path, encoder_depth)]
        )

        self.norm = nn.LayerNorm(dim)
        if conti_plan:
            self.trajectory_decoder = NewContigencyDecoder(
                embed_dim=dim,
                num_modes=num_modes,
                marginal_mode=marginal_mode,
                future_steps=future_steps,
                out_channels=4,
                top_trajs=6, traj_input=4, traj_step=traj_step, multi_agent=True
            )
        else:
            self.trajectory_decoder = TrajectoryDecoder(
                embed_dim=dim,
                num_modes=num_modes,
                future_steps=future_steps,
                out_channels=4,
            )
        self.joint_pred = joint_pred
        self.occ_pred = occ_pred
        self.conti_plan = conti_plan

        if self.occ_pred:
            self.rel_pos_emb = build_mlp(5, [dim] * 2)
            self.occ_fuser = nn.ModuleList([OccupancyFuser(dim, 0.1) for _ in range(encoder_depth)])
            self.occ_decoder = OccupancyDecoder(dim, 0.0)
        
        self.multi_pred = multi_pred
        if self.joint_pred:
            self.agent_predictor = PredTrajectoryDecoder(
                embed_dim=dim,
                num_modes=num_modes,
                future_steps=future_steps,
                out_channels=4,
            )
        else:
            if self.multi_pred:
                self.agent_predictor = PredTrajectoryDecoder(
                    embed_dim=dim,
                    num_modes=num_modes,
                    future_steps=future_steps,
                    out_channels=4,
                )
            else:
                self.agent_predictor = build_mlp(dim, [dim * 2, future_steps * 2], norm="ln")
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
    
    def _cosine(self, v1, v2):
        v1_norm = torch.linalg.norm(v1, dim=-1)
        v2_norm = torch.linalg.norm(v2, dim=-1)
        v1_x, v1_y = v1[..., 0], v1[..., 1]
        v2_x, v2_y = v2[..., 0], v2[..., 1]
        cos = (v1_x * v2_x + v1_y * v2_y) / (v1_norm * v2_norm + 1e-10)
        return cos

    def _sine(self, v1, v2):
        v1_norm = torch.linalg.norm(v1, dim=-1)
        v2_norm = torch.linalg.norm(v2, dim=-1)
        v1_x, v1_y = v1[..., 0], v1[..., 1]
        v2_x, v2_y = v2[..., 0], v2[..., 1]
        sin = (v1_x * v2_y - v1_y * v2_x) / (v1_norm * v2_norm + 1e-10)
        return sin
    
    def build_rel_feature(self, pos, radius=100.0):
        position, vec = pos[..., :2], pos[..., -2:]
        pos_diff = position[:, :, None, :] - position[:, None, :, :]
        d_pos = torch.linalg.norm(pos_diff, dim=-1)
        d_pos = d_pos * 2 / radius

        # angle diff
        cos_a1 = self._cosine(vec[:, :, None, :2], vec[:,  None, :, :2])
        sin_a1 = self._sine(vec[:, :, None, :2], vec[:,  None, :, :2])

        cos_a2 = self._cosine(vec[:, :, None, :2], pos_diff)
        sin_a2 = self._sine(vec[:, :, None, :2], pos_diff)

        return torch.stack([cos_a1, sin_a1, cos_a2, sin_a2, d_pos], dim=-1)

    def forward(self, data):
        agent_pos = data["agent"]["position"][:, :, self.history_steps - 1]
        agent_heading = data["agent"]["heading"][:, :, self.history_steps - 1]
        agent_mask = data["agent"]["valid_mask"][:, :, : self.history_steps]
        polygon_center = data["map"]["polygon_center"]
        polygon_mask = data["map"]["valid_mask"]          
        safety_history = data['safety_history']
        safety_label = data['safety_label']
        ego_history = torch.concat([data['agent']['position'][:, :, :21], data['agent']['heading'][:, :, :21, None], data['agent']['velocity'][:, :, :21], data['agent']['shape'][:, :, :21], data['agent']['valid_mask'][:, :, :21, None]], dim=-1)
        bs, As = agent_pos.shape[0:2]
        if agent_pos.shape[1] == 33:
            A = top_k = 10
        else:
            A = top_k = As
        position = torch.cat([agent_pos, polygon_center[..., :2]], dim=1)
        angle = torch.cat([agent_heading, polygon_center[..., 2]], dim=1)
        pos = torch.cat(
            [position, torch.stack([angle.cos(), angle.sin()], dim=-1)], dim=-1
        )

        if self.rel_pred or self.occ_pred:
            val_agt = agent_mask.any(-1)
            val_map = polygon_mask.any(-1)
            val_f_mask = torch.cat([val_agt, val_map], dim=-1)
            val_rel_mask = val_f_mask[:, :, None] * val_f_mask[:, None, :]

            rel_pos = self.build_rel_feature(pos)
            rel_pos = rel_pos.detach() * val_rel_mask[..., None].float()
        if self.rel_pred:
            pos = rel_pos
        if self.occ_pred:
            occ_feat = self.rel_pos_emb(rel_pos[:, :A, :, :])
        pos_embed = self.pos_emb(pos)
        if not self.rel_pred:
            agent_key_padding = ~(agent_mask.any(-1))
            polygon_key_padding = ~(polygon_mask.any(-1))
        else:
            agent_key_padding = agent_mask.any(-1)
            polygon_key_padding = polygon_mask.any(-1)

        key_padding_mask = torch.cat([agent_key_padding, polygon_key_padding], dim=-1)
        if self.rel_pred:
            key_padding_mask = key_padding_mask[:, :, None] * key_padding_mask[:, None, :]

        x_agent = self.agent_encoder(data)
        x_polygon = self.map_encoder(data)
        pre_safety, safeencoder = self.PreSafeModel(ego_history, safety_history)
        x_agent = x_agent + safeencoder
        mask_top = self.rag_model(x_agent, top_k-1)
        x_agent = select_top(x_agent, mask_top)
        
        if self.occ_pred:
            occ_feat = torch.cat([select_top_2(occ_feat[:, :, :As], mask_top), occ_feat[:, :, As:]], dim=2)
        
        pos_embed = torch.cat([select_top(pos_embed[:, :As], mask_top), pos_embed[:, As:]], dim=1)
        key_padding_mask = torch.cat([select_mask(key_padding_mask[:, :As], mask_top), key_padding_mask[:, As:]], dim=1)
        x = torch.cat([x_agent, x_polygon], dim=1)
        if not self.rel_pred:
            x = x + pos_embed
        else:
            b, length, d = x.shape
            edge = pos_embed
        
        i = 0
        for blk in self.encoder_blocks:
            if not self.rel_pred:
                x = blk(x, key_padding_mask=key_padding_mask)
            else:
                x, edge = blk(x, edge_mask=key_padding_mask, edge=edge)
            if self.occ_pred:
                occ_feat = self.occ_fuser[i](src_feat=x[:, :A], tgt_feat=x, prev_occ_feat=occ_feat)
                i+= 1

        if self.rel_pred:
            x = x.view(b, length, d)
            
        x = self.norm(x)

        if self.conti_plan:
            joint_plan, trajectory, probability = self.trajectory_decoder(x[:, 0])
        else:
            trajectory, probability = self.trajectory_decoder(x[:, 0])
        
        if self.joint_pred:
            prediction, pred_probability = self.agent_predictor(x[:, 1:A])
        else:
            if self.multi_pred:
                prediction, pred_probability = self.agent_predictor(x[:, 1:A])
            else:
                prediction = self.agent_predictor(x[:, 1:A]).view(bs, -1, self.future_steps, 2)
        
        if self.occ_pred:
            actor_o, actor_map_o = self.occ_decoder(occ_feat, A)

        out = {
            "trajectory": trajectory,
            "probability": probability,
            "prediction": prediction,
            'safety':pre_safety,
            'safety_label':safety_label,
            'rel_pred': self.rel_pred,
            'occ_pred':self.occ_pred,
            'conti_plan':self.conti_plan,
            'conti_loss':self.conti_loss,
        }
        
        if self.occ_pred:
            out['actor_occ'] = actor_o
            out['actor_map_occ'] = actor_map_o

        if self.joint_pred or self.multi_pred:
            out["pred_probability"] = pred_probability
        
        if self.conti_plan:
            out['joint_plan'] = joint_plan

        if not self.training:
            best_mode = probability.argmax(dim=-1)
            output_trajectory = trajectory[torch.arange(bs), best_mode]
            angle = torch.atan2(output_trajectory[..., 3], output_trajectory[..., 2])
            out["output_trajectory"] = torch.cat(
                [output_trajectory[..., :2], angle.unsqueeze(-1)], dim=-1
            )
            full_angle = torch.atan2(trajectory[..., 3], trajectory[..., 2])
            
            out["full_trajectory"] = torch.cat(
                [trajectory[..., :2], full_angle.unsqueeze(-1)], dim=-1
            )
        return out, mask_top