import os
import time
import random
import argparse
import logging
from tqdm import tqdm
import gc

import torch
import pandas as pd
import numpy as np
from torcheval.metrics import ReciprocalRank, HitRate

from torch_geometric.data import TemporalData
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

from tgb_memory_module import DyRepMemory

###############################################################################
# Logging configuration
###############################################################################
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

###############################################################################
# Helper function to clear GPU memory
###############################################################################
def clear_unused_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

###############################################################################
# Argument Parsing
###############################################################################
def parse_args():
    parser = argparse.ArgumentParser(
        description="Personal Edge Prediction with Batching on Large Graphs, Early Stopping, and Custom CSV Paths"
    )
    parser.add_argument('--base_path', type=str, default='.',
                        help="Base path for CSV files.")
    parser.add_argument('--model', type=str, choices=['TGN', 'DyRep'], default='TGN',
                        help="Model type to use.")
    parser.add_argument('--embedding_dim', type=int, default=32,
                        help="Dimensionality of input node embeddings.")
    parser.add_argument('--epochs', type=int, default=10,
                        help="Maximum number of training epochs.")
    parser.add_argument('--lr', type=float, default=1e-3,
                        help="Learning rate.")
    parser.add_argument('--train_batch_size', type=int, default=2**12,
                        help="Mini-batch size for training.")
    parser.add_argument('--eval_batch_size', type=int, default=2**18,
                        help="Mini-batch size for evaluation and testing.")
    parser.add_argument('--patience', type=int, default=10,
                        help="Early stopping patience (in epochs).")
    parser.add_argument('--eval_every', type=int, default=2,
                        help="Evaluation frequency (in epochs).")
    parser.add_argument('--seed', type=int, default=42,
                        help="Random seed for reproducibility.")
    parser.add_argument('--gpu_id', type=int, default=0,
                        help="GPU ID to use (if CUDA is available).")
    parser.add_argument('--enable_early_stopping', action='store_true',
                        help='Enable early stopping based on validation MRR. (default: False)')
    parser.add_argument('--num_neighbors', type=int, default=10,
                        help="Number of neighbors to sample per layer in the neighbor loader.")
    parser.add_argument('--train_neg_sampling_ratio', type=float, default=5,
                        help="Negative sampling ratio for training in the neighbor loader.")
    parser.add_argument('--hits_k', type=str, default="3,5,10,50",
                        help="Comma separated list of k values for computing Hits@k metrics.")
    parser.add_argument('--experiment_name', type=str, default='TG-exp-1',
                        help="Experiment name for saving checkpoints and artifacts.")
    parser.add_argument("--max_events", type=int, default=100,
                        help="Maximum number of personal events per user (will truncate/pad accordingly)")
    # New flag: by default relational edges are included. When --no_relational is passed, relational edges are excluded.
    parser.add_argument('--no_relational', action='store_true',
                        help="If set, relational edges will be excluded from the graph. By default, relational edges are included.")
    # NEW: Option to adjust number of workers used in the data loaders.
    parser.add_argument('--num_workers', type=int, default=0,
                        help="Number of worker processes for data loaders (default: 0: (all))")
    return parser.parse_args()

###############################################################################
# Set Random Seed for Reproducibility
###############################################################################
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

###############################################################################
# Data Loading for Personal Edge Prediction
###############################################################################
def load_data(base_path, embedding_dim, max_events, include_relational_edges):
    """
    Load edge lists from CSV files to create a PyG TemporalData object.
    The CSV files expected are:
      - personal_train.csv
      - personal_val.csv
      - personal_val_negative_sample.csv
      - personal_test.csv
      - personal_test_negative_sample.csv
      - relational.csv (if including relational edges)
      
    Personal edges (user-to-event) are used for link prediction, while 
    relational edges (user-to-user) provide auxiliary context.
    
    Even if relational edges are not included, user nodes (from personal 'uid') are created.
    
    Returns:
        dict: A dictionary containing TemporalData objects for train, val, and test splits,
              and the total number of nodes (num_all_nodes).
    """
    
    logging.info("Preparing dataset for personal edge prediction...")
    logging.info(f"Base path: {base_path}")
    
    # Load CSV files for personal data.
    personal_train_path = os.path.join(base_path, "personal_train.csv")
    personal_val_path = os.path.join(base_path, "personal_val.csv")
    personal_val_neg_path = os.path.join(base_path, "personal_val_negative_sample.csv")
    personal_test_path = os.path.join(base_path, "personal_test.csv")
    personal_test_neg_path = os.path.join(base_path, "personal_test_negative_sample.csv")
    
    personal_train_df = pd.read_csv(personal_train_path)
    personal_val_df = pd.read_csv(personal_val_path)
    personal_test_df = pd.read_csv(personal_test_path)
    personal_val_neg_df = pd.read_csv(personal_val_neg_path)
    personal_test_neg_df = pd.read_csv(personal_test_neg_path)
    
    # Determine number of user nodes.
    if include_relational_edges:
        relational_path = os.path.join(base_path, "relational_observed.csv")
        relational_df = pd.read_csv(relational_path)
        # Consider uids from both personal and relational events.
        max_user_personal = max(personal_train_df['uid'].max(), 
                                personal_val_df['uid'].max(), 
                                personal_test_df['uid'].max())
        max_user_relational = max(relational_df['uid'].max(), relational_df['other_uid'].max())
        num_user_nodes = max(max_user_personal, max_user_relational) + 1
    else:
        # Compute user nodes solely from personal data.
        num_user_nodes = max(
            personal_train_df['uid'].max(),
            personal_val_df['uid'].max(),
            personal_test_df['uid'].max()
        ) + 1
    
    # Optionally limit personal events per user in the training set.
    if max_events > -1:
        personal_train_df = personal_train_df.sort_values("timestamp").groupby("uid", group_keys=False) \
            .tail(max_events).sort_values(["uid", "timestamp"])
    
    # Factorize personal event IDs across all positive personal data (train, val, test).
    all_personal_events = pd.concat([personal_train_df['event'], personal_val_df['event'], personal_test_df['event'], personal_val_neg_df['event'], personal_test_neg_df['event']])
    event_codes, uniques = pd.factorize(all_personal_events)
    event_mapping = {event: code for event, code in zip(uniques, range(len(uniques)))}
    
    def map_events(df):
        return df['event'].map(event_mapping)
    
    # Map personal event IDs and offset them by the number of user nodes.
    personal_train_codes = map_events(personal_train_df) + num_user_nodes
    personal_val_codes = map_events(personal_val_df) + num_user_nodes
    personal_test_codes = map_events(personal_test_df) + num_user_nodes
    personal_val_neg_codes = map_events(personal_val_neg_df) + num_user_nodes
    personal_test_neg_codes = map_events(personal_test_neg_df) + num_user_nodes
    
    # Prepare training edges using personal data (uid --> event node).
    src_p = torch.tensor(personal_train_df['uid'].values, dtype=torch.long)
    dst_p = torch.tensor(personal_train_codes.values, dtype=torch.long)
    t_p = torch.tensor(personal_train_df['timestamp'].values, dtype=torch.long)
    
    # Build the final training edge set.
    if include_relational_edges:
        # Load relational edges (uid --> uid).
        src_r = torch.tensor(relational_df['uid'].values, dtype=torch.long)
        dst_r = torch.tensor(relational_df['other_uid'].values, dtype=torch.long)
        t_r = torch.tensor(relational_df['timestamp'].values, dtype=torch.long)
        
        src = torch.cat([src_p, src_r], dim=0)
        dst = torch.cat([dst_p, dst_r], dim=0)
        t = torch.cat([t_p, t_r], dim=0)
        indices = torch.argsort(t)
        src = src[indices]
        dst = dst[indices]
        t   = t[indices]
    else:
        # Only use personal edges.
        src = src_p
        dst = dst_p
        t = t_p

    train_data = TemporalData(src=src, dst=dst, t=t)
    
    # Build validation and test TemporalData objects using personal edges only.
    eval_time = t_p.max() + 1000
    src_val_pos = torch.tensor(personal_val_df['uid'].values, dtype=torch.long)
    dst_val_pos = torch.tensor(personal_val_codes.values, dtype=torch.long)
    t_val_pos = torch.ones(len(src_val_pos), dtype=torch.long) * eval_time
    
    src_val_neg = torch.tensor(personal_val_neg_df['uid'].values, dtype=torch.long)
    dst_val_neg = torch.tensor(personal_val_neg_codes.values, dtype=torch.long)
    t_val_neg = torch.ones(len(src_val_neg), dtype=torch.long) * eval_time
    val_data_pos = TemporalData(src=src_val_pos, dst=dst_val_pos, t=t_val_pos)
    val_data_neg = TemporalData(src=src_val_neg, dst=dst_val_neg, t=t_val_neg)
    
    src_test_pos = torch.tensor(personal_test_df['uid'].values, dtype=torch.long)
    dst_test_pos = torch.tensor(personal_test_codes.values, dtype=torch.long)
    t_test_pos = torch.ones(len(src_test_pos), dtype=torch.long) * eval_time
    
    src_test_neg = torch.tensor(personal_test_neg_df['uid'].values, dtype=torch.long)
    dst_test_neg = torch.tensor(personal_test_neg_codes.values, dtype=torch.long)
    t_test_neg = torch.ones(len(src_test_neg), dtype=torch.long) * eval_time
    test_data_pos = TemporalData(src=src_test_pos, dst=dst_test_pos, t=t_test_pos)
    test_data_neg = TemporalData(src=src_test_neg, dst=dst_test_neg, t=t_test_neg)
    
    num_personal_nodes = len(uniques)
    # Total nodes = user nodes + personal event nodes.
    num_all_nodes = num_user_nodes + num_personal_nodes
    
    logging.info('>> Data stats')
    logging.info(f'>> Train data: src.shape {train_data.src.shape}, num nodes: {train_data.num_nodes}')
    logging.info(f'>> Val data pos: src.shape {val_data_pos.src.shape}, num nodes: {val_data_pos.num_nodes}')
    logging.info(f'>> Val data neg: src.shape {val_data_neg.src.shape}, num nodes: {val_data_neg.num_nodes}')
    logging.info(f'>> Test data pos: src.shape {test_data_pos.src.shape}, num nodes: {test_data_pos.num_nodes}')
    logging.info(f'>> Test data neg: src.shape {test_data_neg.src.shape}, num nodes: {test_data_neg.num_nodes}')
     
    data = {
        'train_data': train_data,
        'val_data_pos': val_data_pos,
        'val_data_neg': val_data_neg,
        'test_data_pos': test_data_pos,
        'test_data_neg': test_data_neg,
        'num_all_nodes': num_all_nodes
    }

    return data


###############################################################################
# Model Definitions
###############################################################################
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)

class LinkPredictor(torch.nn.Module):
    """
    Link Predictor for personal edge prediction.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = torch.nn.Linear(in_channels, in_channels)
        self.lin_dst = torch.nn.Linear(in_channels, in_channels)
        self.lin_final = torch.nn.Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)
    
###############################################################################
# Evaluation Functions using torcheval
###############################################################################
def evaluate(candidates, target, k_list):
    mrr_metric = ReciprocalRank()
    hits_metrics = {k: HitRate(k=k) for k in k_list}
    
    mrr_metric.update(candidates, target)
    for k, hits_metric in hits_metrics.items():
        hits_metric.update(candidates, target)
    
    mrr = mrr_metric.compute().mean().item()
    hits_results = {f"Hits@{k}": hits_metrics[k].compute().mean().item() for k in k_list}
    return {"MRR": mrr, **hits_results}

def test(model, data, eval_batch_size, device, hits_k, num_workers, split="test"):
    embedding_module = model['embedding_module']
    memory = model['memory']
    gnn = model['gnn']
    linkpred = model['linkpred']
    neighbor_loader = model['neighbor_loader']
    assoc = model['assoc']
    assoc2 = model['assoc2']
    
    embedding_module.eval()
    memory.eval()
    gnn.eval()
    linkpred.eval()
    
    pos_scores_list = []
    neg_scores_list = []
    
    train_data = data['train_data']
    
    if split == "val":
        desc_prefix = "Validation"
        pos_data = data['val_data_pos']
        neg_data = data['val_data_neg']
        logging.info("Starting validation evaluation...")
    else:
        desc_prefix = "Testing"
        pos_data = data['test_data_pos']
        neg_data = data['test_data_neg']
        logging.info("Starting testing evaluation...")
    
    pos_loader = TemporalDataLoader(pos_data, batch_size=eval_batch_size, neg_sampling_ratio=0, num_workers=num_workers)
    neg_loader = TemporalDataLoader(neg_data, batch_size=eval_batch_size, neg_sampling_ratio=0, num_workers=num_workers)
    
    with torch.no_grad():
        for batch in tqdm(pos_loader, desc=f"{desc_prefix} Pos", leave=False):
            batch = batch.to(device)
            n_id, edge_index, e_id = neighbor_loader(batch.n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)
            
            neigh_nid = torch.unique(torch.cat([batch.src, batch.dst, edge_index[0]]))
            assoc2[neigh_nid] = torch.arange(neigh_nid.size(0), device=device)
            
            x = embedding_module(neigh_nid)
            msg = torch.cat([x[assoc2[edge_index[0]]], x[assoc2[edge_index[1]]]], dim=1)
            
            z, last_update = memory(n_id)
            z = gnn(z, last_update, edge_index, train_data.t[e_id.cpu()].to(device), msg)
            
            scores = linkpred(z[assoc[batch.src]], z[assoc[batch.dst]])
            pos_scores_list.append(scores.cpu())
            
        for batch in tqdm(neg_loader, desc=f"{desc_prefix} Neg", leave=False):
            batch = batch.to(device)
            n_id, edge_index, e_id = neighbor_loader(batch.n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)
            
            neigh_nid = torch.unique(torch.cat([batch.src, batch.dst, edge_index[0]]))
            assoc2[neigh_nid] = torch.arange(neigh_nid.size(0), device=device)
            
            x = embedding_module(neigh_nid)
            msg = torch.cat([x[assoc2[edge_index[0]]], x[assoc2[edge_index[1]]]], dim=1)
            
            z, last_update = memory(n_id)
            z = gnn(z, last_update, edge_index, train_data.t[e_id.cpu()].to(device), msg)
            
            scores = linkpred(z[assoc[batch.src]], z[assoc[batch.dst]])
            neg_scores_list.append(scores.cpu())
    
    pos_scores = torch.cat(pos_scores_list, dim=0)
    neg_scores = torch.cat(neg_scores_list, dim=0)
    
    num_samples = pos_scores.shape[0]
    num_neg = neg_scores.shape[0] // num_samples
    neg_scores = neg_scores.reshape(num_samples, num_neg)
    
    candidates = torch.cat([pos_scores, neg_scores], dim=1)
    target = torch.zeros(num_samples, dtype=torch.long)  # positive sample is at index 0
    
    metrics = evaluate(candidates, target, k_list=hits_k)
    return metrics

def train(model, data, optimizer, criterion, epochs, train_batch_size, eval_batch_size,
          device, patience, eval_every, enable_early_stopping, train_neg_sampling_ratio, hits_k, num_workers):
    train_data = data['train_data']
    train_loader = TemporalDataLoader(train_data, batch_size=train_batch_size, neg_sampling_ratio=train_neg_sampling_ratio, num_workers=num_workers)
    
    embedding_module = model['embedding_module']
    memory = model['memory']
    gnn = model['gnn']
    linkpred = model['linkpred']
    neighbor_loader = model['neighbor_loader']
    assoc = model['assoc']
    assoc2 = model['assoc2']
    
    embedding_module.train()
    memory.train()
    gnn.train()
    linkpred.train()
    
    best_val_mrr = -float('inf')
    patience_counter = 0
    best_state = None

    logging.info("Starting training...")
    
    for epoch in range(1, epochs + 1):
        memory.reset_state()
        neighbor_loader.reset_state()
        
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=False):
            batch = batch.to(device)
            optimizer.zero_grad()
                                    
            n_id, edge_index, e_id = neighbor_loader(batch.n_id)
            assoc[n_id] = torch.arange(n_id.size(0), device=device)
            
            neigh_nid = torch.unique(torch.cat([batch.src, batch.dst, edge_index[0], edge_index[1]]))
            assoc2[neigh_nid] = torch.arange(neigh_nid.size(0), device=device)
            
            x = embedding_module(neigh_nid)
            msg = torch.cat([x[assoc2[edge_index[0]]], x[assoc2[edge_index[1]]]], dim=1)
            
            z, last_update = memory(n_id)
            z = gnn(z, last_update, edge_index, train_data.t[e_id.cpu()].to(device), msg)
            
            pos_out = linkpred(z[assoc[batch.src]], z[assoc[batch.dst]])
            neg_src = batch.src.repeat_interleave(train_neg_sampling_ratio, dim=0)
            neg_out = linkpred(z[assoc[neg_src]], z[assoc[batch.neg_dst]])
            
            loss = criterion(pos_out, torch.ones_like(pos_out))
            loss += criterion(neg_out, torch.zeros_like(neg_out))
            
            msg_gt = torch.cat([x[assoc2[batch.src]], x[assoc2[batch.dst]]], dim=1).detach()
            memory.update_state(batch.src, batch.dst, batch.t, msg_gt)
            neighbor_loader.insert(batch.src, batch.dst)

            loss.backward()
            optimizer.step()
            memory.detach()
            
            total_loss += float(loss) * batch.num_events
        
        avg_loss = total_loss / train_data.num_events
        logging.info(f"Epoch {epoch:03d} Loss: {avg_loss:.4f}")
        
        if epoch % eval_every == 0:
            del batch, n_id, edge_index, e_id, x, msg, z, last_update
            clear_unused_gpu_memory()
            
            metrics = test(model, data, eval_batch_size, device, hits_k, num_workers, split="val")
            val_mrr = metrics["MRR"]
            log_msg = f"Epoch {epoch:03d} Evaluation | Val MRR: {val_mrr:.4f}"
            for k in hits_k:
                log_msg += f", Hits@{k}: {metrics[f'Hits@{k}']:.4f}"
            logging.info(log_msg)
            
            if enable_early_stopping:
                if val_mrr > best_val_mrr:
                    best_val_mrr = val_mrr
                    best_state = {k: v.state_dict() for k, v in model.items()}
                    patience_counter = 0
                    logging.info("New best model found; resetting patience counter.")
                else:
                    patience_counter += 1
                    logging.info(f"No improvement; patience counter: {patience_counter}/{patience}")
                
                if patience_counter >= patience:
                    logging.info(f"Early stopping triggered at epoch {epoch}.")
                    break                    

    if best_state is not None:
        for k, v in model.items():
            if hasattr(v, "load_state_dict"):
                v.load_state_dict(best_state[k])
    return model

###############################################################################
# Checkpoint Saving Function
###############################################################################
def save_checkpoint(filepath, model, optimizer, args, elapsed_time, test_metrics):
    checkpoint = {
        'experiment_name': args.experiment_name,
        'model_state_dict': {k: v.state_dict() for k, v in model.items() if hasattr(v, "state_dict")},
        'optimizer_state_dict': optimizer.state_dict(),
        'args': vars(args),
        'elapsed_time': elapsed_time,
        'test_metrics': test_metrics,
    }
    torch.save(checkpoint, filepath)
    logging.info(f"Checkpoint saved at {filepath}")

###############################################################################
# Main Execution
###############################################################################
if __name__ == "__main__":
    args = parse_args()
    args.hits_k = [int(x) for x in args.hits_k.split(",")]
    
    set_seed(args.seed)
    device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')

    # Use relational edges unless --no_relational is provided.
    data = load_data(
        args.base_path,
        args.embedding_dim,
        args.max_events,
        include_relational_edges=not args.no_relational
    )
    num_all_nodes = data['num_all_nodes']
    
    # Model configuration
    embedding_dim = args.embedding_dim
    msg_dim = memory_dim = time_dim = 2 * args.embedding_dim
    
    embedding_module = torch.nn.Embedding(num_all_nodes, embedding_dim).to(device)
    
    if args.model == 'TGN':
        memory = TGNMemory(
            num_all_nodes,
            msg_dim,
            memory_dim,
            time_dim,
            message_module=IdentityMessage(msg_dim, memory_dim, time_dim),
            aggregator_module=LastAggregator(),
        ).to(device)
    else:
        memory = DyRepMemory(
            num_all_nodes,
            msg_dim,
            memory_dim,
            time_dim,
            message_module=IdentityMessage(msg_dim, memory_dim, time_dim),
            aggregator_module=LastAggregator(),
            memory_updater_type='rnn'
        ).to(device)        

    gnn = GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=msg_dim,
        time_enc=memory.time_enc,
    ).to(device)

    linkpred = LinkPredictor(in_channels=embedding_dim).to(device)
    neighbor_loader = LastNeighborLoader(num_all_nodes, size=args.num_neighbors, device=device)
    
    # Helper tensors to map global node indices to local ones.
    assoc = torch.empty(num_all_nodes, dtype=torch.long, device=device)
    assoc2 = torch.empty(num_all_nodes, dtype=torch.long, device=device)
    
    model = {
        'embedding_module': embedding_module,
        'memory': memory,
        'gnn': gnn,
        'linkpred': linkpred,
        'neighbor_loader': neighbor_loader,
        'assoc': assoc,
        'assoc2': assoc2
    }
    
    optimizer = torch.optim.Adam(
        set(memory.parameters()) |
        set(gnn.parameters()) |
        set(linkpred.parameters()) |
        set(embedding_module.parameters()),
        lr=args.lr
    )
    criterion = torch.nn.BCEWithLogitsLoss()

    start_time = time.time()
    model = train(
        model, data, optimizer, criterion, args.epochs,
        args.train_batch_size, args.eval_batch_size, device,
        patience=args.patience, eval_every=args.eval_every,
        enable_early_stopping=args.enable_early_stopping,
        train_neg_sampling_ratio=args.train_neg_sampling_ratio,
        hits_k=args.hits_k,
        num_workers=args.num_workers
    )
    elapsed_time = time.time() - start_time

    clear_unused_gpu_memory()
    test_metrics = test(model, data, args.eval_batch_size, device, args.hits_k, args.num_workers, split="test")
    log_msg = f"{args.model} (Personal Prediction): MRR: {test_metrics['MRR']:.4f}"
    for k in args.hits_k:
        log_msg += f", Hits@{k}: {test_metrics[f'Hits@{k}']:.4f}"
    log_msg += f", Time: {elapsed_time:.2f}s"
    logging.info(log_msg)

    os.makedirs("stored", exist_ok=True)
    checkpoint_filename = f"stored/{args.experiment_name}_checkpoint.pt"
    save_checkpoint(checkpoint_filename, model, optimizer, args, elapsed_time, test_metrics)
