import torch
import pandas as pd

def cd_collate_fn_seft(batch, biomarker_list):
    biom2idx = {b: i for i, b in enumerate(biomarker_list)}
    B = len(batch)

    all_records = []
    statics = []
    labels = []
    all_times = []

    for graph_dict, label, _ in batch:
        patient_records = []
        sex = None
        for biomarker, (x, _, _, sampling_dates) in graph_dict.items():
            biom_idx = biom2idx[biomarker]
            for i in range(x.shape[0]):
                value = x[i, 0].item()
                date = pd.to_datetime(sampling_dates[i])
                timestamp = date.value // 10**9  # seconds since epoch
                all_times.append(timestamp)
                if sex is None:
                    sex = x[i, -1].item()
                patient_records.append((timestamp, biom_idx, value))
        statics.append([sex])
        all_records.append(patient_records)
        labels.append(label)

    # Normalize all times to min = 0
    min_time = min(all_times)
    all_records = [[(t - min_time, bidx, v) for (t, bidx, v) in record] for record in all_records]

    max_T = max(t for record in all_records for (t, _, _) in record) + 1
    src = torch.zeros((max_T, B, 2))      # [value, biomarker idx]
    times = torch.zeros((max_T, B))       # sampling times
    static = torch.tensor(statics, dtype=torch.float)
    lengths = torch.zeros(B)

    for b in range(B):
        seen = set()
        for t, bidx, val in all_records[b]:
            t = int(t)
            src[t, b, 0] = val
            src[t, b, 1] = bidx
            times[t, b] = t
            seen.add(t)
        lengths[b] = len(seen)

    y = torch.cat(labels, dim=0)
    return src, static, times, lengths, y