import torch
from torch import nn

from opencood.models.sub_modules.base_transformer import PreNorm, FeedForward
from opencood.models.sub_modules.a_att_module import Agent_wise_Attention
from opencood.models.sub_modules.s_att_module import Spatial_wise_Attention
from opencood.models.sub_modules.s_conv_module import Spacial_wise_Conv

from opencood.models.sub_modules.torch_transformation_utils import \
    get_transformation_matrix, warp_affine, get_roi_and_cav_mask, \
    get_discretized_transformation_matrix
    
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)

        return x

class STTF(nn.Module):
    def __init__(self, args):
        super(STTF, self).__init__()
        self.discrete_ratio = args['voxel_size'][0]
        self.downsample_rate = args['downsample_rate']

    def forward(self, x, mask, spatial_correction_matrix):
        x = x.permute(0, 1, 4, 2, 3)
        dist_correction_matrix = get_discretized_transformation_matrix(
            spatial_correction_matrix, self.discrete_ratio,
            self.downsample_rate)
        # Only compensate non-ego vehicles
        B, L, C, H, W = x.shape

        T = get_transformation_matrix(
            dist_correction_matrix[:, 1:, :, :].reshape(-1, 2, 3), (H, W))
        cav_features = warp_affine(x[:, 1:, :, :, :].reshape(-1, C, H, W), T,
                                   (H, W))
        cav_features = cav_features.reshape(B, -1, C, H, W)
        x = torch.cat([x[:, 0, :, :, :].unsqueeze(1), cav_features], dim=1)
        x = x.permute(0, 1, 3, 4, 2)
        return x

class V2XFusionBlock(nn.Module):
    def __init__(self, num_blocks, outter_dim,
                 cav_att_config, neighborhood_att_config, spatial_conv_config,
                 feature_fusion='concat', ablation=None):
        super().__init__()
        # first multi-agent attention and then multi-window attention
        self.layers = nn.ModuleList([])
        self.num_blocks = num_blocks
        self.feature_fusion= feature_fusion
        
        self.ablation = ablation
    
        self.ccl = nn.Linear(outter_dim, cav_att_config['dim'])

        for _ in range(num_blocks):                        
            a_att = Agent_wise_Attention(cav_att_config['dim'],
                                         heads=cav_att_config['heads'],
                                         dim_head=cav_att_config['dim_head'],
                                         dropout=cav_att_config['dropout'],
                                         embed_config=cav_att_config['embed_config'] if 'embed_config' in cav_att_config else None)
            
            s_att = Spatial_wise_Attention(dim=neighborhood_att_config['dim'],
                                           num_heads=neighborhood_att_config['heads'],
                                           kernel=neighborhood_att_config['kernel'],
                                           dilations=neighborhood_att_config['dilations'],
                                           mlp_ratio=neighborhood_att_config['mlp_ratio'],
                                           embed_type=neighborhood_att_config['embed_type'],
                                           drop_path=neighborhood_att_config['drop_path'])
            
            s_conv = Spacial_wise_Conv(spatial_conv_config['dim'],
                                       kernel=spatial_conv_config['kernel'],
                                       res_flag=spatial_conv_config['res_flag'])
                
            self.layers.append(nn.ModuleList([
                a_att,
                s_att,
                s_conv
                ]))

    def forward(self, x, init_flag, mask, infra=None, hrpe_comp=None):          
        mask_ = mask.permute(0, 4, 1, 2, 3)
        
        x = x.masked_fill(mask_ == 0, 0)
 
        for aw_att, sw_att, s_conv in self.layers:
            x = self.ccl(x)
            
            aw_x = aw_att(x, mask=mask, infra=infra, init_flag=init_flag, hrpe_comp=hrpe_comp) + x
            sw_x = sw_att(x) + x   
            sw_c_x = s_conv(x)
                
            x = torch.cat((aw_x, sw_x, sw_c_x, x), dim=-1)
                
        return x


class V2XTEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()

        cav_att_config = args['cav_att_config']
        neighborhood_att_config = args['neighborhood_att_config']
        spatial_conv_config = args['spatial_conv_config']
        feed_config = args['feed_forward']

        num_blocks = args['num_blocks']
            
        depth = args['depth']
        outter_dim = args['outter_dim']
        feature_fusion = args['feature_fusion']
        mlp_dim = feed_config['mlp_dim']
        dropout = feed_config['dropout']
        
        ablation = args['ablation'] if 'ablation' in args else None
        
        self.downsample_rate = args['sttf']['downsample_rate']
        self.discrete_ratio = args['sttf']['voxel_size'][0]
        
        self.use_roi_mask = False
        if 'use_roi_mask' in args:
            self.use_roi_mask = args['use_roi_mask']

        self.sttf = STTF(args['sttf'])
        # adjust the channel numbers from 256+3 -> 256
        self.prior_feed = nn.Linear(cav_att_config['dim'] + 3,
                                    cav_att_config['dim'])
        
        self.layers = nn.ModuleList([])
        
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                V2XFusionBlock(num_blocks, outter_dim, 
                               cav_att_config, neighborhood_att_config, spatial_conv_config, feature_fusion, ablation),
                PreNorm(outter_dim, FeedForward(outter_dim, mlp_dim, dropout=dropout))
            ]))
    
    def forward(self, x, mask, spatial_correction_matrix):

        # transform the features to the current timestamp
        # velocity, time_delay, infra, dist, relative_angle
        # (B,L,H,W,3)
        prior_encoding = x[..., -5:]
        infra = prior_encoding[:, :, 0, 0, 2]
        hrpe_comp = prior_encoding[:, :, 0, 0, 3:]
        # (B,L,H,W,C)
        x = x[..., :-5]
        
        x = self.sttf(x, mask, spatial_correction_matrix)
        
        # com_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)  
        com_mask = mask.unsqueeze(1).unsqueeze(2).unsqueeze(
            3) if not self.use_roi_mask else get_roi_and_cav_mask(x.shape,
                                                                  mask,
                                                                  spatial_correction_matrix,
                                                                  self.discrete_ratio,
                                                                  self.downsample_rate)
        
        init_flag=True
        for attn, ff in self.layers:
            x = attn(x, init_flag, mask=com_mask, infra=infra, hrpe_comp=hrpe_comp)
            x = ff(x) + x
            init_flag = False
        return x


class V2XTransformer(nn.Module):
    def __init__(self, args):
        super(V2XTransformer, self).__init__()
        encoder_args = args['encoder']
        self.encoder = V2XTEncoder(encoder_args)

    def forward(self, x, mask, spatial_correction_matrix):
        output = self.encoder(x, mask, spatial_correction_matrix)
        output = output[:, 0]
        return output
