import torch
import numpy as np
from torch.utils.data import Dataset
from easydict import EasyDict

class Dataset4Unsupervise(Dataset):
    def __init__(self, ap_coords, data, mask_cfg):
        super().__init__()
        self.ap_coords = ap_coords
        self.data = data
        self.mask_cfg = mask_cfg
        self.mask_proc = Preprocess4Mask(mask_cfg)

    def __getitem__(self, index):
        instance_ap_coords = self.ap_coords[index]
        instance_data = self.data[index]

        T, E, A, D = instance_data.shape

        frame_indices = np.arange(T-1)
        k = max(1, int(round((T-1) * self.mask_cfg.mask_frame_ratio)))
        mask_frame_idxs = np.random.choice(frame_indices, size=k, replace=False).tolist()

        # For each selected frame, run Preprocess4Mask on that frame (shape [E,A,D])
        masked_frames = instance_data.copy()
        all_mask_pos = []   # list of lists of (frame_idx, (i,j))
        all_mask_labels = []  # list of arrays of shape [n_mask, D]
        for t in mask_frame_idxs:
            frame = instance_data[t]  # [E, A, D]
            masked_frame, mask_pos, target = self.mask_proc(frame)
            masked_frames[t] = masked_frame
            # store positions as (t, i, j)
            pos_with_t = [(t, int(i), int(j)) for (i, j) in mask_pos]
            all_mask_pos.append(pos_with_t)
            all_mask_labels.append(target) 

        # flatten mask_pos and labels for this sample
        flat_mask_pos = [pp for lst in all_mask_pos for pp in lst]  # list of (t,i,j)
        flat_mask_labels = np.concatenate(all_mask_labels, axis=0) if len(all_mask_labels) > 0 else np.zeros((0, D))

        # Return tensors:
        # ap_coords: [T, 7, 6, 2]
        # masked_frames: [T, 7, 6, D]
        # mask_pos: np.array shape [N_mask, 3] (t,i,j)
        # mask_label: [N_mask, D]
        return (torch.from_numpy(instance_ap_coords).float(),
                torch.from_numpy(masked_frames).float(),
                torch.from_numpy(np.array(flat_mask_pos)).long(),
                torch.from_numpy(flat_mask_labels).float(),
                torch.from_numpy(instance_data).float())


    def __len__(self):
        return len(self.data)
    

class Dataset4Supervise(Dataset):
    def __init__(self, ap_coords, data, labels):
        super().__init__()
        self.ap_coords = ap_coords
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        instance_ap_coords = self.ap_coords[index]
        instance_data = self.data[index]
        instance_labels = self.labels[index]

        (frame_num, ensemble_num, seq_ap_num, feature_num) = instance_data.shape

        replicated_labels = np.tile(instance_labels[np.newaxis, :, :], (ensemble_num, 1, 1))
        

        return torch.from_numpy(instance_ap_coords), torch.from_numpy(instance_data), torch.from_numpy(replicated_labels)

    def __len__(self):
        return len(self.data)


class Dataset5Supervise(Dataset):
    def __init__(self, ap_coords, data, labels):
        super().__init__()
        self.ap_coords = ap_coords
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        instance_ap_coords = self.ap_coords[index]
        instance_data = self.data[index]
        instance_labels = self.labels[index]

        (frame_num, ensemble_num, seq_ap_num, feature_num) = instance_data.shape

        replicated_labels = np.tile(instance_labels[np.newaxis, :, :], (ensemble_num, 1, 1))
        replicated_labels = replicated_labels[:,1:,:]
        

        return torch.from_numpy(instance_ap_coords), torch.from_numpy(instance_data), torch.from_numpy(replicated_labels)

    def __len__(self):
        return len(self.data)


class DatasetntpSupervise(Dataset):
    def __init__(self, ap_coords, data, labels):
        super().__init__()
        self.ap_coords = ap_coords
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        instance_ap_coords = self.ap_coords[index]
        instance_data = self.data[index]
        instance_labels = self.labels[index]

        (frame_num, ensemble_num, seq_ap_num, feature_num) = instance_data.shape

        replicated_labels = np.tile(instance_labels[np.newaxis, :, :], (ensemble_num, 1, 1))
        replicated_labels = replicated_labels[:, :frame_num-1, :]
        

        return torch.from_numpy(instance_ap_coords), torch.from_numpy(instance_data), torch.from_numpy(replicated_labels)

    def __len__(self):
        return len(self.data)
    

# ---------------------------
# span mask (与你提供的一致)
# ---------------------------
def span_mask_2d(seq_h, seq_w, max_gram=3, p=0.2, goal_num_predict=10):
    ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
    pvals = p * np.power(1 - p, np.arange(max_gram))
    pvals = pvals / pvals.sum()
    mask_pos = set()
    while len(mask_pos) < goal_num_predict:
        n = np.random.choice(ngrams, p=pvals)
        n = min(n, goal_num_predict - len(mask_pos))
        anchor_h = np.random.randint(seq_h)
        anchor_w = np.random.randint(seq_w)
        for offset in range(n):
            i = anchor_h
            j = anchor_w + offset
            if j >= seq_w:
                break
            mask_pos.add((i, j))
    return list(mask_pos)


class Preprocess4Mask:
    """ 适用于输入 shape=[7, 6, 2] 的 MLM 掩码策略 """
    def __init__(self, mask_cfg):
        self.mask_ratio = mask_cfg.mask_ratio  # 比例，例如 0.3
        self.mask_prob = mask_cfg.mask_prob    # 80% 概率用 0
        self.replace_prob = mask_cfg.replace_prob  # 10% 用随机值
        self.max_gram = mask_cfg.max_gram
        self.span_prob = mask_cfg.mask_alpha  # 控制 span 的稀疏程度

    def __call__(self, instance):
        """
        instance: np.ndarray, shape [7, 6, 2]
        return:
            masked_instance: 同 shape，部分位置被掩码或替换
            mask_pos: 掩码位置索引 List[Tuple[int, int]]
            target: 被掩码的 ground truth 特征 [N_mask, 12]
        """
        E, A, D = instance.shape
        total = E * A
        n_pred = max(1, int(round(total * self.mask_ratio)))

        mask_pos = span_mask_2d(E, A, max_gram=self.max_gram, p=self.span_prob, goal_num_predict=n_pred)

        instance_mask = instance.copy()
        target = []

        for (i, j) in mask_pos:
            target.append(instance[i, j].copy())
            rand = np.random.rand()
            if rand < self.mask_prob:
                instance_mask[i, j] = np.zeros(D)
            elif rand < self.mask_prob + self.replace_prob:
                instance_mask[i, j] = np.random.rand(D)
            else:
                pass
            # else 保留原始

        target = np.stack(target, axis=0)  # [N_mask, 12]
        return instance_mask, np.array(mask_pos), target
    
