import argparse
import logging
import pickle
import os
import torch
import torch.nn.functional as F
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAE, DeepGraphInfomax

# -----------------------------------------------------------------------------
# Configure Logging
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------
# Define a Flexible GCN Encoder for Unsupervised Models
# -----------------------------------------------------------------------------
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim, num_layers=2):
        super(GCNEncoder, self).__init__()
        if num_layers < 2:
            raise ValueError("Number of layers must be at least 2")
        self.convs = torch.nn.ModuleList()
        # First layer: in_channels -> hidden_channels
        self.convs.append(GCNConv(in_channels, hidden_channels))
        # Intermediate layers: hidden_channels -> hidden_channels
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        # Final layer: hidden_channels -> embedding_dim
        self.convs.append(GCNConv(hidden_channels, embedding_dim))
        
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
        x = self.convs[-1](x, edge_index)
        return x

# -----------------------------------------------------------------------------
# Build the Feature Matrix Using uids as Indices
# -----------------------------------------------------------------------------
def build_feature_matrix(features_file, df_edges):
    logger.info("Loading node features from %s", features_file)
    with open(features_file, 'rb') as f:
        feat_dict = pickle.load(f)  # Expected to be {uid: feature_vector}
    if not feat_dict:
        raise ValueError("Feature dictionary is empty")

    max_uid_edges = max(df_edges['uid'].max(), df_edges['other_uid'].max())
    max_uid_features = max(feat_dict.keys())
    max_uid = max(max_uid_edges, max_uid_features)

    sample = next(iter(feat_dict.values()))
    feature_dim = len(sample)

    x = torch.zeros(max_uid + 1, feature_dim, dtype=torch.float)
    for uid, feat in feat_dict.items():
        if uid <= max_uid:
            x[uid] = torch.tensor(feat, dtype=torch.float)
    uids = list(range(max_uid + 1))
    logger.info("Constructed feature matrix for %d nodes with dimension %d", len(uids), feature_dim)
    return uids, x

# -----------------------------------------------------------------------------
# Load Edge Data from CSV
# -----------------------------------------------------------------------------
def load_edges(csv_file):
    logger.info("Loading edges from %s", csv_file)
    df = pd.read_csv(csv_file)
    logger.info("Loaded %d edges", len(df))
    return df

# -----------------------------------------------------------------------------
# Build the Edge Index Using uid Directly as Index
# -----------------------------------------------------------------------------
def build_edge_index(df_edges):
    logger.info("Building edge index from DataFrame...")
    edge_index = torch.tensor([df_edges['uid'].values, df_edges['other_uid'].values], dtype=torch.long)
    from torch_geometric.utils import to_undirected
    edge_index = to_undirected(edge_index)
    logger.info("Edge index shape: %s", list(edge_index.size()))
    return edge_index

# -----------------------------------------------------------------------------
# Train the Unsupervised Model (GAE or DGI)
# -----------------------------------------------------------------------------
def train_unsupervised(model, data, optimizer, device, model_type):
    model.train()
    optimizer.zero_grad()
    if model_type == 'dgi':
        # For DGI, the forward pass returns (pos_z, neg_z, summary)
        pos_z, neg_z, summary = model(x=data.x.to(device), edge_index=data.edge_index.to(device))
        loss = model.loss(pos_z, neg_z, summary)
    else:
        z = model.encode(data.x.to(device), data.edge_index.to(device))
        loss = model.recon_loss(z, data.edge_index.to(device))
    loss.backward()
    optimizer.step()
    return loss.item()

# -----------------------------------------------------------------------------
# Main Function: Setup, Training/Embedding Generation, and Saving
# -----------------------------------------------------------------------------
def main(args):
    device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
    logger.info("Using device: %s", device)
    os.makedirs("stored", exist_ok=True)
    
    df_edges = load_edges(args.csv_file)
    edge_index = build_edge_index(df_edges)
    uids, x = build_feature_matrix(args.features_file, df_edges)
    
    data = Data(x=x, edge_index=edge_index).to(device)
    
    in_channels = x.size(1)
    encoder = GCNEncoder(in_channels, args.hidden_channels, args.embedding_dim, args.num_layers)
    
    if args.model_type == 'gae':
        model = GAE(encoder).to(device)
    elif args.model_type == 'dgi':
        # Mimic the provided DGI example:
        # Use a summary function that computes the mean of the encoder output and applies sigmoid.
        summary_fn = lambda z, *args, **kwargs: z.mean(dim=0).sigmoid()
        # Define a corruption function that randomly shuffles the node features.
        def corruption(x, edge_index):
            return x[torch.randperm(x.size(0), device=x.device)], edge_index
        model = DeepGraphInfomax(encoder=encoder, summary=summary_fn, corruption=corruption, hidden_channels=args.hidden_channels).to(device)
    else:
        raise ValueError(f"Unsupported model_type: {args.model_type}")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    os.makedirs("stored/", exist_ok=True)
    if args.mode == "embed":
        logger.info("Loading model from %s", args.model_file)
        model.load_state_dict(torch.load(os.path.join("stored", args.model_file), map_location=device))
    else:
        logger.info("Starting training for %d epochs using %s...", args.num_epochs, args.model_type)
        for epoch in range(1, args.num_epochs + 1):
            loss = train_unsupervised(model, data, optimizer, device, args.model_type)
            logger.info("Epoch: %03d, Loss: %.4f", epoch, loss)
        torch.save(model.state_dict(), os.path.join("stored", args.model_file))
        logger.info("Saved trained model to %s", args.model_file)
    
    model.eval()
    with torch.no_grad():
        if args.model_type == 'dgi':
            embeddings = model.encoder(data.x, data.edge_index).cpu().numpy()
        else:
            embeddings = model.encode(data.x, data.edge_index).cpu().numpy()
    
    logger.info("Generated embeddings for %d nodes", embeddings.shape[0])
    embedding_dict = {uid: embeddings[uid] for uid in uids}
    
    with open(os.path.join("stored", args.embeddings_file), 'wb') as f:
        pickle.dump(embedding_dict, f)
    logger.info("Saved embeddings to %s", args.embeddings_file)

# -----------------------------------------------------------------------------
# Argument Parser for Configurable Parameters and Mode
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Unsupervised Node Embedding using GAE or DGI with PyTorch Geometric"
    )
    parser.add_argument("--csv_file", type=str, default="friendship.csv",
                        help="Path to the CSV file containing friendship edges with columns 'uid' and 'other_uid'")
    parser.add_argument("--features_file", type=str, default="node_features.pkl",
                        help="Path to the pickle file containing node features as a dict (uid -> feature vector)")
    parser.add_argument("--model_type", type=str, default="gae", choices=["gae", "dgi"],
                        help="Type of unsupervised model to use: 'gae' or 'dgi'")
    parser.add_argument("--embedding_dim", type=int, default=128,
                        help="Dimension of node embeddings (default 128)")
    parser.add_argument("--hidden_channels", type=int, default=128,
                        help="Dimension of hidden channels in the GCN encoder")
    parser.add_argument("--num_layers", type=int, default=2,
                        help="Number of layers in the GCN encoder (minimum 2)")
    parser.add_argument("--num_epochs", type=int, default=200,
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=0.01,
                        help="Learning rate for the optimizer")
    parser.add_argument("--gpu_id", type=int, default=0,
                        help="GPU ID to use (if available)")
    parser.add_argument("--model_file", type=str, default="unsupervised_model.pth",
                        help="File path to save/load the trained model")
    parser.add_argument("--embeddings_file", type=str, default="node_embeddings.pkl",
                        help="File path to save node embeddings dictionary")
    parser.add_argument("--mode", type=str, default="train", choices=["train", "embed"],
                        help="Mode: 'train' to train and save the model, 'embed' to load a trained model and generate embeddings")
    args = parser.parse_args()
    
    main(args)
