import pickle
import pymetis
import json
import torch
import numpy as np
from tqdm import tqdm
import os
import argparse
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor


def create_partitions(dataset_path, mapping_file, edge_file, output_file, k=512):
    with open(mapping_file, 'r', encoding='utf-8') as f:
        mapping = json.load(f)
    asin2idx = mapping["asin2idx"]
    idx2asin = {int(k): v for k, v in mapping["idx2asin"].items()}

    print('Loading edge data...')
    num_nodes = max(asin2idx.values()) + 1
    num_partitions = num_nodes // k + 1

    print(f"Using {num_partitions} partitions")
    nodes_per_partition = num_nodes // num_partitions + 1
    print(f"Approx. nodes per partition: {nodes_per_partition}")
    
    membership = [min(i // nodes_per_partition, num_partitions-1) for i in range(num_nodes)]
    partition_edge_counts = defaultdict(int)
    
    temp_dir = os.path.join(dataset_path, "temp_partitions")
    os.makedirs(temp_dir, exist_ok=True)
    
    partition_files = {}
    for part_id in range(num_partitions):
        partition_files[part_id] = os.path.join(temp_dir, f"partition_{part_id}.pkl")
        part_data = {
            "nodes": [],
            "edges": {
                "src": [],
                "dst": [],
                "cluster_ids": []
            }
        }
        with open(partition_files[part_id], "wb") as f:
            pickle.dump(part_data, f)
    
    partition_nodes = defaultdict(set)
    print("Counting total edges...")
    with open(edge_file, 'r', encoding='utf-8') as f:
        total_edges = sum(1 for _ in f)
    print(f"Total edges: {total_edges}")
    
    chunk_size = 1000000
    processed_edges = 0
    
    print("Processing edges to assign to partitions...")
    with open(edge_file, 'r', encoding='utf-8') as f:
        edge_buffer = []
        for i, line in enumerate(tqdm(f, total=total_edges, desc="Processing edges")):
            edge = json.loads(line.strip())
            idx1, idx2 = int(edge["product1"]), int(edge["product2"])
            cluster_id = int(edge["cluster_id"])
            
            part1 = membership[idx1]
            part2 = membership[idx2]
            
            if part1 == part2:
                edge_buffer.append((part1, idx1, idx2, cluster_id))
                partition_edge_counts[part1] += 1
        
                partition_nodes[part1].add(idx1)
                partition_nodes[part1].add(idx2)
            
            if len(edge_buffer) >= chunk_size:
                _process_edge_buffer(edge_buffer, partition_files)
                processed_edges += len(edge_buffer)
                edge_buffer = []
        
        # Process remaining edges
        if edge_buffer:
            _process_edge_buffer(edge_buffer, partition_files)
            processed_edges += len(edge_buffer)
    
    print(f"Processed {processed_edges} edges within partitions")
    
    # Finalize partitions and gather metadata
    all_nodes = set()
    total_edges_in_partitions = 0
    valid_partitions = []
    
    for part_id in range(num_partitions):
        if part_id in partition_nodes:
            nodes_list = list(partition_nodes[part_id])
            edge_count = partition_edge_counts[part_id]
            print(f"Partition {part_id}: {len(nodes_list)} nodes, {edge_count} edges")
            
            all_nodes.update(nodes_list)
            total_edges_in_partitions += edge_count
            
            if len(nodes_list) > 0:
                valid_partitions.append(part_id)
    
    print(f"Total unique nodes across all partitions: {len(all_nodes)}")
    print(f"Total edges in partitions: {total_edges_in_partitions}")
    
    # Save metadata
    output_data = {
        "partition_files": partition_files,
        "valid_partitions": valid_partitions,
        "temp_dir": temp_dir,
        "node_count": num_nodes,
        "total_edges": total_edges_in_partitions
    }
    
    with open(output_file, "wb") as f:
        pickle.dump(output_data, f)
    
    print(f"Partition data saved to {output_file}")
    return output_data


def _process_edge_buffer(edge_buffer, partition_files):
    """Process a buffer of edges and add them to the appropriate partition files"""
    # Group edges by partition
    partition_edges = defaultdict(list)
    for part_id, src, dst, cluster in edge_buffer:
        partition_edges[part_id].append((src, dst, cluster))
    
    # Update each partition file
    for part_id, edges in partition_edges.items():
        # Load current data
        with open(partition_files[part_id], "rb") as f:
            part_data = pickle.load(f)
        
        # Add new edges
        for src, dst, cluster in edges:
            part_data["edges"]["src"].append(src)
            part_data["edges"]["dst"].append(dst)
            part_data["edges"]["cluster_ids"].append(cluster)
        
        # Save updated data
        with open(partition_files[part_id], "wb") as f:
            pickle.dump(part_data, f)


def build_batch_data(partition_file, output_file, mapping_file, edge_file, k=512):
    print(f"Loading partition metadata from {partition_file}...")
    with open(partition_file, "rb") as f:
        partition_meta = pickle.load(f)
    
    partition_files = partition_meta["partition_files"]
    valid_partitions = partition_meta["valid_partitions"]
    temp_dir = partition_meta["temp_dir"]

    all_batches = []
    BATCH_SIZE = k
    extra_nodes = []
    
    print("Processing partitions into batches...")
    for part_id in tqdm(valid_partitions, desc="Processing partitions"):
        with open(partition_files[part_id], "rb") as f:
            part_data = pickle.load(f)
        
        nodes = set()
        for src, dst in zip(part_data["edges"]["src"], part_data["edges"]["dst"]):
            nodes.add(src)
            nodes.add(dst)
        nodes = list(nodes)
        
        with open(partition_files[part_id], "rb") as f:
            part_data = pickle.load(f)
        part_data["nodes"] = nodes
        with open(partition_files[part_id], "wb") as f:
            pickle.dump(part_data, f)
        
        if len(nodes) <= BATCH_SIZE:
            all_batches.append({
                'nodes': nodes,
                'edges': part_data["edges"]
            })
            continue
        
        node_index_map = {node: idx for idx, node in enumerate(nodes)}
        node_batches = [nodes[i:i+BATCH_SIZE] for i in range(0, len(nodes), BATCH_SIZE)]
        
        # Process complete batches
        for node_batch in node_batches[:-1]:  # Process all except the last batch
            batch_nodes_set = set(node_batch)
            
            # Extract edges for this batch
            batch_edges = {"src": [], "dst": [], "cluster_ids": []}
            for i, (src, dst) in enumerate(zip(part_data["edges"]["src"], part_data["edges"]["dst"])):
                if src in batch_nodes_set and dst in batch_nodes_set:
                    batch_edges["src"].append(src)
                    batch_edges["dst"].append(dst)
                    batch_edges["cluster_ids"].append(part_data["edges"]["cluster_ids"][i])
            
            all_batches.append({
                'nodes': node_batch,
                'edges': batch_edges
            })
        
        # Add the last batch if it has nodes
        if len(node_batches) > 0 and len(node_batches[-1]) > 0:
            if len(node_batches[-1]) < BATCH_SIZE // 2:  # If last batch is small
                extra_nodes.extend(node_batches[-1])
            else:
                # Create a batch for the last chunk
                last_batch = node_batches[-1]
                batch_nodes_set = set(last_batch)
                
                batch_edges = {"src": [], "dst": [], "cluster_ids": []}
                for i, (src, dst) in enumerate(zip(part_data["edges"]["src"], part_data["edges"]["dst"])):
                    if src in batch_nodes_set and dst in batch_nodes_set:
                        batch_edges["src"].append(src)
                        batch_edges["dst"].append(dst)
                        batch_edges["cluster_ids"].append(part_data["edges"]["cluster_ids"][i])
                
                all_batches.append({
                    'nodes': last_batch,
                    'edges': batch_edges
                })
    
    print(f"Created {len(all_batches)} batches")
    print(f"Extra nodes: {len(extra_nodes)}")
    
    # Fill incomplete batches with extra nodes
    if extra_nodes:
        print("Filling incomplete batches with extra nodes...")
        
        # Create a simple mapping for finding edge clusters
        edge_clusters = {}
        for part_id in valid_partitions:
            with open(partition_files[part_id], "rb") as f:
                part_data = pickle.load(f)
                
            for i, (src, dst) in enumerate(zip(part_data["edges"]["src"], part_data["edges"]["dst"])):
                cluster = part_data["edges"]["cluster_ids"][i]
                edge_clusters[(src, dst)] = cluster
                edge_clusters[(dst, src)] = cluster
        
        # Fill batches that need more nodes
        for batch in tqdm(all_batches, desc="Filling batches"):
            if len(batch['nodes']) < BATCH_SIZE and extra_nodes:
                # Number of nodes to add
                fill_count = min(BATCH_SIZE - len(batch['nodes']), len(extra_nodes))
                added_nodes = extra_nodes[:fill_count]
                batch['nodes'].extend(added_nodes)
                extra_nodes = extra_nodes[fill_count:]
                
                # Update edges
                batch_nodes_set = set(batch['nodes'])
                for node in added_nodes:
                    for other_node in list(batch_nodes_set - set([node])):
                        edge_key = (node, other_node)
                        if edge_key in edge_clusters:
                            batch['edges']["src"].append(node)
                            batch['edges']["dst"].append(other_node)
                            batch['edges']["cluster_ids"].append(edge_clusters[edge_key])
                            
                            # Add reverse edge
                            batch['edges']["src"].append(other_node)
                            batch['edges']["dst"].append(node)
                            batch['edges']["cluster_ids"].append(edge_clusters[edge_key])
            
            # Pad with -1 if needed
            batch['nodes'].extend([-1] * (BATCH_SIZE - len(batch['nodes'])))
    
    # Save all batches
    with open(output_file, "wb") as f:
        pickle.dump(all_batches, f)
    
    print(f"Saved {len(all_batches)} batches to {output_file}")
    return all_batches


class PredefinedBatchDataset:
    def __init__(self, batches, items, idx2asin):
        self.batches = batches
        self.items = items
        self.idx2asin = idx2asin

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        batch = self.batches[idx]
        nodes = batch['nodes']
        nodes = [node for node in nodes if node != -1]
        edges_src = batch['edges']['src']
        edges_dst = batch['edges']['dst']
        edge_cluster = batch['edges']['cluster_ids']
        titles = []
        asins = []
        asin_idx = []

        for node in nodes:
            asin_idx.append(node)
            asins.append(self.idx2asin[node])
            item_title = self.items[self.idx2asin[node]]["title"]
            titles.append(item_title)

        # Build connection matrix
        node_idx_map = {node: idx for idx, node in enumerate(nodes)}
        co_purchase_label_matrix = np.zeros((len(nodes), len(nodes)), dtype=np.float32)
        for src, dst, cluster in zip(edges_src, edges_dst, edge_cluster):
            if src in node_idx_map and dst in node_idx_map:
                co_purchase_label_matrix[node_idx_map[src], node_idx_map[dst]] = cluster + 1
        np.fill_diagonal(co_purchase_label_matrix, 11)

        return titles, asins, asin_idx, torch.tensor(co_purchase_label_matrix)


def preprocess_single_batch(idx, dataset, output_dir):
    titles, asins, asin_idx, co_purchase_label_matrix = dataset[idx]

    # Package batch data
    batch_data = {
        "titles": titles,
        "asins": asins,
        "asin_idx": asin_idx,
        "matrix": co_purchase_label_matrix
    }

    # Save to file
    batch_file = os.path.join(output_dir, f"batch_{idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    return idx


def preprocess_and_save_batches_multithreaded(dataset, output_dir, max_workers=8):
    os.makedirs(output_dir, exist_ok=True)

    # Use thread pool for parallel processing
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(preprocess_single_batch, idx, dataset, output_dir)
            for idx in range(len(dataset))
        ]

        # Show progress and wait for all threads to complete
        for future in tqdm(futures, desc="Preprocessing Dataset", total=len(futures)):
            future.result()

    print(f"All batches have been preprocessed and saved to {output_dir}")


def process_batches_in_chunks(batches, items_dict, idx2asin, output_dir, chunk_size=100, max_workers=8):
    """Process batches in smaller chunks to reduce memory usage"""
    os.makedirs(output_dir, exist_ok=True)
    total_batches = len(batches)
    
    for start_idx in range(0, total_batches, chunk_size):
        end_idx = min(start_idx + chunk_size, total_batches)
        print(f"Processing batch chunk {start_idx+1}-{end_idx} of {total_batches}")
        
        # Create a dataset for just this chunk
        chunk_batches = batches[start_idx:end_idx]
        chunk_dataset = PredefinedBatchDataset(chunk_batches, items_dict, idx2asin)
        
        # Process this chunk
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = []
            
            # Use local indices for dataset access, but global indices for output files
            for i in range(len(chunk_batches)):
                global_idx = start_idx + i
                # Create a function that accesses dataset with LOCAL index but saves with GLOBAL index
                def process_batch(local_idx, global_idx, dataset, output_dir):
                    titles, asins, asin_idx, co_purchase_label_matrix = dataset[local_idx]
                    
                    # Package batch data
                    batch_data = {
                        "titles": titles,
                        "asins": asins,
                        "asin_idx": asin_idx,
                        "matrix": co_purchase_label_matrix
                    }
                    
                    # Save to file using global index
                    batch_file = os.path.join(output_dir, f"batch_{global_idx}.pkl")
                    with open(batch_file, 'wb') as f:
                        pickle.dump(batch_data, f)
                    return global_idx
                
                # Submit the task with local and global indices
                futures.append(executor.submit(process_batch, i, global_idx, chunk_dataset, output_dir))
            
            for future in tqdm(futures, desc=f"Processing chunk {start_idx+1}-{end_idx}", total=len(futures)):
                future.result()
                
        # Clear memory
        chunk_batches = None
        chunk_dataset = None
    
    print(f"All {total_batches} batches have been processed and saved to {output_dir}")


def parse_args():
    parser = argparse.ArgumentParser(description='Dataset preprocessing script')
    parser.add_argument('--dataset_name', type=str, default='Handmade_Products', 
                        help='Name of the dataset to process')
    parser.add_argument('--k', type=int, default=512, 
                        help='Number of nodes per partition')
    return parser.parse_args()


def main():
    args = parse_args()

    data_root = '/home/yqiao47/dataset'
    dataset_name = args.dataset_name
    k = args.k
    
    dataset_path = os.path.join(data_root, dataset_name)
    mapping_file = os.path.join(dataset_path, f"mapping_train.json")
    edge_file = os.path.join(dataset_path, f"{dataset_name}_edge_train.jsonl")
    item_file = os.path.join(dataset_path, f"{dataset_name}_item_train.jsonl")
    
    partition_file = os.path.join(dataset_path, f"all_partitions_with_unique_edge_train.pkl")
    batch_file = os.path.join(dataset_path, f"batch_data_edge_train.pkl")
    output_dir = os.path.join(dataset_path, f"preprocessed_batches_edge_train")
    
    # Create partitions
    create_partitions(dataset_path, mapping_file, edge_file, partition_file, k)
    
    # Build batch data
    build_batch_data(partition_file, batch_file, mapping_file, edge_file, k)
    
    # Load item data for final processing
    print('Loading mapping file...')
    with open(mapping_file, 'r', encoding='utf-8') as f:
        mapping = json.load(f)
    idx2asin = {int(k): v for k, v in mapping["idx2asin"].items()}

    # Load item data in a memory-efficient way
    print('Loading item data...')
    items_dict = {}
    with open(item_file, 'r', encoding='utf-8') as f:
        for line in tqdm(f, desc="Loading items"):
            item = json.loads(line)
            items_dict[item["parent_asin"]] = {k: v for k, v in item.items() if k != "parent_asin"}
    
    # Load batches
    print(f"Loading batches from {batch_file}...")
    with open(batch_file, 'rb') as f:
        batches = pickle.load(f)
    
    # Process batches in chunks
    process_batches_in_chunks(batches, items_dict, idx2asin, output_dir, chunk_size=100, max_workers=16)
    
    print("Processing complete!")


if __name__ == "__main__":
    main()