import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random
from ..attention import EncoderLayer
from ..utils import CrossNet
from .position import PositionalEmbedding, AbsPositionalEmbedding

class CloudContextFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()

        self.element_cnt = args['action_element_cnt'] + args['info_element_cnt']
        self.action_cnt = args['action_element_cnt']
        self.info_cnt = args['info_element_cnt']
        self.cloud_context_feature = args["cloud_context_feature"]
        self.cloud_context_feature_discrete_cnt = len(args["cloud_discrete_feature"]["cloud_context_feature"])

        emb_dim = args['mlp_emb_dim']

        # cloud_context_feature
        self.main_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["main_type"], emb_dim).to(
            args["device"])  
        self.assis_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["assis_type"],
                                       emb_dim).to(args["device"])  
        self.slop_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["slop_type"], emb_dim).to(
            args["device"])  # slope
        self.scene_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["scene_type"],
                                       emb_dim).to(args["device"])  # scene
        self.last_fw = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["last_fw"], emb_dim).to(
            args["device"])  # last roadclass
        self.last_rc = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["last_rc"], emb_dim).to(
            args["device"])  # last formway
        self.seg_len_discrete = nn.Embedding(
            args["cloud_discrete_feature"]["cloud_context_feature"]["seg_len_discrete"], emb_dim).to(
            args["device"])  # segment length
        self.back_lane_cnt = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["back_lane_cnt"],
                                          emb_dim).to(args["device"])  
        self.front_lane_cnt = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["front_lane_cnt"],
                                           emb_dim).to(args["device"])  
        self.lane_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["lane_type"], emb_dim).to(
            args["device"]) 
        self.nature_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["nature_type"],
                                        emb_dim).to(args["device"])  
        self.light_cnt = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["light_cnt"], emb_dim).to(
            args["device"]) 
        self.non_navi_cnt = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["non_navi_cnt"],
                                         emb_dim).to(args["device"]) 
        self.last_lane_cnt = nn.Embedding(args["cloud_discrete_feature"]["cloud_context_feature"]["last_lane_cnt"],
                                          emb_dim).to(args["device"])  

    def forward(self, input_data):
        batch_size = input_data.size(0)
        seq_len = input_data.size(1)

        # 数据切片
        cloud_context_feature = input_data

        # [1] context
        main_type = self.main_type(cloud_context_feature[:, :, 0].long())
        assis_type = self.assis_type(cloud_context_feature[:, :, 1].long())
        slop_type = self.slop_type(cloud_context_feature[:, :, 2].long())
        next_main_type = self.main_type(cloud_context_feature[:, :, 3].long())
        next_assis_type = self.assis_type(cloud_context_feature[:, :, 4].long())
        pre_main_type = self.main_type(cloud_context_feature[:, :, 5].long())
        scene_type = self.scene_type(cloud_context_feature[:, :, 6].long())
        last_rc = self.last_rc(cloud_context_feature[:, :, 7].long())
        last_fw = self.last_fw(cloud_context_feature[:, :, 8].long())
        first_rc = self.last_rc(cloud_context_feature[:, :, 9].long())
        first_fw = self.last_fw(cloud_context_feature[:, :, 10].long())
        seg_len_discrete = self.seg_len_discrete(cloud_context_feature[:, :, 11].long())
        nextseg_len_discrete = self.seg_len_discrete(cloud_context_feature[:, :, 12].long())
        back_lane_cnt = self.back_lane_cnt(cloud_context_feature[:, :, 13].long())
        front_lane_cnt = self.front_lane_cnt(cloud_context_feature[:, :, 14].long())
        lane_type = self.lane_type(cloud_context_feature[:, :, 15].long())
        nature_type = self.nature_type(cloud_context_feature[:, :, 16].long())
        next_lane_type = self.lane_type(cloud_context_feature[:, :, 17].long())
        light_cnt = self.light_cnt(cloud_context_feature[:, :, 18].long())
        non_navi_cnt = self.non_navi_cnt(cloud_context_feature[:, :, 19].long())
        last_lane_cnt = self.last_lane_cnt(cloud_context_feature[:, :, 20].long())

        cloud_context_feature_discrete = torch.cat(
            [main_type, assis_type, slop_type, next_main_type, next_assis_type, pre_main_type, scene_type, last_rc,
             last_fw, first_rc, first_fw, seg_len_discrete, nextseg_len_discrete, back_lane_cnt, front_lane_cnt,
             lane_type, nature_type, next_lane_type, light_cnt, non_navi_cnt, last_lane_cnt], -1)

        cloud_context_feature_continue = cloud_context_feature[:, :, self.cloud_context_feature_discrete_cnt:]

        return {
            "cloud_context_feature_discrete": cloud_context_feature_discrete,
            "cloud_context_feature_continue": cloud_context_feature_continue
        }


class CloudActionFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()

        self.element_cnt = args['action_element_cnt'] + args['info_element_cnt']
        self.action_cnt = args['action_element_cnt']
        self.info_cnt = args['info_element_cnt']

        self.cloud_action_feature_discrete_cnt = len(args["cloud_discrete_feature"]["cloud_action_feature"])
        self.device = args["device"]
        emb_dim = args['mlp_emb_dim']

        # cloud_action_feature
        self.action_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_action_feature"]["action_type"],
                                        emb_dim).to(args["device"])

        action_sub_type_list = args["cloud_discrete_feature"]["cloud_action_feature"]["action_sub_type"]
        self.action_sub_type_embs = nn.ModuleList([
            nn.Embedding(num_embeddings=value, embedding_dim=emb_dim).to(self.device)
            for value in action_sub_type_list
        ])
    def forward(self, input_data):
        batch_size = input_data.size(0)
        seq_len = input_data.size(1)
        
        cloud_action_feature = input_data
        
        action_type_emb = self.action_type(cloud_action_feature[:, :, :, 0].long()).to(self.device)
        
        action_sub_type_indices = cloud_action_feature[:, :, :, 1].long()

        action_sub_type_emb = torch.zeros(batch_size, seq_len, cloud_action_feature.size(2), action_type_emb.size(-1), device=self.device)
        
        for idx, emb in enumerate(self.action_sub_type_embs):
            action_sub_type_emb[:, :, idx, :] = emb(action_sub_type_indices[:, :, idx])

        action_discrete = torch.cat([action_type_emb, action_sub_type_emb], -1)
        cloud_action_feature_continue = cloud_action_feature[:, :, :, self.cloud_action_feature_discrete_cnt:]

        return {
            "cloud_action_discrete": action_discrete.reshape(batch_size, seq_len, -1),
            "cloud_action_feature_continue": cloud_action_feature_continue.reshape(batch_size, seq_len, -1),
        }


class CloudInfoFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()

        self.element_cnt = args['action_element_cnt'] + args['info_element_cnt']
        self.action_cnt = args['action_element_cnt']
        self.info_cnt = args['info_element_cnt']

        self.cloud_info_feature = args["cloud_info_feature"]
        self.cloud_info_feature_discrete_cnt = len(args["cloud_discrete_feature"]["cloud_info_feature"])

        emb_dim = args['mlp_emb_dim']
        self.device = args["device"]

        # cloud_info_feature
        self.info_type = nn.Embedding(args["cloud_discrete_feature"]["cloud_info_feature"]["info_type"], emb_dim).to(
            args["device"])

        info_sub_type_list = args["cloud_discrete_feature"]["cloud_info_feature"]["info_sub_type"]
        self.info_sub_type = nn.ModuleList([
            nn.Embedding(num_embeddings=value, embedding_dim=emb_dim).to(self.device)
            for value in info_sub_type_list
        ])
    def forward(self, input_data):
        batch_size = input_data.size(0)
        seq_len = input_data.size(1)
        cloud_info_feature = input_data

        # info
        info_type = self.info_type(cloud_info_feature[:, :, :, 0].long())
        
        info_sub_type_indices = cloud_info_feature[:, :, :, 1].long()
        info_sub_type = torch.zeros(batch_size, seq_len, cloud_info_feature.size(2), info_type.size(-1), device=self.device)
        for idx, emb in enumerate(self.info_sub_type):
            info_sub_type[:, :, idx, :] = emb(info_sub_type_indices[:, :, idx])

        info_discrete = torch.cat([info_type, info_sub_type], -1)
        cloud_info_feature_continue = cloud_info_feature[:, :, :, self.cloud_info_feature_discrete_cnt:]

        return {
            "cloud_info_discrete": info_discrete.reshape(batch_size, seq_len, -1),
            "cloud_info_feature_continue": cloud_info_feature_continue.reshape(batch_size, seq_len, -1),
        }


class ClientFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()

        self.element_cnt = args['action_element_cnt'] + args['info_element_cnt']
        self.action_cnt = args['action_element_cnt']
        self.info_cnt = args['info_element_cnt']

        self.client_position_feature = args["client_position_feature"]
        self.client_action_feature = args["client_action_feature"]
        self.client_info_feature = args["client_info_feature"]
        self.client_history_feature = args["client_history_feature"]

        self.client_position_feature_discrete_cnt = len(args["client_discrete_feature"]["client_position_feature"])
        self.client_history_feature_discrete_cnt = len(args["client_discrete_feature"]["client_history_feature"])
        self.client_action_feature_continue_start = len(
            args["client_discrete_feature"]["client_action_feature"]) + 2  
        self.client_info_feature_continue_start = len(
            args["client_discrete_feature"]["client_info_feature"]) + 2  

        self.client_position_feature_end = self.client_position_feature  # 20
        self.client_action_feature_end = self.client_position_feature_end + self.client_action_feature * self.action_cnt  # 20 + 24*8
        self.client_info_feature_end = self.client_action_feature_end + self.client_info_feature * self.info_cnt  # 20 + 24*8 + 24*8
        self.client_history_element_feature_end = self.client_info_feature_end + (self.client_history_feature - 2) * self.element_cnt 
        self.client_history_other_feature_end = self.client_history_element_feature_end + 2  

        emb_dim = args['mlp_emb_dim']

        # client_position_feature
        self.current_time = nn.Embedding(args["client_discrete_feature"]["client_position_feature"]["current_time"],
                                         emb_dim).to(args["device"])
        self.current_subend_element = nn.Embedding(
            args["client_discrete_feature"]["client_position_feature"]["current_subend_element"], emb_dim).to(
            args["device"])  # 47 + 3 
        self.subseg_type = nn.Embedding(args["client_discrete_feature"]["client_position_feature"]["subseg_type"],
                                        emb_dim).to(args["device"])
        self.ds_discrete = nn.Embedding(args["client_discrete_feature"]["client_position_feature"]["ds_discrete"],
                                        emb_dim).to(args["device"])
        self.ds_subend_discrete = nn.Embedding(
            args["client_discrete_feature"]["client_position_feature"]["ds_subend_discrete"], emb_dim).to(
            args["device"])
        self.speed_discrete = nn.Embedding(args["client_discrete_feature"]["client_position_feature"]["speed_discrete"],
                                           emb_dim).to(args["device"])
        self.accelerate_speed_discrete = nn.Embedding(
            args["client_discrete_feature"]["client_position_feature"]["accelerate_speed_discrete"], emb_dim).to(
            args["device"])
        self.current_seg_len = nn.Embedding(
            args["client_discrete_feature"]["client_position_feature"]["current_seg_len"], emb_dim).to(args["device"])

        # client_action_feature
        self.action_if_current_seg = nn.Embedding(
            args["client_discrete_feature"]["client_action_feature"]["if_current_seg"], emb_dim).to(args["device"])
        self.action_ds_to_element = nn.Embedding(
            args["client_discrete_feature"]["client_action_feature"]["action_ds_to_element"], emb_dim).to(
            args["device"])

        # client_info_feature
        self.info_if_current_seg = nn.Embedding(
            args["client_discrete_feature"]["client_info_feature"]["if_current_seg"], emb_dim).to(args["device"])
        self.info_ds_to_element = nn.Embedding(
            args["client_discrete_feature"]["client_info_feature"]["info_ds_to_element"], emb_dim).to(args["device"])

        # client_history_feature
        self.last_playelement_type = nn.Embedding(
            args["client_discrete_feature"]["client_history_feature"]["last_playelement_type"], emb_dim).to(
            args["device"])
        self.seg_play_cnt = nn.Embedding(args["client_discrete_feature"]["client_history_feature"]["seg_play_cnt"],
                                         emb_dim).to(args["device"])
        self.subseg_play_cnt = nn.Embedding(
            args["client_discrete_feature"]["client_history_feature"]["subseg_play_cnt"], emb_dim).to(args["device"])
        self.last_play_ds = nn.Embedding(args["client_discrete_feature"]["client_history_feature"]["last_play_ds"],
                                         emb_dim).to(args["device"])
        self.last_play_time = nn.Embedding(args["client_discrete_feature"]["client_history_feature"]["last_play_time"],
                                           emb_dim).to(args["device"])

    def forward(self, batch_list):
        batch_size = batch_list.size(0)
        seq_len = batch_list.size(1)
        client_position_feature = batch_list[:, :, :self.client_position_feature_end].reshape(batch_size, seq_len, self.client_position_feature)
        client_action_feature = batch_list[:, :, self.client_position_feature_end:self.client_action_feature_end].reshape(batch_size, seq_len, self.action_cnt, self.client_action_feature)
        client_info_feature = batch_list[:, :, self.client_action_feature_end:self.client_info_feature_end].reshape(batch_size, seq_len, self.info_cnt,self.client_info_feature)

        client_history_element_feature = batch_list[:, :, self.client_info_feature_end:self.client_history_element_feature_end].reshape(batch_size, seq_len, self.element_cnt, self.client_history_feature - 2)

        client_history_other_feature = batch_list[:,:,self.client_history_element_feature_end:self.client_history_other_feature_end].reshape(batch_size, seq_len, 2)

        # [1] position
        current_rc = self.current_rc(client_position_feature[:, :, 0].long())
        current_fw = self.current_fw(client_position_feature[:, :, 1].long())
        current_time = self.current_time(client_position_feature[:, :, 2].long())
        current_subend_element = self.current_subend_element(client_position_feature[:, :, 3].long())
        current_substart_element = self.current_subend_element(client_position_feature[:, :, 4].long())
        subseg_type = self.subseg_type(client_position_feature[:, :, 5].long())

        client_position_feature[:, :, 6], client_position_feature[:, :, 7] = 0, 0
        ds_discrete = self.ds_discrete(client_position_feature[:, :, 6].long())
        ds_subend_discrete = self.ds_subend_discrete(client_position_feature[:, :, 7].long())

        speed_discrete = self.speed_discrete(client_position_feature[:, :, 8].long())
        accelerate_speed_discrete = self.accelerate_speed_discrete(client_position_feature[:, :, 9].long())
        current_seg_len = self.current_seg_len(client_position_feature[:, :, 10].long())
        next_seg_len = self.current_seg_len(client_position_feature[:, :, 11].long())

        client_position_feature_discrete = torch.cat(
            [current_rc, current_fw, current_time, current_subend_element, current_substart_element, subseg_type,
             ds_discrete, ds_subend_discrete, speed_discrete, accelerate_speed_discrete, current_seg_len, next_seg_len],
            -1)

        # [2] action

        action_if_current_seg = self.action_if_current_seg(client_action_feature[:, :, :, 2].long())
        action_ds_to_element = self.action_ds_to_element(client_action_feature[:, :, :, 3].long())
        action_discrete = torch.cat([action_if_current_seg, action_ds_to_element], -1)

        # [3] info
        info_if_current_seg = self.info_if_current_seg(client_info_feature[:, :, :, 2].long())
        info_ds_to_element = self.info_ds_to_element(client_info_feature[:, :, :, 3].long())
        info_discrete = torch.cat([info_if_current_seg, info_ds_to_element], -1)
        # [4] history
        last_playelement_type = self.last_playelement_type(client_history_element_feature[:, :, :, 0].long())
        seg_play_cnt = self.seg_play_cnt(client_history_element_feature[:, :, :, 1].long())
        subseg_play_cnt = self.subseg_play_cnt(client_history_element_feature[:, :, :, 2].long())
        last_play_ds = self.last_play_ds(client_history_element_feature[:, :, :, 3].long())
        last_play_time = self.last_play_time(client_history_element_feature[:, :, :, 4].long())

        history_discrete = torch.cat(
            [last_playelement_type, seg_play_cnt, subseg_play_cnt, last_play_ds, last_play_time], -1)

        # [5] all continue features
        client_position_feature_continue = client_position_feature[:, :, self.client_position_feature_discrete_cnt:]
        client_action_feature_continue = client_action_feature[:, :, :, self.client_action_feature_continue_start:]
        client_info_feature_continue = client_info_feature[:, :, :, self.client_info_feature_continue_start:]
        client_history_element_feature_continue = client_history_element_feature[:, :, :, self.client_history_feature_discrete_cnt:]
        client_history_other_feature_continue = client_history_other_feature

        #[6] feature dict
        feature_subseg_type = client_position_feature[:, :, 5]

        return {
            "client_position_feature_discrete": client_position_feature_discrete,
            "client_action_discrete": action_discrete,
            "client_info_discrete": info_discrete,
            "client_history_discrete": history_discrete,
            "client_position_feature_continue": client_position_feature_continue,
            "client_action_feature_continue": client_action_feature_continue,
            "client_info_feature_continue": client_info_feature_continue,
            "client_feature_subtype" : feature_subseg_type,
            "client_history_element_feature_continue": client_history_element_feature_continue,
            "client_history_other_feature_continue": client_history_other_feature_continue
        }

class CloudUserFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()
        emb_dim = args['mlp_emb_dim']
        self.cloud_user_feature_discrete_cnt = len(args["cloud_discrete_feature"]["cloud_user_feature"])
        self.user_feature_seg_num = args["cloud_user_feature_seg_num"]

        # cloud_user_feature
        self.play_style =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['play_style'], emb_dim).to(args["device"])
        self.old_driver =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['old_driver'], emb_dim).to(args["device"])
        self.to_home =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['to_home'], emb_dim).to(args["device"])
        self.to_company =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['to_company'], emb_dim).to(args["device"])
        self.music =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['music'], emb_dim).to(args["device"])
        self.familiarity =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['familiarity'], emb_dim).to(args["device"])
        self.complexity =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['complexity'], emb_dim).to(args["device"])
        self.language =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['language'], emb_dim).to(args["device"])
        self.kilometers =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['kilometers'], emb_dim).to(args["device"])
        self.mute_days =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['mute_days'], emb_dim).to(args["device"])
        self.detailed_days =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['detailed_days'], emb_dim).to(args["device"])
        self.concise_days =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['concise_days'], emb_dim).to(args["device"])
        self.minimalist_days =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['minimalist_days'], emb_dim).to(args["device"])
        self.intelligence_days =nn.Embedding(args["cloud_discrete_feature"]['cloud_user_feature']['intelligence_days'], emb_dim).to(args["device"])


    def forward(self, input_data):
        batch_size = input_data.size(0)
        seq_len = input_data.size(1)

        cloud_user_feature = input_data
        cloud_user_feature = cloud_user_feature.reshape(batch_size, seq_len, self.user_feature_seg_num, -1)

        # [1] user feature embedding
        play_style = self.play_style(cloud_user_feature[:,:,:,0].long())
        old_driver = self.old_driver(cloud_user_feature[:,:,:,1].long())
        to_home = self.to_home(cloud_user_feature[:,:,:,2].long())
        to_company = self.to_company(cloud_user_feature[:,:,:,3].long())
        music = self.music(cloud_user_feature[:,:,:,4].long())
        familiarity = self.familiarity(cloud_user_feature[:,:,:,5].long())
        complexity = self.complexity(cloud_user_feature[:,:,:,6].long())
        language = self.language(cloud_user_feature[:,:,:,7].long())
        kilometers = self.kilometers(cloud_user_feature[:,:,:,8].long())
        mute_days = self.mute_days(cloud_user_feature[:,:,:,9].long())
        detailed_days = self.detailed_days(cloud_user_feature[:,:,:,10].long())
        concise_days = self.concise_days(cloud_user_feature[:,:,:,11].long())
        minimalist_days = self.minimalist_days(cloud_user_feature[:,:,:,12].long())
        intelligence_days = self.intelligence_days(cloud_user_feature[:,:,:,13].long())

        cloud_user_feature_discrete = torch.cat([play_style,old_driver,to_home,to_company,music,familiarity,complexity,language,kilometers,mute_days,detailed_days,concise_days,minimalist_days,intelligence_days], -1)
        cloud_user_feature_continue = cloud_user_feature[:, :, :, self.cloud_user_feature_discrete_cnt:]
        
        cloud_user_feature_discrete = cloud_user_feature_discrete.reshape(batch_size, seq_len, -1)
        cloud_user_feature_continue = cloud_user_feature_continue.reshape(batch_size, seq_len, -1)

        return {
            "cloud_user_feature_discrete": cloud_user_feature_discrete,
            "cloud_user_feature_continue": cloud_user_feature_continue
        }


class CloudModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.cloud_context_constructor = CloudContextFeatureConstructor(args)
        self.cloud_action_constructor = CloudActionFeatureConstructor(args)
        self.cloud_info_constructor = CloudInfoFeatureConstructor(args)
        self.cloud_user_constructor = CloudUserFeatureConstructor(args)
        self.single_linear = nn.Sequential(nn.Linear(504, 168), nn.ReLU())

    def forward(self, cloud_context_input, cloud_action_input, cloud_info_input, cloud_user_input):
        cloud_context_feature_dict = self.cloud_context_constructor(cloud_context_input)
        cloud_action_feature_dict = self.cloud_action_constructor(cloud_action_input)
        cloud_info_feature_dict = self.cloud_info_constructor(cloud_info_input)
        cloud_user_feature_dict = self.cloud_user_constructor(cloud_user_input)

        cloud_action_discrete = cloud_action_feature_dict["cloud_action_discrete"]
        cloud_info_discrete = cloud_info_feature_dict["cloud_info_discrete"]
        cloud_action_feature_continue = cloud_action_feature_dict["cloud_action_feature_continue"]
        cloud_info_feature_continue = cloud_info_feature_dict["cloud_info_feature_continue"]

        # context feature contains base road_feature, user feature
        cloud_context_feature_discrete = torch.cat([cloud_context_feature_dict["cloud_context_feature_discrete"], cloud_user_feature_dict['cloud_user_feature_discrete']], -1)
        cloud_context_feature_continue = torch.cat([cloud_context_feature_dict["cloud_context_feature_continue"], cloud_user_feature_dict["cloud_user_feature_continue"]], -1)
        cloud_context_feature_discrete = self.single_linear(cloud_context_feature_discrete)
        return {
            "cloud_context_feature_discrete": cloud_context_feature_discrete,
            "cloud_action_discrete": cloud_action_discrete,
            "cloud_info_discrete": cloud_info_discrete,
            "cloud_context_feature_continue": cloud_context_feature_continue,
            "cloud_action_feature_continue": cloud_action_feature_continue,
            "cloud_info_feature_continue": cloud_info_feature_continue,
        }


class ClientActionFeatureConstructor(nn.Module):

    def __init__(self, args):
        super().__init__()
        emb_dim = args['mlp_emb_dim']
        self.emb_dim = emb_dim
        self.lane_id = args["client_discrete_feature"]["client_element_embedding"]["lane_id"]
        self.rtk_single_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_single_id"]
        self.rtk_combine_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_combine_id"]

        self.cloud_action_feature_discrete_cnt = args["client_discrete_feature"]["client_element_embedding"]["discrete_feature_num"]
        self.lane_main_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["lane_action_type"],
                                        emb_dim).to(args["device"])
        self.lane_sub_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["lane_action_sub_type"],
                                        emb_dim).to(args["device"])

        self.rtk_single_main_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["rtk_single_action_type"],
                                        emb_dim).to(args["device"])
        self.rtk_single_sub_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["rtk_single_action_sub_type"],
                                        emb_dim).to(args["device"])

        self.rtk_combine_main_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["rtk_combine_action_type"],
                                        emb_dim).to(args["device"])
        self.rtk_combine_sub_type = nn.Embedding(args["client_discrete_feature"]["client_element_embedding"]["rtk_combine_sub_type"],
                                        emb_dim).to(args["device"])

    def forward(self, input_data):
        batch_size = input_data.size(0)
        seq_len = input_data.size(1)
        # 数据切片
        client_element_feature = input_data

        # lane
        lane_type = self.lane_main_type(torch.where(client_element_feature[:, :, 0, 0].long() == self.lane_id, 1, 0))
        lane_sub_type = self.lane_sub_type(client_element_feature[:, :, 0, 1].long())

        # rtk single
        
        rtk_single_type = self.rtk_single_main_type(torch.where(client_element_feature[:, :, 1, 0].long() == self.rtk_single_id, 1, 0))
        rtk_single_sub_type = self.rtk_single_sub_type(client_element_feature[:, :, 1, 1].long())

        # rtk combine
        rtk_combine_type = self.rtk_combine_main_type(torch.where(client_element_feature[:, :, 2, 0].long() == self.rtk_combine_id, 1, 0))
        rtk_combine_sub_type = self.rtk_combine_sub_type(client_element_feature[:, :, 2, 1].long())

        action_type = torch.cat([lane_type, rtk_single_type, rtk_combine_type], -1)
        action_sub_type = torch.cat([lane_sub_type, rtk_single_sub_type, rtk_combine_sub_type], -1)

        action_type = action_type.reshape(batch_size, seq_len, -1, self.emb_dim)
        action_sub_type = action_sub_type.reshape(batch_size, seq_len, -1, self.emb_dim)

        client_action_discrete = torch.cat([action_type, action_sub_type], -1)
        client_action_continue = client_element_feature[:, :, :, self.cloud_action_feature_discrete_cnt:]

        return {
            "client_action_discrete": client_action_discrete,
            "client_action_continue": client_action_continue,
        }

class ClientModel(nn.Module):
    def __init__(self, args, teach_force=False, hard_label=False):
        super(ClientModel, self).__init__()

        self.teach_force = teach_force
        self.hard_label = hard_label
        self.args = args
        self.att_hid = args['att_hid']
        self.info_onehot_cnt = args['info_element_cnt'] + 1

        self.lane_id = args["client_discrete_feature"]["client_element_embedding"]["lane_id"]
        self.rtk_single_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_single_id"]
        self.rtk_combine_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_combine_id"]

        # 端模型embedding： lane、rtk
        self.client_action_constructor = ClientActionFeatureConstructor(args)

        # 云端特征模块
        self.client_constructor = ClientFeatureConstructor(args)

        # 多任务模型相关结构
        self.task_name = ['trigger', 'action', 'info', 'voiceTimes']

        # speed seq model 
        self.speed_seq_index = args["client_position_feature"] - len(args["client_discrete_feature"]["client_position_feature"]) - args["client_position_speed_feature"] - args["client_position_accuspeed_feature"]
        self.accuspeed_seq_index = self.speed_seq_index + args["client_position_speed_feature"]

        self.speed_mlp = nn.Sequential(nn.Linear(1, args['speed_att_hid']), nn.ReLU())
        self.accuspeed_mlp = nn.Sequential(nn.Linear(1, args['speed_att_hid']), nn.ReLU())
        self.speed_attention_layer = EncoderLayer(args['speed_att_hid'], args['speed_att_head'], args['speed_att_hid'], 0, args['device'], args)
        self.accuspeed_attention_layer = EncoderLayer(args['speed_att_hid'], args['speed_att_head'], args['speed_att_hid'], 0, args['device'], args)
        self.speed_pos_model = AbsPositionalEmbedding(args)

        # attention
        self.attention_layer = EncoderLayer(args['att_hid'], args['att_head'], args['att_hid'], 0, args['device'], args)
        self.attention_fc = nn.Sequential(nn.Linear(args['att_emb'], args['att_hid']), nn.ReLU())

        self.virtual = nn.Embedding(1, self.args['att_emb']).to(self.args['device'])



    def one_hot(self, labels, num):
        one = torch.zeros((labels.size(0), num)).to(self.args['device'])
        one[range(labels.long().size(0)), labels.long()] = 1
        return one

    def process_action_pred(self, pred_action):
        pred_action = torch.sigmoid(pred_action)
        pred_action = (pred_action > 0.5).long()
        return pred_action

    def process_info_pred(self, pred_info):
        pred_info = torch.softmax(pred_info, dim=1)
        pred_info = pred_info.max(-1)[1].long()
        return pred_info

    def replace_element_feature(self, element_feature_ori, element_feature_replace):
        element_feature_ori[:, :, self.lane_id, :] = element_feature_replace[:, :, 0, :]
        element_feature_ori[:, :, self.rtk_single_id, :] = element_feature_replace[:, :, 1, :]
        element_feature_ori[:, :, self.rtk_combine_id, :] = element_feature_replace[:, :, 2, :]
        return element_feature_ori



    def forward(self, cloud_context_feature_discrete, cloud_action_discrete,
                cloud_info_discrete, cloud_context_feature_continue, cloud_action_feature_continue,
                cloud_info_feature_continue, client_input, client_elements):

        batch_size = client_input.size(0)
        seq_len = client_input.size(1)

        client_elements_dict = self.client_action_constructor(client_elements)
        client_elements_discrete = client_elements_dict["client_action_discrete"]
        client_elements_continues = client_elements_dict["client_action_continue"]

        client_feature_dict = self.client_constructor(client_input)
        client_position_feature_discrete = client_feature_dict["client_position_feature_discrete"]
        client_action_discrete = client_feature_dict["client_action_discrete"]
        client_info_discrete = client_feature_dict["client_info_discrete"]
        client_history_discrete = client_feature_dict["client_history_discrete"]
        client_position_feature_continue = client_feature_dict["client_position_feature_continue"]
        client_action_feature_continue = client_feature_dict["client_action_feature_continue"]
        client_info_feature_continue = client_feature_dict["client_info_feature_continue"]
        client_history_element_feature_continue = client_feature_dict["client_history_element_feature_continue"]
        client_history_other_feature_continue = client_feature_dict["client_history_other_feature_continue"]

        # print(client_position_feature_discrete.shape)

        cloud_action_discrete_rh = cloud_action_discrete.reshape(batch_size, seq_len, -1, self.args['mlp_emb_dim'] * 2)  
        cloud_info_discrete_rh = cloud_info_discrete.reshape(batch_size, seq_len, -1, self.args['mlp_emb_dim'] * 2)  
        # client_action_discrete_rh = client_action_discrete.reshape(batch_size,-1)
        # client_info_discrete_rh = client_info_discrete.reshape(batch_size,-1)

        cloud_action_discrete_rh = self.replace_element_feature(cloud_action_discrete_rh, client_elements_discrete)


        client_both_discrete = torch.cat([client_action_discrete, client_info_discrete], -2)  
        cloud_both_discrete = torch.cat([cloud_action_discrete_rh, cloud_info_discrete_rh], -2)
        history_both_discrete = client_history_discrete

        both_discrete = torch.cat([client_both_discrete, cloud_both_discrete, history_both_discrete], -1)

        other_discrete = torch.cat([cloud_context_feature_discrete, client_position_feature_discrete], -1)

        cloud_action_feature_continue = cloud_action_feature_continue.reshape(batch_size, seq_len, self.args['action_element_cnt'], -1) 
        cloud_info_feature_continue = cloud_info_feature_continue.reshape(batch_size, seq_len, self.args['info_element_cnt'], -1) 
        
        cloud_action_feature_continue = self.replace_element_feature(cloud_action_feature_continue, client_elements_continues)

        cloud_both_continue = torch.cat([cloud_action_feature_continue, cloud_info_feature_continue], -2)
        client_both_continue = torch.cat([client_action_feature_continue, client_info_feature_continue], -2)
        client_history_continue = client_history_element_feature_continue.reshape(batch_size, seq_len, self.args['action_element_cnt'] + self.args['info_element_cnt'], -1)
        both_continue = torch.cat([cloud_both_continue, client_both_continue, client_history_continue], -1)

        speed_seq_feature = self.speed_mlp(client_position_feature_continue[:,:,self.speed_seq_index:self.accuspeed_seq_index].unsqueeze(-1))
        accuspeed_seq_feature = self.accuspeed_mlp(client_position_feature_continue[:,:,self.accuspeed_seq_index:].unsqueeze(-1))
        seq_position =  self.speed_pos_model(batch_size).unsqueeze(1)

        speed_seq_feature, _ = self.speed_attention_layer((speed_seq_feature + seq_position).reshape(batch_size*seq_len, self.args["client_position_speed_feature"], -1))
        accuspeed_seq_feature, _ = self.accuspeed_attention_layer((accuspeed_seq_feature + seq_position).reshape(batch_size*seq_len, self.args["client_position_accuspeed_feature"], -1))

        other_continue = torch.cat([cloud_context_feature_continue,\
                        client_position_feature_continue[:,:,:self.speed_seq_index], speed_seq_feature.reshape(batch_size, seq_len, -1), accuspeed_seq_feature.reshape(batch_size, seq_len, -1),\
                        client_history_other_feature_continue], -1)

        both_feature = torch.cat([both_discrete, both_continue], -1) # b,47,30  / 54
        # encoder self-attetion
        virtual = torch.zeros([batch_size, seq_len, 1]).long().to(self.args['device'])
        both_feature = torch.cat([self.virtual(virtual), both_feature], 2) 
        both_feature = both_feature.reshape(batch_size*seq_len, -1,  self.args['att_emb'])
        both_feature = self.attention_fc(both_feature) 
        both_feature, attention_map = self.attention_layer(both_feature)
        both_feature = both_feature.reshape(batch_size, seq_len, -1,  self.att_hid)
        virtual_fea = both_feature[:, :, 0, :].reshape(batch_size, seq_len, -1)

        # [3] Total feature
        tbt_inputs = torch.cat([virtual_fea, other_discrete, other_continue], -1)

        return tbt_inputs, client_feature_dict


class TBTEmbedding(nn.Module):
    def __init__(self, args):
        super().__init__()

        emb_dim = args["mlp_emb_dim"]
        self.args = args
        self.device = args['device']
        self.action_size = int(args["action_element_cnt"])
        self.info_size = int(args["info_element_cnt"])
        self.element_size = int(args["action_element_cnt"])+int(args["info_element_cnt"])
        self.lane_id = args["client_discrete_feature"]["client_element_embedding"]["lane_id"]
        self.rtk_single_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_single_id"]
        self.rtk_combine_id = args["client_discrete_feature"]["client_element_embedding"]["rtk_combine_id"]

        self.cloud_dim_cnt_ori = 0 
        self.client_dim_cnt_ori = 0 

        for k1, v1 in args["cloud_discrete_feature"].items():
            if "cloud_context_feature" in k1:
                self.cloud_dim_cnt_ori += args["cloud_context_feature"]
            elif "cloud_action_feature" in k1:
                self.cloud_dim_cnt_ori += args["cloud_action_feature"] * self.action_size
            elif "cloud_info_feature" in k1:
                self.cloud_dim_cnt_ori += args["cloud_info_feature"] * self.info_size

        for k1, v1 in args["client_discrete_feature"].items():
            if "client_position_feature" in k1:
                self.client_dim_cnt_ori += args["client_position_feature"]
            elif "client_action_feature" in k1:
                self.client_dim_cnt_ori += args["client_action_feature"] * self.action_size
            elif "client_info_feature" in k1:
                self.client_dim_cnt_ori += args["client_info_feature"] * self.info_size
            elif "client_history_feature" in k1:
                self.client_dim_cnt_ori += (args["client_history_feature"]-2) * self.element_size + 2

        self.cloud_context_feature = args["cloud_context_feature"]
        self.cloud_action_feature = args["cloud_action_feature"]
        self.cloud_info_feature = args["cloud_info_feature"]
        self.cloud_user_feature = args["cloud_user_feature"]
        self.cloud_context_feature_end =  self.cloud_context_feature  # 23
        self.cloud_action_feature_end = self.cloud_context_feature_end + self.cloud_action_feature * self.action_size # 23 + 22*4
        self.cloud_info_feature_end = self.cloud_action_feature_end + self.cloud_info_feature * self.info_size # 23 + 22*4 + 25*4 
        self.cloud_user_feature_start = self.cloud_dim_cnt_ori+self.client_dim_cnt_ori
        self.cloud_user_feature_end = self.cloud_user_feature_start + self.cloud_user_feature
        
        
        self.cloud_model = CloudModel(args)
        self.client_model = ClientModel(args,teach_force=True,hard_label=True)
        self.client_model.client_constructor.current_rc = self.cloud_model.cloud_context_constructor.last_rc
        self.client_model.client_constructor.current_fw = self.cloud_model.cloud_context_constructor.last_fw
        self.client_model.client_constructor.action_type = self.cloud_model.cloud_action_constructor.action_type
        self.client_model.client_constructor.info_type = self.cloud_model.cloud_info_constructor.info_type

        # cross feature
        # deep
        self.pos_model = PositionalEmbedding(args)

        self.embedding_fc = nn.Sequential(
            nn.Linear(args['embeded_all_feature_cnt'], 96),
            torch.nn.LeakyReLU()
        )
        # cross
        self.cross_net = CrossNet(args['embeded_all_feature_cnt'], layer_num=2, parameterization='matrix')
        # final
        self.dnn_linear = nn.Sequential(
            nn.Linear(args['embeded_all_feature_cnt'] + 96, args['hidden']),
            torch.nn.LeakyReLU()
        )

        
    def forward(self, data):
        batch_list = data['gpt_input']
        padding_mask = data['padding_mask']
        batch_size = batch_list.size(0)
        seq_len = batch_list.size(1)

        cloud_context_input = batch_list[:,:, :self.cloud_context_feature_end]
        cloud_action_input = batch_list[:,:, self.cloud_context_feature_end:self.cloud_action_feature_end]
        cloud_info_input = batch_list[:,:, self.cloud_action_feature_end:self.cloud_info_feature_end]
        client_input = batch_list[:,:, self.cloud_dim_cnt_ori:self.cloud_dim_cnt_ori+self.client_dim_cnt_ori]
        cloud_user_input = batch_list[:,:,self.cloud_user_feature_start:self.cloud_user_feature_end]

        cloud_context_input = cloud_context_input.reshape(batch_size, seq_len, self.cloud_context_feature)
        cloud_action_input = cloud_action_input.reshape(batch_size,seq_len, -1, self.cloud_action_feature)
        cloud_info_input= cloud_info_input.reshape(batch_size,seq_len, -1, self.cloud_info_feature)
        cloud_user_input = cloud_user_input.reshape(batch_size, seq_len, self.cloud_user_feature)

        cloud_output_dict = self.cloud_model(cloud_context_input, cloud_action_input, cloud_info_input, cloud_user_input)

        cloud_context_feature_discrete = cloud_output_dict["cloud_context_feature_discrete"]
        cloud_action_discrete = cloud_output_dict["cloud_action_discrete"]
        cloud_info_discrete = cloud_output_dict["cloud_info_discrete"]
        cloud_context_feature_continue = cloud_output_dict["cloud_context_feature_continue"]
        cloud_action_feature_continue = cloud_output_dict["cloud_action_feature_continue"]
        cloud_info_feature_continue = cloud_output_dict["cloud_info_feature_continue"]

        lane = cloud_action_input[:,:,self.lane_id:self.lane_id + 1,:]
        rtk_single = cloud_action_input[:,:,self.rtk_single_id:self.rtk_single_id + 1,:]
        rtk_combine = cloud_action_input[:,:,self.rtk_combine_id:self.rtk_combine_id + 1,:]
        client_elements = torch.cat([lane, rtk_single, rtk_combine], -2)


        outputs, client_feature_dict = self.client_model(cloud_context_feature_discrete, cloud_action_discrete,
                    cloud_info_discrete, cloud_context_feature_continue, cloud_action_feature_continue, 
                    cloud_info_feature_continue, client_input, client_elements)

        # deep & cross
        # deep
        deep_outputs = self.embedding_fc(outputs)

        # [input_2] pos_emb
        position_emb = self.pos_model(batch_size, client_feature_dict, data['ds_to_sub_end'], padding_mask)
        outputs = outputs + position_emb
        # cross
        cross_outputs = self.cross_net(outputs)

        outputs = torch.cat((deep_outputs, cross_outputs), dim=-1)
        outputs = self.dnn_linear(outputs)

        return outputs

    def trans_element_order(self, element_feature):
        return torch.cat([element_feature[:, :, :self.lane_id, :], 
                element_feature[:, :, self.lane_id + 1:self.rtk_single_id, :],
                element_feature[:, :, self.lane_id:self.lane_id + 1, :],
                element_feature[:, :, self.rtk_single_id:self.rtk_combine_id+1, :]
            ],-2)

