#!/usr/bin/env python

import argparse
import logging
import os
import random
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from transformers import BertConfig, BertModel
from tqdm import tqdm  # Progress bar
from torcheval.metrics import ReciprocalRank, HitRate  # For ranking metrics

# -----------------------------------------------------------------------------
# Configure Logging
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# 1. Generate Vocabulary from CSV Data
# -----------------------------------------------------------------------------
def generate_vocab(csv_file, event_set_filter, 
                   pad_token="[PAD]", unk_token="[UNK]", 
                   cls_token="[CLS]", mask_token="[MASK]"):
    logger.info("Generating vocabulary from CSV file: %s", csv_file)
    df = pd.read_csv(csv_file)
    if event_set_filter:
        df = df[df["event_set"] == event_set_filter]
    unique_events = df["event"].dropna().unique()
    tokens = set()
    for event in unique_events:
        event = str(event)
        if len(event) < 8:
            continue
        geohash = event[:8]
        parts = [geohash[i:i+2] for i in range(0, 8, 2)]
        prefixes = ["gh12-", "gh34-", "gh56-", "gh78-"]
        for prefix, part in zip(prefixes, parts):
            tokens.add(f"{prefix}{part}")
    vocab = {
        pad_token: 0,
        unk_token: 1,
        cls_token: 2,
        mask_token: 3,
    }
    for token in sorted(tokens):
        vocab[token] = len(vocab)
    logger.info("Generated vocabulary with %d tokens", len(vocab))
    return vocab


# -----------------------------------------------------------------------------
# 2. Define the GeohashTokenizer
# -----------------------------------------------------------------------------
class GeohashTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.inv_vocab = {v: k for k, v in vocab.items()}
        self.cls_token = "[CLS]"
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"
        self.mask_token = "[MASK]"

    def tokenize_geohash(self, geohash):
        """Tokenizes a geohash string into its 4 tokens."""
        geohash = geohash[:8]
        parts = [geohash[i:i+2] for i in range(0, 8, 2)]
        prefixes = ["gh12-", "gh34-", "gh56-", "gh78-"]
        return [f"{prefix}{part}" for prefix, part in zip(prefixes, parts)]
    
    def encode(self, tokens, max_length=None, padding='max_length', truncation=True):
        """
        Encodes a list of tokens into their corresponding IDs.
        It tokenizes geohash strings and appends the [CLS] token.
        """
        processed_tokens = []
        for token in tokens:
            if token in [self.cls_token, self.pad_token, self.unk_token, self.mask_token]:
                processed_tokens.append(token)
            else:
                processed_tokens.extend(self.tokenize_geohash(token))
        tokens_with_special = processed_tokens + [self.cls_token]
        ids = [self.vocab.get(tok, self.vocab[self.unk_token]) for tok in tokens_with_special]
        if truncation and max_length is not None:
            ids = ids[-max_length:]
        if padding == 'max_length' and max_length is not None:
            ids = [self.vocab[self.pad_token]] * (max_length - len(ids)) + ids
        return ids

    def decode(self, ids):
        """Decodes a list of token IDs back into tokens."""
        return [self.inv_vocab.get(i, self.unk_token) for i in ids]


# -----------------------------------------------------------------------------
# 3. Process Sequences from CSV Data
# -----------------------------------------------------------------------------
def load_and_process_sequences(csv_file, event_set_filter, max_len, pad_token="[PAD]"):
    logger.info("Loading CSV file: %s", csv_file)
    df = pd.read_csv(csv_file)
    if event_set_filter:
        df = df[df["event_set"] == event_set_filter]
    df = df.sort_values(by=["uid", "timestamp"])
    user_groups = df.groupby("uid")["event"].apply(list)
    uids = user_groups.index.tolist()
    def pad_or_truncate(seq, max_len=max_len, pad_token=pad_token):
        if len(seq) > max_len:
            return seq[-max_len:]
        else:
            num_pad = max_len - len(seq)
            return [pad_token] * num_pad + seq
    sequences = user_groups.apply(lambda seq: pad_or_truncate(seq, max_len=max_len)).tolist()
    logger.info("Processed %d user sequences", len(sequences))
    return uids, sequences


# -----------------------------------------------------------------------------
# 4. Define the Dataset for Next Event Prediction
# -----------------------------------------------------------------------------
class NextEventDataset(Dataset):
    """
    For each user sequence, generate sliding-window examples.
      - Input: the last max_events events immediately preceding the target.
      - Target: the event immediately following that window.
    Only examples with non-[PAD] tokens are kept.
    Each example is a tuple: (uid, input_ids, target_ids)
    """
    def __init__(self, uids, sequences, tokenizer, max_events, max_examples):
        self.examples = []
        self.tokenizer = tokenizer
        self.max_seq_tokens = max_events * 4 + 1

        num_users_with_examples = 0
        for uid, seq in tqdm(zip(uids, sequences), total=len(uids), desc="Creating NextEventDataset"):
            user_examples = []
            for i in range(len(seq) - 1, max_events - 1, -1):
                input_seq = seq[i - max_events : i]
                target_event = seq[i]
                if target_event == tokenizer.pad_token:
                    break
                user_examples.append((uid, input_seq, target_event))
                if len(user_examples) == max_examples:
                    break
            if user_examples:
                num_users_with_examples += 1
                self.examples.extend(user_examples)
        
        logger.info("NextEventDataset: Created %d examples from %d users out of %d user sequences", 
                    len(self.examples), num_users_with_examples, len(uids))
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        uid, input_seq, target_event = self.examples[idx]
        input_ids = self.tokenizer.encode(input_seq, max_length=self.max_seq_tokens,
                                            padding="max_length", truncation=True)
        target_tokens = self.tokenizer.tokenize_geohash(target_event)
        target_ids = [self.tokenizer.vocab.get(tok, self.tokenizer.vocab[self.tokenizer.unk_token])
                      for tok in target_tokens]
        return uid, torch.tensor(input_ids, dtype=torch.long), torch.tensor(target_ids, dtype=torch.long)


# -----------------------------------------------------------------------------
# 5. Define the Base Next Event Model (without graph integration)
# -----------------------------------------------------------------------------
class NextEventModel(nn.Module):
    def __init__(self, bert_model, hidden_size, emb_dim, vocab_size, event_token_count=4):
        super(NextEventModel, self).__init__()
        self.bert = bert_model
        self.projection = nn.Linear(hidden_size, emb_dim * event_token_count)
        self.event_token_count = event_token_count
        self.emb_dim = emb_dim
        self.token_embeddings = nn.Embedding(vocab_size, emb_dim)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        # Use the embedding from the last token.
        seq_embedding = outputs.last_hidden_state[:, -1, :]
        projected = self.projection(seq_embedding)  # (B, emb_dim * event_token_count)
        projected = projected.view(-1, self.event_token_count, self.emb_dim)  # (B, 4, emb_dim)
        return projected


# -----------------------------------------------------------------------------
# 5a. Next Event Model with Graph Embedding Integrated After BERT (Graph-After)
# -----------------------------------------------------------------------------
class NextEventModelGraphAfterBERT(nn.Module):
    def __init__(self, bert_model, hidden_size, emb_dim, vocab_size, uid_embedding_dict, graph_emb_dim, event_token_count=4):
        super(NextEventModelGraphAfterBERT, self).__init__()
        self.bert = bert_model
        self.uid_embedding_dict = uid_embedding_dict
        self.graph_emb_dim = graph_emb_dim
        if graph_emb_dim != hidden_size:
            self.uid_proj = nn.Linear(graph_emb_dim, hidden_size)
        else:
            self.uid_proj = None
        self.projection = nn.Linear(2 * hidden_size, emb_dim * event_token_count)
        self.event_token_count = event_token_count
        self.emb_dim = emb_dim
        self.token_embeddings = nn.Embedding(vocab_size, emb_dim)
    
    def forward(self, input_ids, uid_list, attention_mask=None):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        seq_embedding = outputs.last_hidden_state[:, -1, :]  # (B, hidden_size)
        graph_embeds = []
        for uid in uid_list:
            # Look up the precomputed graph embedding; if missing, use zeros.
            graph_embeds.append(self.uid_embedding_dict.get(uid, torch.zeros(self.graph_emb_dim, device=input_ids.device)))
        graph_embeds = torch.stack(graph_embeds, dim=0)  # (B, graph_emb_dim)
        if self.uid_proj is not None:
            graph_embeds = self.uid_proj(graph_embeds)  # (B, hidden_size)
        combined = torch.cat([seq_embedding, graph_embeds], dim=-1)  # (B, 2*hidden_size)
        projected = self.projection(combined)  # (B, emb_dim * event_token_count)
        projected = projected.view(-1, self.event_token_count, self.emb_dim)
        return projected


# -----------------------------------------------------------------------------
# 5b. Next Event Model with Graph Embedding Integrated as Input Token to BERT (Graph-Input)
# -----------------------------------------------------------------------------
class NextEventModelGraphInputBERT(nn.Module):
    def __init__(self, bert_model, hidden_size, emb_dim, vocab_size, uid_embedding_dict, graph_emb_dim, event_token_count=4):
        super(NextEventModelGraphInputBERT, self).__init__()
        self.bert = bert_model
        self.uid_embedding_dict = uid_embedding_dict
        self.graph_emb_dim = graph_emb_dim
        if graph_emb_dim != hidden_size:
            self.uid_proj = nn.Linear(graph_emb_dim, hidden_size)
        else:
            self.uid_proj = None
        self.projection = nn.Linear(hidden_size, emb_dim * event_token_count)
        self.event_token_count = event_token_count
        self.emb_dim = emb_dim
        self.token_embeddings = nn.Embedding(vocab_size, emb_dim)
    
    def forward(self, input_ids, uid_list, attention_mask=None):
        # Get token embeddings from BERT's embedding layer.
        token_embeds = self.bert.embeddings.word_embeddings(input_ids)  # (B, seq_length, hidden_size)
        uid_embeds = []
        for uid in uid_list:
            uid_embeds.append(self.uid_embedding_dict.get(uid, torch.zeros(self.graph_emb_dim, device=input_ids.device)))
        uid_embeds = torch.stack(uid_embeds, dim=0)  # (B, graph_emb_dim)
        if self.uid_proj is not None:
            uid_embeds = self.uid_proj(uid_embeds)  # (B, hidden_size)
        uid_embeds = uid_embeds.unsqueeze(1)  # (B, 1, hidden_size)
        inputs_embeds = torch.cat([uid_embeds, token_embeds], dim=1)  # (B, seq_length+1, hidden_size)
        if attention_mask is not None:
            uid_mask = torch.ones(attention_mask.size(0), 1, device=attention_mask.device, dtype=attention_mask.dtype)
            attention_mask = torch.cat([uid_mask, attention_mask], dim=1)
        else:
            attention_mask = torch.ones(inputs_embeds.size(0), inputs_embeds.size(1), device=inputs_embeds.device)
        outputs = self.bert(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
        uid_output = outputs.last_hidden_state[:, 0, :]  # use the first token output
        projected = self.projection(uid_output)  # (B, emb_dim * event_token_count)
        projected = projected.view(-1, self.event_token_count, self.emb_dim)
        return projected


# -----------------------------------------------------------------------------
# 6. NegativeSampler Classes
# -----------------------------------------------------------------------------
class NegativeSampler:
    """
    Samples negative token IDs for each geohash token position from the subset
    of vocabulary tokens that have the corresponding prefix.
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.samples = {}
        self.samples[0] = [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh12-")]
        self.samples[1] = [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh34-")]
        self.samples[2] = [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh56-")]
        self.samples[3] = [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh78-")]
    
    def sample(self, batch_size, num_negative_samples):
        negatives = []
        for pos in range(4):
            candidate_ids = self.samples[pos]
            candidate_ids_tensor = torch.tensor(candidate_ids)
            indices = torch.randint(low=0, high=len(candidate_ids), size=(batch_size, num_negative_samples))
            neg_ids = candidate_ids_tensor[indices]
            negatives.append(neg_ids)
        negatives = torch.stack(negatives, dim=1)  # (B, 4, num_negative_samples)
        return negatives


class NegativeSamplerOptimized:
    """
    Optimized NegativeSampler performs two-step negative sampling in a vectorized way:
      1. For each positive sample (of 4 tokens), uniformly choose m (1 to 4) 
         as the number of tokens to swap.
      2. For each candidate negative, keep the first (4 - m) tokens from the positive sample 
         and sample negatives for the last m positions based on corresponding prefixes.
         
    The candidate tokens for each position are:
      - Position 0: tokens starting with "gh12-"
      - Position 1: tokens starting with "gh34-"
      - Position 2: tokens starting with "gh56-"
      - Position 3: tokens starting with "gh78-"
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # Build candidate lists for each position.
        self.samples = {
            0: [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh12-")],
            1: [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh34-")],
            2: [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh56-")],
            3: [token_id for token, token_id in tokenizer.vocab.items() if token.startswith("gh78-")]
        }
        # Precompute candidate tensors for each position.
        self.samples_tensors = {
            pos: torch.tensor(self.samples[pos], dtype=torch.long) for pos in range(4)
        }

    def sample(self, pos_token_ids, num_negative_samples):
        """
        Vectorized negative sampling.
        
        Args:
            pos_token_ids (Tensor): Positive samples of shape (B, 4).
            num_negative_samples (int): Number of negative candidates per positive sample.
            
        Returns:
            Tensor: Negative token IDs of shape (B, num_negative_samples, 4).
        """
        batch_size = pos_token_ids.size(0)
        device = pos_token_ids.device

        # Generate m for each sample & candidate: shape (B, num_negative_samples), values in {1,2,3,4}.
        m_tensor = torch.randint(1, 5, (batch_size, num_negative_samples), device=device)

        neg_tokens = []
        # Loop over token positions (0 to 3) (only 4 iterations).
        for pos in range(4):
            # For each candidate, if pos < (4 - m) then keep the positive token; else, sample a negative.
            condition = (pos < (4 - m_tensor))
            pos_tokens = pos_token_ids[:, pos].unsqueeze(1).expand(batch_size, num_negative_samples)
            candidates = self.samples_tensors[pos].to(device)
            n_candidates = candidates.size(0)
            rand_indices = torch.randint(0, n_candidates, (batch_size, num_negative_samples), device=device)
            sampled_tokens = candidates[rand_indices]
            neg_pos = torch.where(condition, pos_tokens, sampled_tokens)
            neg_tokens.append(neg_pos)
        negatives = torch.stack(neg_tokens, dim=-1)  # (B, num_negative_samples, 4)
        return negatives


# -----------------------------------------------------------------------------
# 7. Event Predictor Module
# -----------------------------------------------------------------------------
class EventPredictor(nn.Module):
    """
    Predicts a score between a predicted event embedding and a candidate event embedding.
    This module takes the full predicted event embedding (flattened) as input.
    """
    def __init__(self, in_channels):
        super(EventPredictor, self).__init__()
        self.lin_src = nn.Linear(in_channels, in_channels)
        self.lin_dst = nn.Linear(in_channels, in_channels)
        self.lin_final = nn.Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h).squeeze(-1)


# -----------------------------------------------------------------------------
# 8. Vanilla Binary Cross-Entropy Loss Function using EventPredictor (bce_loss)
# -----------------------------------------------------------------------------
def bce_loss(predicted, target_token_ids, model, negative_sampler, num_negative_samples, predictor):
    """
    Computes the vanilla binary cross entropy loss for geohash predictions.
    """
    batch_size, num_positions, emb_dim = predicted.shape
    # Flatten full geohash predictions.
    z_pred = predicted.view(batch_size, -1)  # (B, 4*emb_dim)
    pos_embeds = model.token_embeddings(target_token_ids)  # (B, 4, emb_dim)
    z_pos = pos_embeds.view(batch_size, -1)  # (B, 4*emb_dim)
    
    s_pos = predictor(z_pred, z_pos)  # (B,)
    
    # Depending on which negative sampler is used, call sample appropriately.
    if isinstance(negative_sampler, NegativeSamplerOptimized):
        negatives = negative_sampler.sample(target_token_ids, num_negative_samples).to(predicted.device)
        # Rearrange from (B, num_negative_samples, 4) to (B, 4, num_negative_samples)
        negatives = negatives.permute(0, 2, 1)
    else:
        negatives = negative_sampler.sample(batch_size, num_negative_samples).to(predicted.device)
    
    neg_embeds = model.token_embeddings(negatives)  # (B, 4, num_negative_samples, emb_dim)
    # Reshape negatives to (B, num_negative_samples, 4*emb_dim)
    z_neg = neg_embeds.permute(0, 2, 1, 3).reshape(batch_size, num_negative_samples, -1)
    z_pred_expanded = z_pred.unsqueeze(1).expand_as(z_neg)  # (B, num_negative_samples, 4*emb_dim)
    
    s_neg = predictor(z_pred_expanded, z_neg)  # (B, num_negative_samples)
    
    pos_loss = F.binary_cross_entropy_with_logits(s_pos, torch.ones_like(s_pos), reduction='sum')
    neg_loss = F.binary_cross_entropy_with_logits(s_neg, torch.zeros_like(s_neg), reduction='sum')
    total_labels = s_pos.numel() + s_neg.numel()
    loss = (pos_loss + neg_loss) / total_labels
    return loss


# -----------------------------------------------------------------------------
# 9. Hierarchical Binary Cross-Entropy Loss Function using EventPredictors (hierarchical_bce_loss)
# -----------------------------------------------------------------------------
def hierarchical_bce_loss(predicted, target_token_ids, model, negative_sampler, num_negative_samples, predictor_list):
    batch_size, num_positions, emb_dim = predicted.shape  # num_positions is 4
    total_loss = 0.0
    total_labels = 0
    device = predicted.device

    # Sample negatives once for the entire batch.
    if isinstance(negative_sampler, NegativeSamplerOptimized):
        negatives_full = negative_sampler.sample(target_token_ids, num_negative_samples).to(device)  # (B, num_negative_samples, 4)
        negatives_full = negatives_full.permute(0, 2, 1)  # (B, 4, num_negative_samples)
    else:
        negatives_full = negative_sampler.sample(batch_size, num_negative_samples).to(device)  # (B, 4, num_negative_samples)

    for i in range(num_positions):  # i = 0, 1, 2, 3
        prefix_length = i + 1
        # Predicted prefix (flattened): (B, (i+1)*emb_dim)
        pred_prefix = predicted[:, :prefix_length, :].reshape(batch_size, -1)
        # True prefix: (B, (i+1)*emb_dim)
        true_prefix = model.token_embeddings(target_token_ids[:, :prefix_length]).reshape(batch_size, -1)
        s_pos = predictor_list[i](pred_prefix, true_prefix)  # (B,)
        
        neg_token_ids = negatives_full[:, :prefix_length, :] # (B, prefix_length, num_negative_samples)
        neg_token_embeds = model.token_embeddings(neg_token_ids)  # (B, prefix_length, num_negative_samples, emb_dim)
        neg_prefix = neg_token_embeds.permute(0, 2, 1, 3).reshape(batch_size, num_negative_samples, -1) # (B, num_negative_samples, prefix_length*emb_dim)
        
        pred_prefix_expanded = pred_prefix.unsqueeze(1).expand_as(neg_prefix)  # (B, num_negative_samples, (i+1)*emb_dim)
        s_neg = predictor_list[i](pred_prefix_expanded, neg_prefix)  # (B, num_negative_samples)
        
        if i < 3:
            true_token_i = target_token_ids[:, :prefix_length].unsqueeze(2).expand(batch_size, prefix_length, num_negative_samples)
            neg_labels = torch.all(neg_token_ids == true_token_i, dim=1).float()
        else:
            neg_labels = torch.zeros_like(s_neg)
        
        pos_loss = F.binary_cross_entropy_with_logits(s_pos, torch.ones_like(s_pos), reduction='sum')
        neg_loss = F.binary_cross_entropy_with_logits(s_neg, neg_labels, reduction='sum')
        level_loss = pos_loss + neg_loss
        
        total_loss += level_loss
        total_labels += s_pos.numel() + s_neg.numel()
    
    loss = total_loss / total_labels
    return loss


# -----------------------------------------------------------------------------
# 10. Build Training Input Dictionary
# -----------------------------------------------------------------------------
def build_train_input_dict(csv_file, event_set_filter, max_events):
    uids, sequences = load_and_process_sequences(csv_file, event_set_filter, max_events)
    train_dict = {}
    for uid, seq in zip(uids, sequences):
        train_dict[uid] = seq[-max_events:]
    return train_dict


# -----------------------------------------------------------------------------
# 11. Create Evaluation Dataset (Combined Positive and Negative)
# -----------------------------------------------------------------------------
class EvalDataset(Dataset):
    def __init__(self, train_input_dict, pos_csv, neg_csv, tokenizer, max_events, max_seq_tokens):
        self.samples = []
        pos_df = pd.read_csv(pos_csv)
        neg_df = pd.read_csv(neg_csv)
        
        num_negatives = len(neg_df) // len(pos_df)
        
        pos_group = pos_df.groupby("uid", sort=False)["event"].apply(list).to_dict()
        neg_group = neg_df.groupby("uid", sort=False)["event"].apply(list).to_dict()

        for uid, pos_events in tqdm(pos_group.items(), desc="Creating EvalDataset", total=len(pos_group)):
            # Get the input sequence for the uid (or use all [PAD] if not available)
            if uid in train_input_dict:
                input_seq = train_input_dict[uid]
            else:
                input_seq = ["[PAD]"] * max_events
            input_ids = tokenizer.encode(input_seq, max_length=max_seq_tokens,
                                         padding="max_length", truncation=True)

            negatives = neg_group[uid]
            # For each positive event, assign consecutive negatives.
            for i, pos_event in enumerate(pos_events):
                start_idx = i * num_negatives
                end_idx = (i + 1) * num_negatives

                # Process the positive candidate.
                pos_tokens = tokenizer.tokenize_geohash(pos_event)
                pos_ids = [tokenizer.vocab.get(tok, tokenizer.vocab[tokenizer.unk_token]) for tok in pos_tokens]

                # Process the consecutive block of negatives.
                neg_ids_list = []
                for event in negatives[start_idx:end_idx]:
                    tokens = tokenizer.tokenize_geohash(event)
                    candidate_ids = [tokenizer.vocab.get(tok, tokenizer.vocab[tokenizer.unk_token]) for tok in tokens]
                    neg_ids_list.append(candidate_ids)

                candidates = [pos_ids] + neg_ids_list
                candidates_tensor = torch.tensor(candidates, dtype=torch.long)

                self.samples.append((uid, torch.tensor(input_ids, dtype=torch.long), candidates_tensor))

        logger.info("Created evaluation dataset with %d samples", len(self.samples))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]


# -----------------------------------------------------------------------------
# 12. Compute Ranking Metrics using torcheval
# -----------------------------------------------------------------------------
def compute_ranking_metrics(candidates, target, k_list):
    mrr_metric = ReciprocalRank()
    hits_metrics = {k: HitRate(k=k) for k in k_list}
    
    mrr_metric.update(candidates, target)
    for k, hits_metric in hits_metrics.items():
        hits_metric.update(candidates, target)
    
    mrr = mrr_metric.compute().mean().item()
    hits_results = {f"Hits@{k}": hits_metrics[k].compute().mean().item() for k in k_list}
    return {"MRR": mrr, **hits_results}


# -----------------------------------------------------------------------------
# 13. Evaluate Model with Combined Evaluation Dataset and Subpopulation Metrics
# -----------------------------------------------------------------------------
def evaluate_model(model, eval_dataset, tokenizer, device, batch_size, k_list, predictor, graph_mode, num_workers):
    model.eval()
    predictor.eval()
    loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    all_candidates = []
    # Dictionary to accumulate scores for each subgroup.
    group_candidates = {"0": [], "1": [], "2-5": [], "6-10": [], "11-50": [], ">50": []}
    pad_token_id = tokenizer.vocab[tokenizer.pad_token]
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating Combined Dataset", leave=False):
            uids, input_ids, candidate_ids = batch  # candidate_ids: (B, num_candidates, 4)
            input_ids = input_ids.to(device)
            candidate_ids = candidate_ids.to(device)
            attention_mask = (input_ids != pad_token_id).long().to(device)
            # Ensure minimal attention on the last tokens.
            attention_mask[:, -2:] = 1
            
            if graph_mode != "none":
                predicted_tokens = model(input_ids, uids, attention_mask=attention_mask)
            else:
                predicted_tokens = model(input_ids, attention_mask=attention_mask)
            z_pred = predicted_tokens.view(predicted_tokens.size(0), -1)
            
            B, num_candidates, _ = candidate_ids.shape
            candidate_embeds = model.token_embeddings(candidate_ids)
            z_candidate = candidate_embeds.view(B, num_candidates, -1)
            
            z_pred_expanded = z_pred.unsqueeze(1).expand_as(z_candidate)
            scores = predictor(z_pred_expanded, z_candidate)  # (B, num_candidates)
            all_candidates.append(scores.cpu())
            
            # Compute the number of non-pad tokens per sample.
            non_pad_count = (input_ids != pad_token_id).sum(dim=1)
            # Subtract one for the [CLS] token and divide by 4 (each event is 4 tokens).
            non_pad_events = ((non_pad_count - 1) // 4).cpu()  # (B,)
            
            # Assign each sample to a subgroup based on non_pad_events.
            for i in range(B):
                n_events = non_pad_events[i].item()
                if n_events == 0:
                    group = "0"
                elif n_events == 1:
                    group = "1"
                elif 2 <= n_events <= 5:
                    group = "2-5"
                elif 6 <= n_events <= 10:
                    group = "6-10"
                elif 11 <= n_events <= 50:
                    group = "11-50"
                else:
                    group = ">50"
                group_candidates[group].append(scores[i].cpu())
    
    # Compute global metrics.
    candidates = torch.cat(all_candidates, dim=0)
    num_samples = candidates.size(0)
    target_indices = torch.zeros(num_samples, dtype=torch.long)  # Positive candidate is at index 0.
    global_metrics = compute_ranking_metrics(candidates, target_indices, k_list)
    
    # Compute metrics for each subgroup.
    group_metrics = {}
    for group, score_list in group_candidates.items():
        population_count = len(score_list)
        if population_count > 0:
            # Use stack to preserve each sample as a row.
            group_scores = torch.stack(score_list, dim=0)  # (N, num_candidates)
            group_target_indices = torch.zeros(group_scores.size(0), dtype=torch.long)
            metrics = compute_ranking_metrics(group_scores, group_target_indices, k_list)
            group_metrics[group] = {"population": population_count, "metrics": metrics}
        else:
            group_metrics[group] = {"population": 0, "metrics": None}
    
    # Concise printing of subgroup metrics with population sizes.
    logger.info("Global Evaluation Metrics: %s", global_metrics)
    for group in ["0", "1", "2-5", "6-10", "11-50", ">50"]:
        pop = group_metrics[group]["population"]
        if group_metrics[group]["metrics"] is not None:
            logger.info("Group %s non-padded events (N=%d): %s", group, pop, group_metrics[group]["metrics"])
        else:
            logger.info("Group %s non-padded events (N=%d): No samples", group, pop)
    
    # Return both overall and subgroup metrics.
    return {"global": global_metrics, "groups": group_metrics}


# -----------------------------------------------------------------------------
# 14. Training Function with Periodic Evaluation
# -----------------------------------------------------------------------------
def train_model(model, train_loader, optimizer, tokenizer, device, num_epochs, 
                negative_sampler, num_negative_samples, eval_every, eval_dataset, k_list, loss_fn, predictor, graph_mode, num_workers):
    for epoch in range(num_epochs):
        model.train()
        predictor.train()
        epoch_loss = 0.0
        logger.info("Starting Epoch %d/%d", epoch + 1, num_epochs)
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training", leave=False):
            uids, input_ids, target_ids = batch
            input_ids = input_ids.to(device)
            target_ids = target_ids.to(device)
            attention_mask = (input_ids != tokenizer.vocab[tokenizer.pad_token]).long().to(device)
            attention_mask[:, -2:] = 1
            optimizer.zero_grad()
            if graph_mode != "none":
                predicted_tokens = model(input_ids, uids, attention_mask=attention_mask)
            else:
                predicted_tokens = model(input_ids, attention_mask=attention_mask)
            loss = loss_fn(predicted_tokens, target_ids, model, negative_sampler, num_negative_samples, predictor)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(train_loader)
        logger.info("Epoch %d/%d: Training Avg Loss: %.4f", epoch + 1, num_epochs, avg_loss)
        
        if (epoch + 1) % eval_every == 0:
            eval_predictor = predictor[-1] if isinstance(predictor, torch.nn.ModuleList) else predictor
            eval_result = evaluate_model(model, eval_dataset, tokenizer, device, train_loader.batch_size, k_list, eval_predictor, graph_mode, num_workers)
            # logger.info("Epoch %d: Evaluation Result: %s", epoch + 1, eval_result)


# -----------------------------------------------------------------------------
# 15. Main Function for Training and Evaluation
# -----------------------------------------------------------------------------
def main(args):
    device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    logger.info("Using device: %s", device)

    os.makedirs("stored/", exist_ok=True)
    model_save_path = os.path.join("stored", f"{args.experiment_name}_model.pth")
    
    # File paths.
    train_file = os.path.join(args.input_path, "personal_train.csv")
    val_file = os.path.join(args.input_path, "personal_val.csv")
    val_neg_file = os.path.join(args.input_path, "personal_val_negative_sample.csv")
    test_file = os.path.join(args.input_path, "personal_test.csv")
    test_neg_file = os.path.join(args.input_path, "personal_test_negative_sample.csv")
    
    # Prepare vocabulary and tokenizer.
    vocab = generate_vocab(train_file, args.event_set)
    tokenizer = GeohashTokenizer(vocab)
    
    # Negative sampler selection.
    if args.negative_sampler == "optimized":
        negative_sampler = NegativeSamplerOptimized(tokenizer)
        logger.info("Using Optimized Negative Sampler")
    else:
        negative_sampler = NegativeSampler(tokenizer)
        logger.info("Using Original Negative Sampler")
    
    # Build training input dictionary.
    train_input_dict = build_train_input_dict(train_file, args.event_set, args.max_events)
    
    # Load training dataset for training.
    max_len = args.max_events + args.max_examples
    train_uids, train_seqs = load_and_process_sequences(train_file, args.event_set, max_len)
    train_dataset = NextEventDataset(train_uids, train_seqs, tokenizer, args.max_events, args.max_examples)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    
    # Create evaluation datasets.
    max_seq_tokens = args.max_events * 4 + 1
    val_dataset = EvalDataset(train_input_dict, val_file, val_neg_file, tokenizer, args.max_events, max_seq_tokens)
    test_dataset = EvalDataset(train_input_dict, test_file, test_neg_file, tokenizer, args.max_events, max_seq_tokens)
    
    # Load graph embeddings if enabled.
    if args.graph_embedding_mode != "none":
        if args.graph_embedding_file == "":
            logger.error("Graph embedding mode is enabled but no graph_embedding_file provided")
            exit(1)
        with open(args.graph_embedding_file, 'rb') as f:
            uid_embedding_dict = pickle.load(f)
        if args.graph_embedding_dim == 0:
            sample_key = next(iter(uid_embedding_dict))
            args.graph_embedding_dim = uid_embedding_dict[sample_key].shape[0]
        logger.info("Loaded graph embeddings with dimension %d", args.graph_embedding_dim)
    else:
        uid_embedding_dict = None

    # Determine effective max position embeddings.
    if args.graph_embedding_mode == "input":
        effective_max_seq_tokens = max_seq_tokens + 1
    else:
        effective_max_seq_tokens = max_seq_tokens

    # Configure BERT.
    config = BertConfig(
        vocab_size=len(vocab),
        hidden_size=args.hidden_size,
        num_hidden_layers=args.num_hidden_layers,
        num_attention_heads=args.num_attention_heads,
        intermediate_size=args.hidden_size * 2,
        max_position_embeddings=effective_max_seq_tokens,
        output_hidden_states=False,
    )
    bert_model = BertModel(config)
    
    # Select the model based on graph_embedding_mode.
    if args.graph_embedding_mode == "none":
        model = NextEventModel(bert_model, hidden_size=args.hidden_size, emb_dim=args.emb_dim,
                               vocab_size=len(vocab), event_token_count=4)
    elif args.graph_embedding_mode == "post":
        model = NextEventModelGraphAfterBERT(bert_model, hidden_size=args.hidden_size, emb_dim=args.emb_dim,
                                             vocab_size=len(vocab), uid_embedding_dict=uid_embedding_dict,
                                             graph_emb_dim=args.graph_embedding_dim, event_token_count=4)
    elif args.graph_embedding_mode == "input":
        model = NextEventModelGraphInputBERT(bert_model, hidden_size=args.hidden_size, emb_dim=args.emb_dim,
                                             vocab_size=len(vocab), uid_embedding_dict=uid_embedding_dict,
                                             graph_emb_dim=args.graph_embedding_dim, event_token_count=4)
    else:
        logger.error("Invalid graph_embedding_mode: %s", args.graph_embedding_mode)
        exit(1)
    
    model.to(device)
        
    # Loss selection and predictor(s) creation.
    if args.loss_type == "hierarchical":
        # Create a list of 4 EventPredictor instances.
        predictor_list = []
        for i in range(4):
            # For level i, input dimension is (i+1)*emb_dim.
            predictor_list.append(EventPredictor(in_channels=(i+1)*args.emb_dim).to(device))
        loss_fn = hierarchical_bce_loss
        predictor = torch.nn.ModuleList(predictor_list)
        logger.info("Using Hierarchical Binary Cross Entropy Loss")
    else:
        # Vanilla loss uses a single EventPredictor expecting full geohash: 4*emb_dim.
        predictor = EventPredictor(in_channels=4*args.emb_dim).to(device)
        loss_fn = bce_loss
        logger.info("Using Vanilla Binary Cross Entropy Loss")
        
    predictor.to(device)
    optimizer = optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=args.learning_rate)
    
    eval_predictor = predictor[-1] if isinstance(predictor, torch.nn.ModuleList) else predictor
    test_result = evaluate_model(model, test_dataset, tokenizer, device, args.batch_size, args.hits_k, eval_predictor, args.graph_embedding_mode, args.num_workers)
    # logger.info("Before training Test Evaluation Result: %s", test_result)
    
    train_model(model, train_loader, optimizer, tokenizer, device, args.num_epochs,
                negative_sampler, args.num_negative_samples, args.eval_every, val_dataset, args.hits_k, loss_fn, predictor, args.graph_embedding_mode, args.num_workers)
    
    torch.save(model.state_dict(), model_save_path)
    logger.info("Model saved to %s", model_save_path)
    
    eval_predictor = predictor[-1] if isinstance(predictor, torch.nn.ModuleList) else predictor
    test_result = evaluate_model(model, test_dataset, tokenizer, device, args.batch_size, args.hits_k, eval_predictor, args.graph_embedding_mode, args.num_workers)
    logger.info("Final Test Evaluation Result: %s", test_result)


# -----------------------------------------------------------------------------
# Argument Parsing and Main Entry Point
# -----------------------------------------------------------------------------
if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Next Event Prediction with Binary or Hierarchical Loss and optional Graph Embedding Integration"
    )
    parser.add_argument("--input_path", type=str, required=True,
                        help="Input directory path containing the data files")
    parser.add_argument("--event_set", type=str, default="personal", help="Filter for event_set column")
    parser.add_argument("--max_events", type=int, default=100, help="Maximum number of events per user")
    parser.add_argument("--max_examples", type=int, default=50, help="Max number of examples to generate per user")
    parser.add_argument("--batch_size", type=int, default=1024, help="Batch size")
    parser.add_argument("--num_epochs", type=int, default=50, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size for BERT model")
    parser.add_argument("--num_hidden_layers", type=int, default=4, help="Number of hidden layers in BERT")
    parser.add_argument("--num_attention_heads", type=int, default=4, help="Number of attention heads in BERT")
    parser.add_argument("--emb_dim", type=int, default=64, help="Dimension of token embeddings for negative sampling")
    parser.add_argument("--num_negative_samples", type=int, default=10, help="Number of negative candidates per sample")
    parser.add_argument("--hits_k", type=int, nargs="+", default=[3, 5, 10, 50], help="List of k values for Hits@k")
    parser.add_argument("--eval_every", type=int, default=5, help="Evaluate every N epochs")
    parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
    parser.add_argument('--seed', type=int, default=42, help="Random seed for reproducibility.")    
    parser.add_argument("--experiment_name", type=str, default="default",
                        help="Name of the experiment; files will be stored in the 'stored' folder using this name")
    parser.add_argument("--loss_type", type=str, choices=["binary", "hierarchical"], default="binary",
                        help="Loss type to use: 'binary' (vanilla) or 'hierarchical'")
    parser.add_argument("--graph_embedding_mode", type=str, choices=["none", "post", "input"], default="none",
                        help="Graph embedding integration mode: 'none' (default), 'post' (after BERT), or 'input' (as input token to BERT)")
    parser.add_argument("--graph_embedding_file", type=str, default="node_embeddings.pkl",
                        help="Path to precomputed graph embeddings file (if using graph embedding)")
    parser.add_argument("--graph_embedding_dim", type=int, default=0,
                        help="Dimension of graph embeddings (if using graph embedding); if 0, inferred from the file")
    # New argument to choose the negative sampler version.
    parser.add_argument("--negative_sampler", type=str, choices=["original", "optimized"], default="optimized",
                        help="Choose negative sampler version: 'original' for original sampler, 'optimized' for optimized sampler")
    # NEW: Add argument to set number of DataLoader worker processes.
    parser.add_argument("--num_workers", type=int, default=0,
                        help="Number of worker processes for DataLoader (default: 0 (all))")
    args = parser.parse_args()
    
    main(args)
