import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import time
from .net_utils import MultiHeadAttentionModule

class TextEncodeBlock(nn.Module):
    def __init__(self, params, image_encoder):
        super(TextEncodeBlock, self).__init__()
        num_channels = image_encoder.feat_out_channels[-1]
        self.num_text_feat = 768
        self.params = params
        self.hidden_dim = params.text_hidden_dim
        self.text_encode_mode = params.text_encode_mode
        
        # if self.params.text_encode_mode == 'average':
        if self.text_encode_mode == 'pool_conv':
            self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, num_channels)
            )
            self.linear = nn.Linear(self.hidden_dim, num_channels)
        
        elif self.text_encode_mode == 'lstm':
            self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, num_channels)
            )
            self.lstm = nn.LSTM(input_size=num_channels, hidden_size=self.hidden_dim, num_layers=2, batch_first=True)
            self.linear = nn.Linear(self.hidden_dim, num_channels)

        elif self.text_encode_mode == 'lstm_average':
            self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, num_channels)
            )
            self.lstm = nn.LSTM(input_size=num_channels, hidden_size=self.hidden_dim, num_layers=2, batch_first=True)
            self.linear = nn.Linear(self.hidden_dim, num_channels)

        elif self.text_encode_mode == 'lstm_new':
            self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim)
            )
            self.lstm = nn.LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=2, batch_first=True)
            self.linear = nn.Linear(self.hidden_dim, self.hidden_dim)
        
        elif self.text_encode_mode == 'lstm_new_average':
            self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim)
            )
            self.lstm = nn.LSTM(input_size=self.hidden_dim, hidden_size=self.hidden_dim, num_layers=2, batch_first=True)
            self.linear = nn.Linear(self.hidden_dim, self.hidden_dim)

            
    def forward(self, text_emb, text_length):

        if self.text_encode_mode == 'average':
            text_emb = self.feat(text_emb)

            max_len = text_emb.size(1)  # Maximum sequence length
            mask = torch.arange(max_len, device=text_emb.device).unsqueeze(0) < text_length.unsqueeze(1)  # Shape: [batch_size, max_len]
            mask = mask.unsqueeze(-1).to(text_emb.dtype)  # Shape: [batch_size, max_len, 1]

            # Apply the mask to t_pad
            masked_t_pad = text_emb * mask  # Zeros out invalid positions

            # Sum over valid positions
            sums = masked_t_pad.sum(dim=1)  # Shape: [batch_size, features]

            # Compute the mean by dividing by lengths
            text_length = text_length.unsqueeze(1).to(text_emb.dtype) # Shape: [batch_size, 1]
            text_emb_avg = sums / text_length  # Shape: [batch_size, features]

            return text_emb_avg

        elif self.text_encode_mode == 'pool_conv':
            text_emb = self.feat(text_emb)
            
            # text_emb_avg = torch.zeros(text_emb.shape[0], text_emb.shape[-1], requires_grad=True).to(text_emb.device)
            text_emb_list = [] 
            for i in range(text_emb.shape[0]):
                text_emb_wo = text_emb[i:i+1, :int(text_length[i])].view(1, -1)
                text_emb_wo = torch.nn.functional.adaptive_avg_pool1d(text_emb_wo, self.hidden_dim)
                text_emb_wo = self.linear(text_emb_wo)
                # text_emb_avg[i] = text_emb_wo[0]
                text_emb_list.append(text_emb_wo)
            
            text_emb_avg = torch.cat(text_emb_list, dim=0)
            return text_emb_avg

        elif self.text_encode_mode == 'lstm':
            text_emb = self.feat(text_emb)

            text_emb_list = [] 
            for i in range(text_emb.shape[0]):
                text_emb_wo = text_emb[i:i+1, :int(text_length[i])]
                text_emb_wo, _ = self.lstm(text_emb_wo)

                # take the last feature OR do average?
                text_emb_wo = self.linear(text_emb_wo[:, -1, :]) # take the last feature 

                text_emb_list.append(text_emb_wo)
            
            text_emb_avg = torch.cat(text_emb_list, dim=0)
            return text_emb_avg
        
        elif self.text_encode_mode == 'lstm_average':
            text_emb = self.feat(text_emb)

            text_emb_list = [] 
            for i in range(text_emb.shape[0]):
                text_emb_wo = text_emb[i:i+1, :int(text_length[i])]
                text_emb_wo, _ = self.lstm(text_emb_wo)

                # take the last feature OR do average?
                # do average
                text_emb_wo = torch.mean(text_emb_wo, 1)
                text_emb_wo = self.linear(text_emb_wo) 
                text_emb_list.append(text_emb_wo)
            
            text_emb_avg = torch.cat(text_emb_list, dim=0)
            return text_emb_avg

        elif self.text_encode_mode == 'lstm_new':
            text_emb = self.feat(text_emb)

            text_emb_list = [] 
            for i in range(text_emb.shape[0]):
                text_emb_wo = text_emb[i:i+1, :int(text_length[i])]
                text_emb_wo, _ = self.lstm(text_emb_wo)

                # take the last feature OR do average?
                text_emb_wo = self.linear(text_emb_wo[:, -1, :]) # take the last feature 

                text_emb_list.append(text_emb_wo)
            
            text_emb_avg = torch.cat(text_emb_list, dim=0) # (B, hidden_dim)
            return text_emb_avg
        
        elif self.text_encode_mode == 'lstm_new_average':
            text_emb = self.feat(text_emb)

            text_emb_list = [] 
            for i in range(text_emb.shape[0]):
                text_emb_wo = text_emb[i:i+1, :int(text_length[i])]
                text_emb_wo, _ = self.lstm(text_emb_wo)

                # take the last feature OR do average?
                # do average
                text_emb_wo = torch.mean(text_emb_wo, 1)
                text_emb_wo = self.linear(text_emb_wo) 
                text_emb_list.append(text_emb_wo)
            
            text_emb_avg = torch.cat(text_emb_list, dim=0)
            return text_emb_avg


class TextEncoderSepPointEnrich(nn.Module):
    def __init__(self, params, image_encoder):
        super(TextEncoderSepPointEnrich, self).__init__()
        num_class = 3
        self.params = params
        self.text_general_block = TextEncodeBlock(params, image_encoder)
        self.text_left_block = TextEncodeBlock(params, image_encoder)
        self.text_mid_left_block = TextEncodeBlock(params, image_encoder)
        self.text_mid_right_block = TextEncodeBlock(params, image_encoder)
        self.text_right_block = TextEncodeBlock(params, image_encoder)

        self.classifier = nn.Sequential(
            nn.Linear(params.text_hidden_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, num_class),
            # nn.Softmax(dim=1)
        )
        self.use_img_feat = params.use_img_feat

        self.text_hidden_dim = params.text_hidden_dim
        self.point_hidden_dim = params.point_hidden_dim
        self.attention_text = nn.ModuleList(
                                    [MultiHeadAttentionModule(self.text_hidden_dim, self.point_hidden_dim, self.text_hidden_dim, 4),
                                    MultiHeadAttentionModule(self.text_hidden_dim, self.point_hidden_dim, self.text_hidden_dim, 4),
                                    MultiHeadAttentionModule(self.text_hidden_dim, self.point_hidden_dim, self.text_hidden_dim, 4),
                                    MultiHeadAttentionModule(self.text_hidden_dim, self.point_hidden_dim, self.text_hidden_dim, 4),
                                    ])
    
    def forward(self, text_feature_general, text_feature_left, text_feature_mid_left, \
                text_feature_mid_right, text_feature_right, text_length, radar_point_feat, radar_point_mask, image_last_feat=None):
        device = text_feature_general.device
        text_feature_general = self.text_general_block(text_feature_general, text_length[:, 0])
        text_feature_left = self.text_left_block(text_feature_left, text_length[:, 1])
        text_feature_mid_left = self.text_mid_left_block(text_feature_mid_left, text_length[:, 2])
        text_feature_mid_right = self.text_mid_right_block(text_feature_mid_right, text_length[:, 3])
        text_feature_right = self.text_right_block(text_feature_right, text_length[:, 4])
        text_feat = [text_feature_left, text_feature_mid_left, text_feature_mid_right, text_feature_right]
        # print(text_feature_right.shape)
        # radar_point_mask = radar_point_mask.unsqueeze(1)
        res_feat_left = []
        res_feat_mid_left = []
        res_feat_right = []
        res_feat_mid_right = []
        res_feat = [res_feat_left, res_feat_mid_left, res_feat_mid_right, res_feat_right]
        for b in range(text_feature_right.shape[0]):
            unique_number = torch.unique(radar_point_mask[b], sorted=True).int()
            for idx in range(1, 5):
                if idx in unique_number:
                    z = text_feat[idx-1][b:(b+1)].unsqueeze(1)
                    radar_point = radar_point_feat[b][radar_point_mask[b]==idx].unsqueeze(0)
                    z = self.attention_text[idx-1](z, radar_point).squeeze(1)
                    # res_feat.append(z)
                    res_feat[idx-1].append(z)
                else:
                    res_feat[idx-1].append(torch.zeros((1, self.text_hidden_dim), requires_grad=True, device=device))
        
        res_feat_left = torch.cat(res_feat_left, 0)
        res_feat_mid_left = torch.cat(res_feat_mid_left, 0)
        res_feat_mid_right = torch.cat(res_feat_mid_right, 0)
        res_feat_right = torch.cat(res_feat_right, 0)

        text_feature_left = text_feature_left + res_feat_left
        text_feature_mid_left = text_feature_mid_left + res_feat_mid_left
        text_feature_mid_right = text_feature_mid_right + res_feat_mid_right
        text_feature_right = text_feature_right + res_feat_right

        # if image_last_feat is not None:
        if self.use_img_feat:
            img_feat = torch.mean(image_last_feat, dim=(2, 3))
            img_feat = nn.functional.adaptive_avg_pool1d(img_feat, self.params.text_hidden_dim)
            classification_feat = text_feature_general + img_feat
            class_pred = self.classifier(classification_feat)

            # img_feat = self.img(torch.mean(image_last_feat, dim=(2, 3)))
            # class_pred = self.classifier(text_feature_general + img_feat)
        else:
            classification_feat = text_feature_general
            class_pred = self.classifier(classification_feat)

        return text_feature_general, [text_feature_left, text_feature_mid_left, text_feature_mid_right, text_feature_right], class_pred, classification_feat

            
class TextEncoderSep(nn.Module):
    def __init__(self, params, image_encoder):
        super(TextEncoderSep, self).__init__()
        num_class = 3
        self.params = params
        self.text_general_block = TextEncodeBlock(params, image_encoder)
        self.text_left_block = TextEncodeBlock(params, image_encoder)
        self.text_mid_left_block = TextEncodeBlock(params, image_encoder)
        self.text_mid_right_block = TextEncodeBlock(params, image_encoder)
        self.text_right_block = TextEncodeBlock(params, image_encoder)

        self.classifier = nn.Sequential(
            nn.Linear(params.text_hidden_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, num_class),
            # nn.Softmax(dim=1)
        )
        self.use_img_feat = params.use_img_feat
        # if self.use_img_feat:
        #     self.img = nn.Sequential(
        #     nn.Linear(image_encoder.feat_out_channels[-1], params.text_hidden_dim),
        #     nn.ReLU(inplace=True))
    
    def forward(self, text_feature_general, text_feature_left, text_feature_mid_left, \
                text_feature_mid_right, text_feature_right, text_length, image_last_feat=None):
        text_feature_general = self.text_general_block(text_feature_general, text_length[:, 0])
        text_feature_left = self.text_left_block(text_feature_left, text_length[:, 1])
        text_feature_mid_left = self.text_mid_left_block(text_feature_mid_left, text_length[:, 2])
        text_feature_mid_right = self.text_mid_right_block(text_feature_mid_right, text_length[:, 3])
        text_feature_right = self.text_right_block(text_feature_right, text_length[:, 4])

        # if image_last_feat is not None:
        if self.use_img_feat:
            img_feat = torch.mean(image_last_feat, dim=(2, 3))
            img_feat = nn.functional.adaptive_avg_pool1d(img_feat, self.params.text_hidden_dim)
            classification_feat = text_feature_general + img_feat
            class_pred = self.classifier(classification_feat)

            # img_feat = self.img(torch.mean(image_last_feat, dim=(2, 3)))
            # class_pred = self.classifier(text_feature_general + img_feat)
        else:
            classification_feat = text_feature_general
            class_pred = self.classifier(classification_feat)

        return text_feature_general, [text_feature_left, text_feature_mid_left, text_feature_mid_right, text_feature_right], class_pred, classification_feat



class TextEncoder(nn.Module):
    def __init__(self, params, image_encoder):
        super(TextEncoder, self).__init__()
        num_channels = image_encoder.feat_out_channels[-1]
        self.num_text_feat = 768
        num_class = 3
        # self.pseudo_feat = torch.randn((num_class, num_channels), requires_grad=True).to(params.device)
        self.params = params
        self.feat = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.classifier = nn.Sequential(
            nn.Linear(num_channels, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, num_class),
            # nn.Softmax(dim=1)
        )

    def forward(self, text_emb, text_length):
        # (batch, 15, 768)
        
        text_emb = self.feat(text_emb) # (batch, 15, 512)

        # tic = time.time()
        # b = torch.zeros(text_emb.shape[0], text_emb.shape[-1]).to(text_emb.device)
        # for i in range(text_emb.shape[0]):
        #     print(text_emb[i, :text_length[i]].shape)
        #     b[i] = torch.mean(text_emb[i, :text_length[i]], 0)
        # print(time.time()-tic)
        
        max_len = text_emb.size(1)  # Maximum sequence length
        mask = torch.arange(max_len, device=text_emb.device).unsqueeze(0) < text_length.unsqueeze(1)  # Shape: [batch_size, max_len]
        mask = mask.unsqueeze(-1).to(text_emb.dtype)  # Shape: [batch_size, max_len, 1]

        # Apply the mask to t_pad
        masked_t_pad = text_emb * mask  # Zeros out invalid positions

        # Sum over valid positions
        sums = masked_t_pad.sum(dim=1)  # Shape: [batch_size, features]

        # Compute the mean by dividing by lengths
        text_length = text_length.unsqueeze(1).to(text_emb.dtype) # Shape: [batch_size, 1]
        text_emb_avg = sums / text_length  # Shape: [batch_size, features]

        # # adaptive average pooling
        # text_emb_avg = torch.zeros(text_emb.shape[0], text_emb.shape[-1], requires_grad=True).to(text_emb.device)
        # for i in range(text_emb.shape[0]):
        #     text_emb_wo = text_emb[i:i+1, :text_length[i]].view(1, -1)
        #     text_emb_wo = torch.nn.functional.adaptive_avg_pool1d(text_emb_wo, text_emb.shape[-1])
        #     text_emb_avg[i] = text_emb_wo[0]

        # for i in range(text_emb.shape[0]):
        #     print(torch.sum(text_emb[i], -1))

        class_pred = self.classifier(text_emb_avg)
        # class_pred = torch.matmul(text_emb_avg, self.pseudo_feat.transpose(-2, -1))

        return text_emb, text_emb_avg, class_pred

class TextEncoderRegion(nn.Module):
    def __init__(self, params, image_encoder):
        super(TextEncoderRegion, self).__init__()
        num_channels = image_encoder.feat_out_channels[-1]
        self.num_text_feat = 768
        num_class = 3
        # self.pseudo_feat = torch.randn((num_class, num_channels), requires_grad=True).to(params.device)
        self.params = params
        self.feat1 = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.feat2 = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.feat3 = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.feat4 = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.feat_general = nn.Sequential(
            nn.Linear(self.num_text_feat, 512),
            nn.LeakyReLU(),
            nn.Linear(512, num_channels)
        )

        self.classifier = nn.Sequential(
            nn.Linear(num_channels, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, num_class),
            # nn.Softmax(dim=1)
        )

    def forward(self, text_emb):
        # (batch, 15, 768)
        
        # text_emb = self.feat(text_emb) # (batch, 15, 512)

        text_feat_gen = self.feat_general(text_emb[:, 0])
        text_feat_left = self.feat1(text_emb[:, 1])
        text_feat_middle_left = self.feat2(text_emb[:, 2])
        text_feat_middle_right = self.feat3(text_emb[:, 3])
        text_feat_right = self.feat4(text_emb[:, 4])

        class_pred = self.classifier(text_feat_gen)
        # class_pred = torch.matmul(text_emb_avg, self.pseudo_feat.transpose(-2, -1))

        return text_feat_gen, [text_feat_left, text_feat_middle_left, text_feat_middle_right, text_feat_right], class_pred