import argparse
import copy
import json
import math
import os
from pathlib import Path
from typing import Dict

import numpy as np
import torch
from torch_frame import stype
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_geometric.loader import NeighborLoader
from torch_geometric.seed import seed_everything
from tqdm import tqdm

import matplotlib.pyplot as plt

from relbench.base import Dataset, EntityTask, TaskType
from relbench.datasets import get_dataset
from relbench.modeling.graph import get_node_train_table_input, make_pkey_fkey_graph
from relbench.modeling.utils import get_stype_proposal
from relbench.tasks import get_task
from utils import GloveTextEmbedding


parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="rel-f1")
parser.add_argument("--task", type=str, default="driver-dnf")
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--channels", type=int, default=128)
parser.add_argument("--aggr", type=str, default="sum")
parser.add_argument("--num_layers", type=int, default=2)
parser.add_argument("--num_neighbors", type=int, default=128)
parser.add_argument("--temporal_strategy", type=str, default="uniform")
parser.add_argument("--pos_enc", type=str, default="none")
parser.add_argument("--max_degree", type=int, default=10000)
parser.add_argument("--pos_enc_dim", type=int, default=128)
parser.add_argument("--max_steps_per_epoch", type=int, default=2000)
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--out_dir", type=str, default="results/debug")
parser.add_argument("--gnn", type=str, default="sage")
parser.add_argument("--device", type=int, default=0)
parser.add_argument(
    "--cache_dir",
    type=str,
    default=os.path.expanduser("~/.cache/relbench_examples"),
)
args = parser.parse_args()


device = torch.device(f"cuda:{args.device}" if args.device >= 0 and torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    torch.set_num_threads(1)
seed_everything(args.seed)

dataset: Dataset = get_dataset(args.dataset, download=True)
task: EntityTask = get_task(args.dataset, args.task, download=True)

# Attempt to load stype mapping from cache
stypes_cache_path = Path(f"{args.cache_dir}/{args.dataset}/stypes.json")
try:
    with open(stypes_cache_path, "r") as f:
        col_to_stype_dict = json.load(f)
    for table, col_to_stype in col_to_stype_dict.items():
        for col, stype_str in col_to_stype.items():
            col_to_stype[col] = stype(stype_str)
except FileNotFoundError:
    col_to_stype_dict = get_stype_proposal(dataset.get_db())
    Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True)
    with open(stypes_cache_path, "w") as f:
        json.dump(col_to_stype_dict, f, indent=2, default=str)

# Build the PyG data
data, col_stats_dict = make_pkey_fkey_graph(
    dataset.get_db(),
    col_to_stype_dict=col_to_stype_dict,
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=GloveTextEmbedding(device=device), 
        batch_size=256
    ),
    cache_dir=f"{args.cache_dir}/{args.dataset}/materialized",
)

# Optional position/relational encoding
if args.pos_enc == "degree":
    from utils_pe import get_node_degrees
    data, args.max_degree = get_node_degrees(data)

if args.pos_enc == "rel_pos_enc":
    from utils_pe import get_relational_positional_encoding
    data = get_relational_positional_encoding(
        dataset.get_db(), data, pos_enc_dim=args.pos_enc_dim
    )

# Create NeighborLoaders
loader_dict: Dict[str, NeighborLoader] = {}
for split in ["train", "val", "test"]:
    table = task.get_table(split)
    table_input = get_node_train_table_input(table=table, task=task)

    loader_dict[split] = NeighborLoader(
        data,
        num_neighbors=[int(args.num_neighbors / 2**i) for i in range(args.num_layers)],
        time_attr="time",
        input_nodes=table_input.nodes,
        input_time=table_input.time,
        transform=table_input.transform,
        batch_size=args.batch_size,
        temporal_strategy=args.temporal_strategy,
        shuffle=(split == "train"),
        num_workers=args.num_workers,
        persistent_workers=args.num_workers > 0,
    )

print("Finished creating NeighborLoaders.")

# --------------------------------------------------------------------
# NEW CODE: Collect neighbor-size distributions and plot them
# --------------------------------------------------------------------

# Directory structure: results/fig/<dataset>/<task>/
fig_dir = Path("results") / "fig" / args.dataset / args.task
fig_dir.mkdir(parents=True, exist_ok=True)

for split in ["train", "val", "test"]:
    print(f"Processing {split} split to get neighbor-size distribution...")
    neighbor_sizes = []
    
    # Each 'batch' is a subgraph with a certain number of nodes
    # 'batch.n_id' contains the global node indices in this subgraph
    total_batches = len(loader_dict[split])
    for batch in tqdm(loader_dict[split], desc=f"{split.capitalize()} Batches", total=total_batches):
        total_subgraph_nodes = sum(batch[ntype].num_nodes for ntype in batch.node_types)
        neighbor_sizes.append(total_subgraph_nodes)

    # Create a histogram for the distribution
    plt.figure(figsize=(7, 5))
    plt.hist(neighbor_sizes, bins=50, edgecolor="black")
    plt.title(f"Neighbor Size Distribution ({split} split)")
    plt.xlabel("Number of nodes in sampled subgraph")
    plt.ylabel("Frequency")

    # Save figure: <split>_neighbor_size_distribution.png
    fig_path = fig_dir / f"{split}_neighbor_size_distribution.png"
    plt.savefig(fig_path, dpi=200, bbox_inches="tight")
    plt.close()

print("Neighbor-size distributions have been plotted and saved.")