import math
from tomlkit import item
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layer=3):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.lstm = nn.LSTM(input_size, hidden_size, num_layer, bidirectional=True)

    def forward(self, input_data, hidden):
        output, hidden = self.lstm(input_data, hidden)
        return output, hidden

    def initHidden(self, batch_size):
        return (torch.zeros(self.num_layer * 2, batch_size, self.hidden_size).cuda(),
                torch.zeros(self.num_layer * 2, batch_size, self.hidden_size).cuda())


class DecoderLSTM(nn.Module):
    def __init__(self, hidden_size, num_layer=1):
        super(DecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.lstm = nn.LSTM(hidden_size, hidden_size, num_layer)

    def forward(self, input_data, hidden):
        output, hidden = self.lstm(input_data, hidden)
        return output, hidden

    def initHidden(self, batch_size):
        return (torch.zeros(self.num_layer, batch_size, self.hidden_size).cuda(),
                torch.zeros(self.num_layer, batch_size, self.hidden_size).cuda())


class ArbiEncoderLSTM(nn.Module):
    def __init__(self, input_size=2, hidden_size=48, num_layer=3):
        super(ArbiEncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.input_size = input_size
        self.encoder = nn.LSTM(self.input_size, self.hidden_size, self.num_layer, bidirectional=True)
    
    def forward(self, obs_data, seq_lens, hidden):
        obs_data_padded = pack_padded_sequence(obs_data, seq_lens, batch_first=False, enforce_sorted=False)
        trj_encoded, _ = self.encoder(obs_data_padded, hidden)
        trj_padded, length = pad_packed_sequence(trj_encoded, batch_first=False)
        final_encoded = []
        trj_padded = trj_padded.transpose(0, 1)
        for i in range(len(seq_lens)):
            final_encoded.append(trj_padded[i, seq_lens[i] - 1: seq_lens[i]])
        final_encoded = torch.cat(final_encoded, dim=0).unsqueeze(0).cuda()

        return final_encoded

    def initHidden(self, batch_size):
        return (torch.zeros(self.num_layer * 2, batch_size, self.hidden_size).cuda(),
                torch.zeros(self.num_layer * 2, batch_size, self.hidden_size).cuda())


class TransformerEncoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
        super().__init__()
        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=dim,
                                                          nhead=heads,
                                                          dim_feedforward=mlp_dim,
                                                          dropout=dropout,
                                                          activation="gelu")
        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=depth)
        
    def forward(self, x, mask=None):
        x = self.seqTransEncoder(x, src_key_padding_mask=mask)
        return x
    

class TransformerDecoder(nn.Module):
    def __init__(self, dim, depth, heads, mlp_dim, dropout=0.):
        super().__init__()
        seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=dim,
                                                          nhead=heads,
                                                          dim_feedforward=mlp_dim,
                                                          dropout=dropout,
                                                          activation="gelu")
        self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
                                                     num_layers=depth)
        
    def forward(self, timequeries, memory, mask=None):
        output = self.seqTransDecoder(tgt=timequeries, memory=memory,
                                      tgt_key_padding_mask=mask)
        return output


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)
        
    def forward(self, T):
        bs = T.shape[0]
        pos_embedding = []
        for i in range(bs):
            pos_embedding.append(self.encoding[:, T[i]])
        pos_embedding = torch.cat(pos_embedding, dim=0).cuda()
        
        return pos_embedding


class SpatialGCN(nn.Module):
    def __init__(self, num_joints, in_channels, hidden_dim, out_channels):
        super(SpatialGCN, self).__init__()
        self.gcn = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
            # nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, out_channels, kernel_size=1),
            # nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        x = x.permute(0, 3, 2, 1).contiguous()
        bs, T, num_joint, c = x.shape
        x = x.view(bs * T, -1, num_joint, 1)
        x = self.gcn(x)
        return x


class TemporalGCN(nn.Module):
    def __init__(self, num_joints, in_channels, hidden_dim, out_channels):
        super(TemporalGCN, self).__init__()
        self.tcn = nn.Sequential(
            nn.Conv1d(in_channels, hidden_dim, kernel_size=3, padding=1),
            # nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(hidden_dim, out_channels, kernel_size=3, padding=1),
            # nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        bs, num_features, num_joint, T = x.shape
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(bs * num_joint, T, -1).transpose(1, 2)
        x = self.tcn(x)
        # print("x_temporal: ", x.shape)
        x = x.transpose(1, 2).view(bs, num_joint, -1, T)
        x = x.permute(0, 2, 1, 3).contiguous()
        return x


class AGCN(nn.Module):
    def __init__(self, num_joints, inchannels, hidden_dim, outchannels, fuse_type="add"):
        super(AGCN, self).__init__()
        self.spatial_gcn = SpatialGCN(num_joints, inchannels, hidden_dim, outchannels)
        self.temporal_gcn = TemporalGCN(num_joints, inchannels, hidden_dim, outchannels)
        self.fuse_type = fuse_type
        if fuse_type == "concat":
            self.fc = nn.Linear(2 * outchannels, outchannels)

    def forward(self, x):
        spatial_output = self.spatial_gcn(x)
        temporal_output = self.temporal_gcn(x)
        if self.fuse_type == "add":
            output = spatial_output + temporal_output
        elif self.fuse_type == "concat":
            output = torch.cat([spatial_output, temporal_output], dim=-1)
            output = self.fc(output)
        else:
            raise ValueError("Unknown fuse type")
        return output


class ConceptDataset(Dataset):
    def __init__(self, total_info_gather):
        self.total_info_gather = total_info_gather
        
    def __len__(self):
        return len(self.total_info_gather)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sample = self.total_info_gather[idx]
            
        return sample


def collate_helper(data):
    seq_len = [item["keypoint_sequence"].shape[0] for item in data]
    keypoint_sequence_pad = [torch.cat((item["keypoint_sequence"], item["keypoint_sequence"][-1].unsqueeze(0).repeat(max(seq_len)-item["keypoint_sequence"].shape[0], 1, 1)), dim=0) for item in data]
    keypoint_sequence_pad = torch.stack(keypoint_sequence_pad, dim=0)
    # print("keypoint_sequence_pad:", keypoint_sequence_pad[:, :, 20])
    keypoint_sequence_mask = torch.zeros(keypoint_sequence_pad.shape[0], keypoint_sequence_pad.shape[1], 1)
    for i in range(len(seq_len)):
        keypoint_sequence_mask[i, :seq_len[i], :] = 1
    complete_keypoint_sequence_pad = [torch.cat([item["complete_keypoint_sequence"], item["complete_keypoint_sequence"][-1].unsqueeze(0).repeat(max(seq_len)-item["complete_keypoint_sequence"].shape[0], 1, 1)], dim=0) for item in data]
    complete_keypoint_sequence_pad = torch.stack(complete_keypoint_sequence_pad, dim=0)
    anchor_pos_label_pad = [torch.cat([item["anchor_pos_label"], torch.zeros((max(seq_len)-item["anchor_pos_label"].shape[0]))], dim=0) for item in data]
    anchor_pos_label_pad = torch.stack(anchor_pos_label_pad, dim=0)
    
    return {"keypoint_sequence": keypoint_sequence_pad, "complete_keypoint_sequence": complete_keypoint_sequence_pad, "anchor_pos_label": anchor_pos_label_pad, "mask_data": None, "keypoint_sequence_mask": keypoint_sequence_mask, "seq_len": seq_len}
    

def collate_helper_rot6d(data):
    seq_len = [item["joints"].shape[0] for item in data]
    poses_pad = [torch.cat((torch.zeros((max(seq_len)-item["poses"].shape[0], item["poses"].shape[1], item["poses"].shape[2])), item["poses"]), dim=0) for item in data]
    poses_pad = torch.stack(poses_pad, dim=0)
    
    keypoint_sequence_pad = [torch.cat((torch.zeros((max(seq_len)-item["joints"].shape[0], item["joints"].shape[1], item["joints"].shape[2])), item["joints"]), dim=0) for item in data]
    keypoint_sequence_pad = torch.stack(keypoint_sequence_pad, dim=0)
   
    return {"joints": keypoint_sequence_pad, "poses": poses_pad}


def collate_helper_trans(data):
    seq_len = [item["full_keypoint_sequence"].shape[0] for item in data]
    full_keypoint_sequence_pad = [torch.cat((item["full_keypoint_sequence"], item["full_keypoint_sequence"][-1].unsqueeze(0).repeat(60-item["full_keypoint_sequence"].shape[0], 1, 1)), dim=0) for item in data]
    full_keypoint_sequence_pad = torch.stack(full_keypoint_sequence_pad, dim=0)
    # create keypoint_sequence_mask
    full_keypoint_sequence_mask = torch.zeros(full_keypoint_sequence_pad.shape[0], full_keypoint_sequence_pad.shape[1], 1)
    for i in range(len(seq_len)):
        full_keypoint_sequence_mask[i, -seq_len[i]:, :] = 1
    
    keypoint_sequence = [item["keypoint_sequence"] for item in data]
    transition = [item["transition"] for item in data]
    transition = torch.stack(transition, dim=0)
    anchor_pair = [item["anchor_pair"] for item in data]
    anchor_pair = torch.stack(anchor_pair, dim=0)
    duration = [item["duration"] for item in data]
    duration = torch.stack(duration, dim=0)
    
    return {"full_keypoint_sequence": full_keypoint_sequence_pad, "keypoint_sequence": keypoint_sequence, "mask_data": None, "transition": transition, "anchor_pair": anchor_pair, "duration": duration}
    

def collate_helper_afn(data):
    seq_len = [item["keypoint_sequence"].shape[0] for item in data]
    keypoint_sequence_pad = [torch.cat([item["keypoint_sequence"], item["keypoint_sequence"][-1:].repeat(max(seq_len)-item["keypoint_sequence"].shape[0], 1, 1)], dim=0) for item in data]
    keypoint_sequence_pad = torch.stack(keypoint_sequence_pad, dim=0)
    keypoint_sequence_mask = torch.zeros(keypoint_sequence_pad.shape[0], keypoint_sequence_pad.shape[1], 1)
    for i in range(len(seq_len)):
        keypoint_sequence_mask[i, :seq_len[i], :] = 1
    gt_anchor = [item["gt_anchor"] for item in data]
    gt_anchor = torch.stack(gt_anchor, dim=0)
    anchor_class = [item["anchor_class"] for item in data]
    anchor_class = torch.stack(anchor_class, dim=0)
    anchor_class = anchor_class.long()
    anchor_class_onehot = [item["anchor_class_onehot"] for item in data]
    anchor_class_onehot = torch.stack(anchor_class_onehot, dim=0)
    anchor_pos = [item["anchor_pos"] for item in data]
    anchor_pos = torch.stack(anchor_pos, dim=0)
    
    return {"keypoint_sequence": keypoint_sequence_pad, "gt_anchor": gt_anchor, "anchor_class": anchor_class, "anchor_class_onehot": anchor_class_onehot, "anchor_pos": anchor_pos, "mask_data": None}


def collate_helper_aprn(data):
    seq_len = [item["full_keypoint_sequence"].shape[0] for item in data]
    full_keypoint_sequence_pad = [torch.cat((item["full_keypoint_sequence"], item["full_keypoint_sequence"][-1].unsqueeze(0).repeat(max(seq_len)-item["full_keypoint_sequence"].shape[0], 1, 1)), dim=0) for item in data]
    full_keypoint_sequence_pad = torch.stack(full_keypoint_sequence_pad, dim=0)
        
    seq_len = [item["keypoint_sequence"].shape[0] for item in data]
    keypoint_sequence_pad = [torch.cat((item["keypoint_sequence"], item["keypoint_sequence"][-1].unsqueeze(0).repeat(max(seq_len)-item["keypoint_sequence"].shape[0], 1, 1)), dim=0) for item in data]
    keypoint_sequence_pad = torch.stack(keypoint_sequence_pad, dim=0)
    keypoint_sequence_mask = torch.zeros(keypoint_sequence_pad.shape[0], keypoint_sequence_pad.shape[1], 1)
    for i in range(len(seq_len)):
        keypoint_sequence_mask[i, :seq_len[i], :] = 1
    transition = [item["transition"] for item in data]
    transition = torch.stack(transition, dim=0)
    duration = [item["duration"] for item in data]
    duration = torch.stack(duration, dim=0)
    pos_record = [item["pos_record"] for item in data]
    pos_record = torch.stack(pos_record, dim=0)
    
    return {"full_keypoint_sequence": full_keypoint_sequence_pad, "keypoint_sequence": keypoint_sequence_pad, "keypoint_sequence_mask": keypoint_sequence_mask, "mask_data": None, "transition": transition, "duration": duration, "pos_record": pos_record}


def collate_helper_refine(data):
    seq_len = [item["keypoint_sequence"].shape[0] for item in data]
    keypoint_sequence = [item["keypoint_sequence"] for item in data]
    keypoint_sequence = torch.stack(keypoint_sequence, dim=0)
    keypoint_sequence_norm = [item["keypoint_sequence_norm"] for item in data]
    keypoint_sequence_norm = torch.stack(keypoint_sequence_norm, dim=0)
    global_position = [item["global_position"] for item in data]
    global_position = torch.stack(global_position, dim=0)
    
    return {"keypoint_sequence": keypoint_sequence, "keypoint_sequence_norm": keypoint_sequence_norm, "seq_len": seq_len, "global_position": global_position}


def collate_helper_aprn_ar(data):
    full_keypoint_sequence = [item["full_keypoint_sequence"] for item in data]
    full_keypoint_sequence = torch.stack(full_keypoint_sequence, dim=0)
    keypoint_sequence = [item["keypoint_sequence"] for item in data]
    keypoint_sequence = torch.stack(keypoint_sequence, dim=0)
    transition = [item["transition"] for item in data]
    transition = torch.stack(transition, dim=0)
    duration = [item["duration"] for item in data]
    duration = torch.stack(duration, dim=0)
    
    return {"full_keypoint_sequence": full_keypoint_sequence, "keypoint_sequence": keypoint_sequence, "transition": transition, "duration": duration}
