import argparse
import logging
import pickle
import os
import torch
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.nn.models import Node2Vec
from torch_geometric.utils import to_undirected

# -----------------------------------------------------------------------------
# Configure Logging
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# 1. Load CSV Data and Build the Edge Index
# -----------------------------------------------------------------------------
def load_data(csv_file):
    logger.info("Loading CSV file: %s", csv_file)
    df = pd.read_csv(csv_file)
    logger.info("CSV loaded with %d records", len(df))
    return df

def build_edge_index(df):
    logger.info("Building edge index from DataFrame...")
    # We assume the CSV file has columns 'uid' and 'other_uid'
    edge_index = torch.tensor([df['uid'].values, df['other_uid'].values], dtype=torch.long)
    # Convert to undirected graph since friendships are mutual.
    edge_index = to_undirected(edge_index)
    logger.info("Edge index shape: %s", list(edge_index.size()))
    return edge_index

# -----------------------------------------------------------------------------
# 2. Train the Node2Vec Model
# -----------------------------------------------------------------------------
def train_node2vec(model, optimizer, loader, device):
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# -----------------------------------------------------------------------------
# 3. Main Function: Setup, Training/Embedding Generation, and Saving
# -----------------------------------------------------------------------------
def main(args):
    # Set device (GPU if available, otherwise CPU)
    device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    logger.info("Using device: %s", device)

    # Create directory for storing models and embeddings if it doesn't exist.
    os.makedirs("stored", exist_ok=True)

    # Load CSV and build the graph edge index.
    df = load_data(args.csv_file)
    edge_index = build_edge_index(df)

    # Initialize the Node2Vec model with the given hyperparameters.
    model = Node2Vec(
        edge_index,
        embedding_dim=args.embedding_dim,  # Default is now 128
        walk_length=args.walk_length,
        context_size=args.context_size,
        walks_per_node=args.walks_per_node,
        num_negative_samples=args.num_negative_samples,
        p=args.p,
        q=args.q,
        sparse=True,
    ).to(device)

    os.makedirs("stored/", exist_ok=True)
    if args.mode == "embed":
        logger.info("Loading model from file: %s", args.model_file)
        model.load_state_dict(torch.load(os.path.join("stored", args.model_file), map_location=device))
    else:
        loader = model.loader(batch_size=args.batch_size, shuffle=True)
        optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=args.learning_rate)
        logger.info("Starting training for %d epochs...", args.num_epochs)
        for epoch in range(1, args.num_epochs + 1):
            loss = train_node2vec(model, optimizer, loader, device)
            logger.info("Epoch: %03d, Loss: %.4f", epoch, loss)
        torch.save(model.state_dict(), os.path.join("stored", args.model_file))
        logger.info("Saved trained model to %s", args.model_file)

    # Generate node embeddings.
    model.eval()
    with torch.no_grad():
        embeddings = model().cpu().numpy()  # Shape: [num_nodes, embedding_dim]
    logger.info("Generated embeddings for %d nodes", embeddings.shape[0])

    # Build a dictionary mapping each uid (from the CSV) to its embedding.
    unique_nodes = pd.concat([df['uid'], df['other_uid']]).unique()
    embedding_dict = {int(uid): embeddings[int(uid)] for uid in unique_nodes}
    
    # Save the embedding dictionary to a file using pickle.
    with open(os.path.join("stored", args.embeddings_file), 'wb') as f:
        pickle.dump(embedding_dict, f)
    logger.info("Saved embeddings to %s", args.embeddings_file)

# -----------------------------------------------------------------------------
# 4. Argument Parser for Configurable Parameters and Mode
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Node2Vec Embedding Training and Generation with PyTorch Geometric")
    parser.add_argument("--csv_file", type=str, default="friendship.csv",
                        help="Path to the CSV file containing friendship relations")
    parser.add_argument("--embedding_dim", type=int, default=128,
                        help="Dimension of node embeddings (default 128)")
    parser.add_argument("--walk_length", type=int, default=20,
                        help="Length of each random walk")
    parser.add_argument("--context_size", type=int, default=10,
                        help="Context size for skip-gram")
    parser.add_argument("--walks_per_node", type=int, default=10,
                        help="Number of random walks per node")
    parser.add_argument("--num_negative_samples", type=int, default=1,
                        help="Number of negative samples per positive sample")
    parser.add_argument("--p", type=float, default=1.0,
                        help="Return hyperparameter p")
    parser.add_argument("--q", type=float, default=1.0,
                        help="In-out hyperparameter q")
    parser.add_argument("--batch_size", type=int, default=128,
                        help="Batch size for training")
    parser.add_argument("--num_epochs", type=int, default=100,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=0.01,
                        help="Learning rate for the optimizer")
    parser.add_argument("--gpu_id", type=int, default=0,
                        help="GPU ID to use (if available)")
    parser.add_argument('--seed', type=int, default=42, help="Random seed for reproducibility.")
    parser.add_argument("--model_file", type=str, default="node2vec_model.pth",
                        help="File path to save/load the trained model")
    parser.add_argument("--embeddings_file", type=str, default="node_embeddings.pkl",
                        help="File path to save node embeddings")
    parser.add_argument("--mode", type=str, default="train", choices=["train", "embed"],
                        help="Mode: 'train' to train and save model, 'embed' to load a trained model and generate embeddings")
    args = parser.parse_args()
    
    main(args)
