import os
import time
import random
import argparse
import logging
from tqdm import tqdm

import pickle
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch_geometric.data import Data
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from torch_geometric.loader import LinkNeighborLoader
from torcheval.metrics import ReciprocalRank, HitRate

###############################################################################
# Logging configuration
###############################################################################
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

###############################################################################
# Argument Parsing
###############################################################################
def parse_args():
    parser = argparse.ArgumentParser(
        description="Link Prediction on Personal Edges with Batching, Early Stopping, and Custom CSV Paths"
    )
    parser.add_argument('--base_path', type=str, default='.',
                        help="Base path for CSV files.")
    parser.add_argument('--model', type=str, choices=['GCN', 'GraphSAGE', 'GAT', 'GIN'], default='GCN',
                        help="Model type to use.")
    parser.add_argument('--hidden_channels', type=int, default=128,
                        help="Number of hidden channels in each layer.")
    parser.add_argument('--num_layers', type=int, default=2,
                        help="Number of GNN layers.")
    parser.add_argument('--embedding_dim', type=int, default=128,
                        help="Dimensionality of input node embeddings.")
    parser.add_argument('--epochs', type=int, default=10,
                        help="Maximum number of training epochs.")
    parser.add_argument('--lr', type=float, default=1e-3,
                        help="Learning rate.")
    parser.add_argument('--dropout', type=float, default=0.1,
                        help="Dropout rate.")
    parser.add_argument('--heads', type=int, default=4,
                        help="Number of attention heads for GAT.")
    parser.add_argument('--train_batch_size', type=int, default=2**10,
                        help="Mini-batch size for training.")
    parser.add_argument('--eval_batch_size', type=int, default=2**20,
                        help="Mini-batch size for evaluation and testing.")
    parser.add_argument('--patience', type=int, default=10,
                        help="Early stopping patience (in epochs).")
    parser.add_argument('--eval_every', type=int, default=2,
                        help="Evaluation frequency (in epochs).")
    parser.add_argument('--seed', type=int, default=42,
                        help="Random seed for reproducibility.")
    parser.add_argument('--gpu_id', type=int, default=0,
                        help="GPU ID to use (if CUDA is available).")
    parser.add_argument('--enable_early_stopping', action='store_true',
                        help='Enable early stopping based on validation MRR. (default: False)')
    parser.add_argument('--num_neighbors', type=int, default=10,
                        help="Number of neighbors to sample per layer in LinkNeighborLoader.")
    parser.add_argument('--train_neg_sampling_ratio', type=float, default=5,
                        help="Negative sampling ratio for training in LinkNeighborLoader.")
    parser.add_argument('--hits_k', type=str, default="3,5,10,50",
                        help="Comma separated list of k values for computing Hits@k metrics.")
    parser.add_argument('--experiment_name', type=str, default='PP-exp-1',
                        help="Experiment name for saving checkpoints and artifacts.")
    parser.add_argument("--max_events", type=int, default=100,
                        help="Maximum number of events per user (will truncate/pad accordingly)")
    # NEW: Option to adjust number of workers used in the data loaders.
    parser.add_argument('--num_workers', type=int, default=0,
                        help="Number of worker processes for data loaders (default: 0 (all))")
    # NEW: Option to include or exclude relational edges.
    parser.add_argument('--no_relational', dest='include_relational', action='store_false',
                        help="Exclude relational edges from the graph.")
    parser.set_defaults(include_relational=True)
    return parser.parse_args()

###############################################################################
# Set Random Seed for Reproducibility
###############################################################################
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

###############################################################################
# Data Loading
###############################################################################
def load_data(base_path, embedding_dim, max_events=-1, include_relational=True):
    """
    Load edge lists from CSV files and create a PyG Data object.
    
    The graph consists of two types of edges:
      - Personal edges (target for link prediction): from a uid (user) to an event.
      - Relational edges (auxiliary): from uid to uid.
    
    The uid (user) nodes are determined from the personal CSV files. If include_relational
    is True, additional uid nodes from the relational CSV (if any) will be included.
    
    Personal event nodes are created by factorizing event identifiers and are offset by the
    number of uid nodes so that uid and event nodes do not overlap.
    
    Returns:
        Data: A PyG Data object with custom attributes.
    """
    logging.info("Preparing dataset for personal link prediction...")
    logging.info(f"Base path: {base_path}")

    # --- Load Personal CSV Files ---
    personal_train_path = os.path.join(base_path, "personal_train.csv")
    personal_val_path = os.path.join(base_path, "personal_val.csv")
    personal_test_path = os.path.join(base_path, "personal_test.csv")
    personal_val_neg_path = os.path.join(base_path, "personal_val_negative_sample.csv")
    personal_test_neg_path = os.path.join(base_path, "personal_test_negative_sample.csv")

    personal_train_df = pd.read_csv(personal_train_path)
    personal_val_df = pd.read_csv(personal_val_path)
    personal_test_df = pd.read_csv(personal_test_path)
    personal_val_neg_df = pd.read_csv(personal_val_neg_path)
    personal_test_neg_df = pd.read_csv(personal_test_neg_path)

    # Compute the set of uid nodes from personal CSV files.
    personal_uids = set(personal_train_df['uid'].values) | \
                    set(personal_val_df['uid'].values) | \
                    set(personal_test_df['uid'].values)

    # --- Optionally Load Relational CSV ---
    if include_relational:
        relational_path = os.path.join(base_path, "relational_observed.csv")
        relational_df = pd.read_csv(relational_path)
        relational_edge_index = torch.tensor(
            [relational_df['uid'].values, relational_df['other_uid'].values],
            dtype=torch.long
        )
        # Get uid nodes from relational data.
        relational_uids = set(relational_df['uid'].values) | set(relational_df['other_uid'].values)
        # Combine uid nodes from personal and relational sources.
        all_uids = personal_uids | relational_uids
    else:
        relational_edge_index = None
        all_uids = personal_uids

    # Ensure that uid nodes are always present.
    if all_uids:
        # Assumes that uid values are numeric and that the maximum plus one gives a node count.
        num_uid_nodes = max(all_uids) + 1
    else:
        num_uid_nodes = 0

    # --- Process Personal Events ---
    # Combine personal datasets to ensure a consistent factorization of event ids.
    # all_personal_df = pd.concat([personal_train_df, personal_val_df, personal_test_df])
    all_personal_df = pd.concat([personal_train_df, personal_val_df, personal_test_df, personal_val_neg_df, personal_test_neg_df])
    unique_events = all_personal_df['event'].unique()
    num_personal_nodes = len(unique_events)
    
    # Create mapping for events and offset them by the number of uid nodes.
    _, uniques = pd.factorize(unique_events)
    event_mapping = {event: i for i, event in enumerate(uniques)}
    
    def convert_event(df):
        return df['event'].map(lambda x: event_mapping[x] + num_uid_nodes)
    
    # Create personal edge indices (from uid to event).
    src_personal_train = torch.tensor(personal_train_df['uid'].values, dtype=torch.long)
    dst_personal_train = torch.tensor(convert_event(personal_train_df).values, dtype=torch.long)
    personal_train_edge_index = torch.stack([src_personal_train, dst_personal_train], dim=0)
    
    src_personal_val = torch.tensor(personal_val_df['uid'].values, dtype=torch.long)
    dst_personal_val = torch.tensor(convert_event(personal_val_df).values, dtype=torch.long)
    personal_val_edge_index = torch.stack([src_personal_val, dst_personal_val], dim=0)
    
    src_personal_test = torch.tensor(personal_test_df['uid'].values, dtype=torch.long)
    dst_personal_test = torch.tensor(convert_event(personal_test_df).values, dtype=torch.long)
    personal_test_edge_index = torch.stack([src_personal_test, dst_personal_test], dim=0)
    
    src_personal_val_neg = torch.tensor(personal_val_neg_df['uid'].values, dtype=torch.long)
    dst_personal_val_neg = torch.tensor(convert_event(personal_val_neg_df).values, dtype=torch.long)
    personal_val_neg_edge_index = torch.stack([src_personal_val_neg, dst_personal_val_neg], dim=0)
    
    src_personal_test_neg = torch.tensor(personal_test_neg_df['uid'].values, dtype=torch.long)
    dst_personal_test_neg = torch.tensor(convert_event(personal_test_neg_df).values, dtype=torch.long)
    personal_test_neg_edge_index = torch.stack([src_personal_test_neg, dst_personal_test_neg], dim=0)
    
    # --- Combine Edges for Training ---
    if include_relational:
        # Concatenate relational (uid-to-uid) and personal (uid-to-event) edges.
        train_edge_index = torch.cat([relational_edge_index, personal_train_edge_index], dim=1)
        train_edge_type = torch.cat([
            torch.zeros(relational_edge_index.size(1), dtype=torch.long),  # Type 0 for relational edges.
            torch.ones(personal_train_edge_index.size(1), dtype=torch.long)  # Type 1 for personal edges.
        ])
    else:
        train_edge_index = personal_train_edge_index
        train_edge_type = torch.ones(personal_train_edge_index.size(1), dtype=torch.long)
    
    # Total number of nodes: uid nodes + event nodes.
    num_all_nodes = num_uid_nodes + num_personal_nodes
    data = Data(
        node_id=torch.arange(num_all_nodes),
        edge_index=train_edge_index,
        edge_type=train_edge_type,
        val_edge_index=personal_val_edge_index,
        test_edge_index=personal_test_edge_index,
        val_neg_edge_index=personal_val_neg_edge_index,
        test_neg_edge_index=personal_test_neg_edge_index
    )
    
    data = T.ToUndirected()(data)
    
    print(data)
    
    return data

###############################################################################
# Model Definitions
###############################################################################
class GCN(torch.nn.Module):
    """Graph Convolutional Network."""
    def __init__(self, in_channels, hidden_channels, num_layers, dropout):
        super().__init__()
        self.layers = torch.nn.ModuleList([GCNConv(in_channels, hidden_channels)])
        for _ in range(num_layers - 1):
            self.layers.append(GCNConv(hidden_channels, hidden_channels))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            if i < len(self.layers) - 1:
                x = F.relu(layer(x, edge_index))
                x = F.dropout(x, p=self.dropout, training=self.training)
            else:
                x = layer(x, edge_index)
        return x

class GraphSAGE(torch.nn.Module):
    """GraphSAGE model."""
    def __init__(self, in_channels, hidden_channels, num_layers, dropout):
        super().__init__()
        self.layers = torch.nn.ModuleList([SAGEConv(in_channels, hidden_channels)])
        for _ in range(num_layers - 1):
            self.layers.append(SAGEConv(hidden_channels, hidden_channels))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            if i < len(self.layers) - 1:
                x = F.relu(layer(x, edge_index))
                x = F.dropout(x, p=self.dropout, training=self.training)
            else:
                x = layer(x, edge_index)
        return x

class GAT(torch.nn.Module):
    """Graph Attention Network."""
    def __init__(self, in_channels, hidden_channels, num_layers, dropout, heads):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout)
        ])
        for _ in range(num_layers - 1):
            self.layers.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            if i < len(self.layers) - 1:
                x = F.relu(layer(x, edge_index))
                x = F.dropout(x, p=self.dropout, training=self.training)
            else:
                x = layer(x, edge_index)
        return x

class GIN(torch.nn.Module):
    """Graph Isomorphism Network."""
    def __init__(self, in_channels, hidden_channels, num_layers, dropout):
        super().__init__()
        self.layers = torch.nn.ModuleList([
            GINConv(torch.nn.Linear(in_channels, hidden_channels))
        ])
        for _ in range(num_layers - 1):
            self.layers.append(GINConv(torch.nn.Linear(hidden_channels, hidden_channels)))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            if i < len(self.layers) - 1:
                x = F.relu(layer(x, edge_index))
                x = F.dropout(x, p=self.dropout, training=self.training)
            else:
                x = layer(x, edge_index)
        return x

class LinkPredictor(torch.nn.Module):
    """
    Predicts links by combining source and destination node embeddings.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = torch.nn.Linear(in_channels, in_channels)
        self.lin_dst = torch.nn.Linear(in_channels, in_channels)
        self.lin_final = torch.nn.Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.lin_final(h)
    
class Model(torch.nn.Module):
    """Composite Model: Embedding layer followed by the chosen GNN."""
    def __init__(self, gnn, embedding_dim, hidden_channels, num_nodes, heads=1):
        super().__init__()
        self.gnn = gnn
        # ------------------------------------------------------------------
        # Determine the real dimensionality coming out of the GNN.
        #  * For GAT it is hidden_channels * heads,
        #  * otherwise it stays at hidden_channels.
        # ------------------------------------------------------------------
        if isinstance(gnn, GAT):
            self.out_dim = hidden_channels * heads
        else:
            self.out_dim = hidden_channels
            
        self.emb = torch.nn.Embedding(num_nodes, embedding_dim)
        self.linkpred = LinkPredictor(self.out_dim)  
        # self.linkpred = LinkPredictor(hidden_channels)
        # self.emb = torch.nn.Embedding(num_nodes, embedding_dim)

    def forward(self, batch):
        x = self.emb(batch.node_id)
        out = self.gnn(x, batch.edge_index)
        return out

###############################################################################
# Evaluation Functions using torcheval
###############################################################################
def evaluate(candidates, target, k_list):
    """
    Compute retrieval metrics using torcheval.metrics.
    """
    mrr_metric = ReciprocalRank()
    hits_metrics = {k: HitRate(k=k) for k in k_list}
    
    mrr_metric.update(candidates, target)
    for k, hits_metric in hits_metrics.items():
        hits_metric.update(candidates, target)
    
    mrr = mrr_metric.compute().mean().item()
    hits_results = {f"Hits@{k}": hits_metrics[k].compute().mean().item() for k in k_list}
    return {"MRR": mrr, **hits_results}

def test(model, predictor, data, eval_batch_size, device, num_neighbors, hits_k, num_workers, split="test"):
    """
    Evaluate model performance on a given split.
    """
    model.eval()
    pos_scores_list = []
    neg_scores_list = []
    
    if split == "val":
        pos_edge_index = data.val_edge_index
        neg_edge_index = data.val_neg_edge_index
        desc_prefix = "Validation"
    else:
        pos_edge_index = data.test_edge_index
        neg_edge_index = data.test_neg_edge_index
        desc_prefix = "Testing"
    
    pos_loader = LinkNeighborLoader(
        data,
        num_neighbors=[num_neighbors] * len(model.gnn.layers),
        edge_label_index=pos_edge_index,
        batch_size=eval_batch_size,
        shuffle=False,
        neg_sampling_ratio=0,
        num_workers=num_workers
    )
    neg_loader = LinkNeighborLoader(
        data,
        num_neighbors=[num_neighbors] * len(model.gnn.layers),
        edge_label_index=neg_edge_index,
        batch_size=eval_batch_size,
        shuffle=False,
        neg_sampling_ratio=0,
        num_workers=num_workers
    )
    
    with torch.no_grad():
        for batch in tqdm(pos_loader, desc=f"{desc_prefix} Pos", leave=False):
            batch.val_edge_index = None; batch.test_edge_index = None
            batch.val_neg_edge_index = None; batch.test_neg_edge_index = None

            batch = batch.to(device)
            output = model(batch)
            scores = predictor(output[batch.edge_label_index[0]], output[batch.edge_label_index[1]]).squeeze()
            pos_scores_list.append(scores.cpu())
            
        for batch in tqdm(neg_loader, desc=f"{desc_prefix} Neg", leave=False):
            batch.val_edge_index = None; batch.test_edge_index = None
            batch.val_neg_edge_index = None; batch.test_neg_edge_index = None

            batch = batch.to(device)
            output = model(batch)
            scores = predictor(output[batch.edge_label_index[0]], output[batch.edge_label_index[1]]).squeeze()
            neg_scores_list.append(scores.cpu())
    
    pos_scores = torch.cat(pos_scores_list, dim=0)
    neg_scores = torch.cat(neg_scores_list, dim=0)
    
    num_samples = pos_scores.shape[0]
    num_neg = neg_scores.shape[0] // num_samples
    neg_scores = neg_scores.reshape(num_samples, num_neg)
    
    candidates = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
    target = torch.zeros(num_samples, dtype=torch.long)  # positive sample is always at index 0
    
    metrics = evaluate(candidates, target, k_list=hits_k)
    return metrics

def train(model, predictor, data, optimizer, epochs, train_batch_size, eval_batch_size,
          device, patience, eval_every, enable_early_stopping, num_neighbors, train_neg_sampling_ratio, hits_k, num_workers):
    """
    Train the model using mini-batches with early stopping based on validation MRR.
    Note: For this script, link prediction is performed on personal edges.
    """
    
    # Filter to select only personal edges (edge_type == 1) for training.
    mask = (data.edge_type == 1)
    edge_label_index = data.edge_index[:, mask]
    
    train_loader = LinkNeighborLoader(
        data,
        num_neighbors=[num_neighbors] * len(model.gnn.layers),
        edge_label_index=edge_label_index,
        batch_size=train_batch_size,
        shuffle=True,
        neg_sampling_ratio=train_neg_sampling_ratio,
        num_workers=num_workers
    )
    
    best_val_mrr = -float('inf')
    patience_counter = 0
    best_state = None

    logging.info("Starting training...")
    model.train()
    for epoch in range(1, epochs + 1):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch}", leave=False):
            batch.val_edge_index = None; batch.test_edge_index = None
            batch.val_neg_edge_index = None; batch.test_neg_edge_index = None

            batch = batch.to(device)
            optimizer.zero_grad()
            output = model(batch)
            scores = predictor(output[batch.edge_label_index[0]], output[batch.edge_label_index[1]]).squeeze()
            loss = F.binary_cross_entropy_with_logits(scores, batch.edge_label.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        logging.info(f"Epoch {epoch:03d} Loss: {avg_loss:.4f}")
        
        if epoch % eval_every == 0:
            metrics = test(model, predictor, data, eval_batch_size, device, num_neighbors, hits_k, num_workers, split="val")
            val_mrr = metrics["MRR"]
            log_msg = f"Epoch {epoch:03d} Evaluation | Val MRR: {val_mrr:.4f}"
            for k in hits_k:
                log_msg += f", Hits@{k}: {metrics[f'Hits@{k}']:.4f}"
            logging.info(log_msg)
            
            if enable_early_stopping:
                if val_mrr > best_val_mrr:
                    best_val_mrr = val_mrr
                    best_state = model.state_dict()
                    patience_counter = 0
                    logging.info("New best model found; resetting patience counter.")
                else:
                    patience_counter += 1
                    logging.info(f"No improvement; patience counter: {patience_counter}/{patience}")
                
                if patience_counter >= patience:
                    logging.info(f"Early stopping triggered at epoch {epoch}.")
                    break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

###############################################################################
# Checkpoint Saving Function
###############################################################################
def save_checkpoint(filepath, model, predictor, optimizer, args, elapsed_time, test_metrics):
    """
    Save the model, predictor, optimizer, training arguments, elapsed time,
    and test metrics to a checkpoint file.
    """
    checkpoint = {
        'experiment_name': args.experiment_name,
        'model_state_dict': model.state_dict(),
        'predictor_state_dict': predictor.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'args': vars(args),
        'elapsed_time': elapsed_time,
        'test_metrics': test_metrics,
    }
    torch.save(checkpoint, filepath)
    logging.info(f"Checkpoint saved at {filepath}")

###############################################################################
# Main Execution
###############################################################################
if __name__ == "__main__":
    args = parse_args()
    # Process the custom hits_k argument as a list of integers.
    args.hits_k = [int(x) for x in args.hits_k.split(",")]
    
    set_seed(args.seed)
    device = torch.device(f'cuda:{args.gpu_id}' if torch.cuda.is_available() else 'cpu')

    # Load data with the include_relational flag.
    data = load_data(args.base_path, args.embedding_dim, args.max_events, include_relational=args.include_relational)
    
    model_classes = {
        'GCN': GCN,
        'GraphSAGE': GraphSAGE,
        'GAT': GAT,
        'GIN': GIN
    }
    if args.model == 'GAT':
        gnn = GAT(in_channels=args.embedding_dim,
            hidden_channels=args.hidden_channels,
            num_layers=args.num_layers,
            dropout=args.dropout,
            heads=args.heads).to(device)
        model = Model(gnn,
                embedding_dim=args.embedding_dim,
                hidden_channels=args.hidden_channels,
                num_nodes=data.num_nodes,
                heads=args.heads).to(device)
        predictor = LinkPredictor(args.hidden_channels * args.heads).to(device)
    else:
        gnn = model_classes[args.model](
            in_channels=args.embedding_dim,
            hidden_channels=args.hidden_channels,
            num_layers=args.num_layers,
            dropout=args.dropout
        )
        model = Model(gnn, args.embedding_dim, args.hidden_channels, data.num_nodes).to(device)
        predictor = LinkPredictor(args.hidden_channels).to(device)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    test_metrics = test(model, predictor, data, args.eval_batch_size, device, args.num_neighbors, args.hits_k, args.num_workers, split="test")
    log_msg = f"{args.model} before training: MRR: {test_metrics['MRR']:.4f}"
    for k in args.hits_k:
        log_msg += f", Hits@{k}: {test_metrics[f'Hits@{k}']:.4f}"
    logging.info(log_msg)

    start_time = time.time()
    model = train(
        model, predictor, data, optimizer, args.epochs,
        args.train_batch_size, args.eval_batch_size, device,
        patience=args.patience, eval_every=args.eval_every,
        enable_early_stopping=args.enable_early_stopping,
        num_neighbors=args.num_neighbors,
        train_neg_sampling_ratio=args.train_neg_sampling_ratio,
        hits_k=args.hits_k,
        num_workers=args.num_workers
    )
    elapsed_time = time.time() - start_time

    test_metrics = test(model, predictor, data, args.eval_batch_size, device, args.num_neighbors, args.hits_k, args.num_workers, split="test")
    log_msg = f"{args.model}: MRR: {test_metrics['MRR']:.4f}"
    for k in args.hits_k:
        log_msg += f", Hits@{k}: {test_metrics[f'Hits@{k}']:.4f}"
    log_msg += f", Time: {elapsed_time:.2f}s"
    logging.info(log_msg)

    # Save the checkpoint using the experiment name.
    os.makedirs("stored/", exist_ok=True)
    checkpoint_filename = f"stored/{args.experiment_name}_checkpoint.pt"
    save_checkpoint(checkpoint_filename, model, predictor, optimizer, args, elapsed_time, test_metrics)
