import pickle
import pymetis
import json
import torch
import numpy as np
from tqdm import tqdm
import sys
import os
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor

def generate_partitions(dataset_name, k):
    mapping_file = f"/home/yqiao47/dataset/{dataset_name}/asin2idx.json"
    edge_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_edge_{k}_train.jsonl"
    output_file = f"/home/yqiao47/dataset/{dataset_name}/all_partitions_with_unique_edge_train.pkl"
    
    # Number of partitions, wierd number?
    num_partitions = 2324
    
    print(f"Generating partitions for dataset: {dataset_name} with k={k}")
    
    # Read mapping file
    with open(mapping_file, 'r', encoding='utf-8') as f:
        mapping = json.load(f)
    asin2idx = mapping["asin2idx"]

    print('Loading edge data...')
    num_nodes = max(asin2idx.values()) + 1  # Determine total number of nodes
    adjacency_list = [[] for _ in range(num_nodes)]  # Store adjacency list
    edge_index = []
    edge_clusters = []

    with open(edge_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Loading edges"):
            edge = json.loads(line.strip())
            idx1, idx2 = int(edge["product1"]), int(edge["product2"])
            cluster_id = int(edge["cluster_id"])

            # Store in adjacency list
            adjacency_list[idx1].append(idx2)
            adjacency_list[idx2].append(idx1)
            edge_index.append((idx1, idx2))
            edge_index.append((idx2, idx1))
            edge_clusters.append(cluster_id)
            edge_clusters.append(cluster_id)

    print(f"Total edges: {len(edge_index) // 2}")
    print(f"Total nodes: {num_nodes}")

    # Convert to NumPy/PyTorch format
    adjacency_list = [np.array(neighbors, dtype=np.int32) for neighbors in adjacency_list]
    edge_index = torch.tensor(np.array(edge_index).T, dtype=torch.long)
    edge_clusters = torch.tensor(np.array(edge_clusters), dtype=torch.long)

    print("Extracting batches...")
    _, membership = pymetis.part_graph(num_partitions, adjacency=adjacency_list)
    batch_data = {}
    all_nodes = []
    all_edges = []
    all_clusters = []

    for part_id in tqdm(range(num_partitions), desc="Extracting subgraphs"):
        # Extract nodes in this partition
        mask = torch.tensor([i == part_id for i in membership], dtype=torch.bool)

        # Save global node IDs for this subgraph
        global_node_ids = torch.where(mask)[0]
        selected_edges_mask = mask[edge_index[0]] & mask[edge_index[1]]
        selected_edges = edge_index[:, selected_edges_mask]
        selected_clusters = edge_clusters[selected_edges_mask]
        all_nodes.append(global_node_ids)
        all_edges.append(selected_edges)
        all_clusters.append(selected_clusters)
        batch_data[part_id] = {
            "nodes": global_node_ids,
            "edges": {
                "src": selected_edges[0].tolist(),
                "dst": selected_edges[1].tolist(),
                "cluster_ids": selected_clusters.tolist()
            }
        }

    # Merge all nodes and edges and remove duplicates
    unique_nodes = torch.cat(all_nodes).unique()
    unique_edges = torch.cat(all_edges, dim=1).unique(dim=1)
    unique_clusters = torch.cat(all_clusters).unique()
    print(f"Total unique nodes: {unique_nodes.size(0)}")
    print(f"Total unique edges: {unique_edges.size(1)}")
    print(f"Total unique clusters: {unique_clusters.size(0)}")

    # Save partition data and global statistics
    output_data = {
        "partitions": batch_data,
        "unique_nodes": unique_nodes.tolist(),
        "unique_edges": unique_edges.tolist(),
        "unique_clusters": unique_clusters.tolist()
    }

    with open(output_file, "wb") as f:
        pickle.dump(output_data, f)
    print(f"Graph partitioning completed. Data saved to {output_file}")
    return output_file

def process_batches(dataset_name, k, partition_file):
    # Define paths
    mapping_file = f"/home/yqiao47/dataset/{dataset_name}/asin2idx.json"
    edge_file = f"/home/yqiao47/{dataset_name}/{dataset_name}_edge_{k}_train.jsonl"

    output_file = f"/home/yqiao47/{dataset_name}/batch_data_edge_train.pkl"
    temp_dir = f"/home/yqiao47/{dataset_name}"
    
    print(f"Processing batches for dataset: {dataset_name} with k={k}")
    
    # Read mapping file
    with open(mapping_file, 'r', encoding='utf-8') as f:
        mapping = json.load(f)

    # Count total edges
    print('Counting total edges...')
    with open(edge_file, 'r', encoding='utf-8') as f:
        total_edges = sum(1 for _ in f)
    print(f"Total edges: {total_edges}")

    # Initialize data structures
    edge_list = []
    edge_clusters = []

    adjacency_list = defaultdict(list)
    node_raw = []

    # Load edges
    print("Loading edges...")
    with open(edge_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, total=total_edges, desc="Processing edges"):
            edge = json.loads(line.strip())
            idx1, idx2 = int(edge["product1"]), int(edge["product2"])
            cluster_id = int(edge["cluster_id"])

            # Store adjacency
            adjacency_list[idx1].append(idx2)
            adjacency_list[idx2].append(idx1)

            # Store edge indices
            edge_list.append([idx1, idx2])
            edge_list.append([idx2, idx1])

            # Store cluster IDs
            edge_clusters.append(cluster_id)
            edge_clusters.append(cluster_id)

    # Convert to NumPy
    edge_index_np = np.array(edge_list, dtype=np.int32).T
    edge_clusters_np = np.array(edge_clusters, dtype=np.int32)

    # Load partitions
    print("Loading partitions...")
    with open(partition_file, "rb") as f:
        partitions = pickle.load(f)

    # Extract original batches
    original_batches = partitions['partitions']

    # Configure batch size
    BATCH_SIZE = 512
    new_batches = []
    extra_nodes = []

    # Helper function to find edge index
    def find_edge_index(edge_index_np, edge_clusters_np, i, j):
        mask = (edge_index_np[0] == i) & (edge_index_np[1] == j)
        idx = np.where(mask)[0]
        if len(idx) > 0:
            return edge_clusters_np[idx[0]]
        return -1

    # Process each batch
    print("Processing batches...")
    for batch_id, batch_data in tqdm(original_batches.items(), desc="Processing batches"):
        nodes = batch_data['nodes']
        edges = batch_data['edges']
        src_list, dst_list, cluster_list = edges["src"], edges["dst"], edges["cluster_ids"]
        node_raw.append(torch.from_numpy(np.array(nodes, dtype=np.int32)))
        
        # If node count is less than BATCH_SIZE
        if len(nodes) <= BATCH_SIZE:
            new_batches.append({
                'nodes': nodes,
                'edges': {
                    "src": src_list,
                    "dst": dst_list,
                    "cluster_ids": cluster_list
                }
            })
        else:
            # Split large batches
            while len(nodes) > BATCH_SIZE:
                mask = [(src in nodes[:BATCH_SIZE] and dst in nodes[:BATCH_SIZE]) for src, dst in zip(src_list, dst_list)]
                src_filtered = [src for src, keep in zip(src_list, mask) if keep]
                dst_filtered = [dst for dst, keep in zip(dst_list, mask) if keep]
                cluster_filtered = [c for c, keep in zip(cluster_list, mask) if keep]

                new_batches.append({
                    'nodes': nodes[:BATCH_SIZE],
                    'edges': {
                        "src": src_filtered,
                        "dst": dst_filtered,
                        "cluster_ids": cluster_filtered
                    }
                })
                nodes = nodes[BATCH_SIZE:]
                src_list = src_list[BATCH_SIZE:]
                dst_list = dst_list[BATCH_SIZE:]
                cluster_list = cluster_list[BATCH_SIZE:]

            # Store remaining nodes
            extra_nodes.extend(nodes)

    print(f"Extra nodes: {len(extra_nodes)}")

    # Fill incomplete batches
    all_nodes = []
    all_edges = []
    all_clusters = []
    
    print("Filling incomplete batches...")
    for batch in tqdm(new_batches, desc="Filling batches"):
        batch['nodes'] = list(batch['nodes'])
        if len(batch['nodes']) < BATCH_SIZE:
            if extra_nodes:
                # Fill with extra nodes
                fill_count = min(BATCH_SIZE - len(batch['nodes']), len(extra_nodes))
                added_node = extra_nodes[:fill_count]
                batch['nodes'].extend(added_node)
                extra_nodes = extra_nodes[fill_count:]
                
                # Add edges for new nodes
                added_src = []
                added_dst = []
                added_clusters = []
                for i in added_node:
                    for j in adjacency_list[i]:
                        if j in batch['nodes']:
                            added_src.append(i)
                            added_dst.append(j)
                            added_src.append(j)
                            added_dst.append(i)
                            # Find cluster ID
                            idx = find_edge_index(edge_index_np, edge_clusters_np, i, j)
                            added_clusters.append(edge_clusters[idx])
                            added_clusters.append(edge_clusters[idx])
                            
                # Extend batch edges
                batch['edges']["src"].extend(added_src)
                batch['edges']["dst"].extend(added_dst)
                batch['edges']["cluster_ids"].extend(added_clusters)

                all_nodes.append(torch.tensor(batch['nodes']))
                all_edges.append({
                    "src": torch.tensor(batch['edges']["src"], dtype=torch.long),
                    "dst": torch.tensor(batch['edges']["dst"], dtype=torch.long),
                    "cluster_ids": torch.tensor(batch['edges']["cluster_ids"], dtype=torch.long)
                })
                all_clusters.append(torch.tensor(batch['edges']["cluster_ids"]))
            else:
                all_nodes.append(torch.tensor(batch['nodes']))
                all_edges.append({
                    "src": torch.tensor(batch['edges']["src"], dtype=torch.long),
                    "dst": torch.tensor(batch['edges']["dst"], dtype=torch.long),
                    "cluster_ids": torch.tensor(batch['edges']["cluster_ids"], dtype=torch.long)
                })
                all_clusters.append(torch.tensor(batch['edges']["cluster_ids"]))
                batch['nodes'].extend([-1] * (BATCH_SIZE - len(batch['nodes'])))
        else:
            all_nodes.append(torch.tensor(batch['nodes']))
            all_edges.append({
                "src": torch.tensor(batch['edges']["src"], dtype=torch.long),
                "dst": torch.tensor(batch['edges']["dst"], dtype=torch.long),
                "cluster_ids": torch.tensor(batch['edges']["cluster_ids"], dtype=torch.long)
            })
            all_clusters.append(torch.tensor(batch['edges']["cluster_ids"]))

    # Count unique nodes and edges
    unique_nodes = set()
    total_edges = 0
    for batch in new_batches:
        unique_nodes.update(batch["nodes"])
        total_edges += len(batch["edges"]["src"])

    print(f"Total unique nodes: {len(unique_nodes)}")
    print(f"Total edges: {total_edges}")

    # Save batches in chunks
    BATCH_SAVE_SIZE = 1000
    batch_files = []

    for i in tqdm(range(0, len(new_batches), BATCH_SAVE_SIZE), desc="Saving batch chunks"):
        batch_chunk = new_batches[i:i + BATCH_SAVE_SIZE]
        temp_file = f"{temp_dir}/batch_chunk_{i//BATCH_SAVE_SIZE}.pkl"
        with open(temp_file, "wb") as f:
            pickle.dump(batch_chunk, f)
        batch_files.append(temp_file)
    
    # Merge chunks
    print("Merging all batch chunks into a single pickle file...")
    batch_files = sorted([f for f in os.listdir(temp_dir) if f.startswith("batch_chunk_")])
    final_batches = []

    for file in tqdm(batch_files, desc="Loading batch files"):
        with open(os.path.join(temp_dir, file), "rb") as fin:
            batch_chunk = pickle.load(fin)
            final_batches.extend(batch_chunk)

    # Save final batches
    with open(output_file, "wb") as fout:
        pickle.dump(final_batches, fout)
    print(f"Final merged batch saved to {output_file}")
    
    return output_file

def preprocess_batches(dataset_name, k, batch_file):
    item_file = f"/home/yqiao47/dataset/{dataset_name}/{dataset_name}_item_new.jsonl"
    mapping_file = f"/home/yqiao47/dataset/{dataset_name}/asin2idx.json"
    output_dir = f"/home/yqiao47/dataset/{dataset_name}/preprocessed_batches_edge_train"
    
    class PredefinedBatchDataset:
        def __init__(self, batches, items, idx2asin):
            self.batches = batches
            self.items = items
            self.idx2asin = idx2asin

        def __len__(self):
            return len(self.batches)

        def __getitem__(self, idx):
            batch = self.batches[idx]
            nodes = batch['nodes']
            nodes = [node.item() if isinstance(node, torch.Tensor) else node for node in nodes]
            if -1 in nodes:
                nodes = [node for node in nodes if node != -1]
            edges_src = batch['edges']['src']
            edges_dst = batch['edges']['dst']
            edge_cluster = batch['edges']['cluster_ids']
            titles = []
            asins = []
            asin_idx = []

            for node in nodes:
                asin_idx.append(node)
                asins.append(self.idx2asin[node])
                item_title = self.items[self.idx2asin[node]]["title"]
                titles.append(item_title)

            # Build connection matrix
            node_idx_map = {node: idx for idx, node in enumerate(nodes)}
            co_purchase_label_matrix = np.zeros((len(nodes), len(nodes)), dtype=np.float32)
            for src, dst, cluster in zip(edges_src, edges_dst, edge_cluster):
                if src in node_idx_map and dst in node_idx_map:
                    co_purchase_label_matrix[node_idx_map[src], node_idx_map[dst]] = cluster + 1
            np.fill_diagonal(co_purchase_label_matrix, 11)

            return titles, asins, asin_idx, torch.tensor(co_purchase_label_matrix)

    # Single batch processing function
    def preprocess_single_batch(idx, dataset, output_dir):
        titles, asins, asin_idx, co_purchase_label_matrix = dataset[idx]
        batch_data = {
            "titles": titles,
            "asins": asins,
            "asin_idx": asin_idx,
            "matrix": co_purchase_label_matrix
        }
        batch_file = os.path.join(output_dir, f"batch_{idx}.pkl")
        with open(batch_file, 'wb') as f:
            pickle.dump(batch_data, f)
        return idx

    # Multi-threaded batch processing
    def preprocess_and_save_batches_multithreaded(dataset, output_dir, max_workers=8):
        os.makedirs(output_dir, exist_ok=True)
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [
                executor.submit(preprocess_single_batch, idx, dataset, output_dir)
                for idx in range(len(dataset))
            ]
            for future in tqdm(futures, desc="Preprocessing Dataset", total=len(futures)):
                future.result()
        print(f"All batches have been preprocessed and saved to {output_dir}")

    with open(mapping_file, 'r', encoding='utf-8') as f:
        mapping = json.load(f)
    idx2asin = {int(k): v for k, v in mapping["idx2asin"].items()}

    with open(item_file, 'r', encoding='utf-8') as f:
        items_dict = {
            item["parent_asin"]: {k: v for k, v in item.items() if k != "parent_asin"}
            for item in map(json.loads, f)
        }

    with open(batch_file, 'rb') as f:
        batches = pickle.load(f)

    dataset = PredefinedBatchDataset(batches, items_dict, idx2asin)
    preprocess_and_save_batches_multithreaded(dataset, output_dir, max_workers=16)
    
    print(f"Batch preprocessing completed. Results saved to {output_dir}")

def main():
    if len(sys.argv) < 3:
        print("Usage: python generate_batch.py DATASET_NAME K")
        print("Example: python generate_batch.py Sports_and_Outdoors 10")
        sys.exit(1)
    
    dataset_name = sys.argv[1]
    k = sys.argv[2]
    
    print(f"Starting batch generation for {dataset_name} with k={k}")
    
    # Step 1: Generate partitions
    partition_file = generate_partitions(dataset_name, k)
    
    # Step 2: Process batches
    batch_file = process_batches(dataset_name, k, partition_file)
    
    # Step 3: Preprocess batches
    preprocess_batches(dataset_name, k, batch_file)
    
    print(f"Batch generation complete for {dataset_name}")

if __name__ == "__main__":
    main()