import os
import json
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
from functools import partial
import torch
from collections import defaultdict

def dict_to_sparse_tensor(sparse_dict):
    """Convert a dictionary back to a sparse tensor"""
    return torch.sparse_coo_tensor(
        indices=torch.LongTensor(sparse_dict['indices']),
        values=torch.FloatTensor(sparse_dict['values']),
        size=sparse_dict['size']
    ).coalesce()

def process_single_smiles(smiles_data, tokenizer, max_node_type, max_edge_type):
    """Process single SMILES returning numpy arrays for sparse representation"""
    smiles, = smiles_data
    if pd.isna(smiles) or str(smiles).strip() == "":
        return None, None, None
    
    node_types, adj = tokenizer.encode(smiles, update_vocab_edge=False)
    # Get adj for bond types only
    adj, _ = tokenizer.get_bond_position_by_vocab(adj)
    adj = np.array(adj) + 1
    # Get edges
    rows, cols = np.nonzero(adj)
    edge_type = adj[rows, cols]
    node_types = np.array(node_types, dtype=np.int64)
    unknown_id = len(tokenizer.vocab_node)
    
    # First create dense array for current edges
    current_occur = np.zeros((max_node_type, max_node_type, max_edge_type + 1))
    for row, col, edge in zip(rows, cols, edge_type):
        node1, node2 = node_types[row], node_types[col]
        if node1 == unknown_id or node2 == unknown_id:
            continue
        current_occur[node1, node2, edge] += 2
    
    # Count potential edges and null edges
    node_types_valid = node_types[node_types != unknown_id]
    node_counts = np.bincount(node_types_valid, minlength=max_node_type)
    null_edge_count = (adj.shape[0] ** 2 - adj.shape[0] - 2 * len(edge_type))
    potential_edges = node_counts.reshape(-1,1) * node_counts.reshape(1,-1)
    potential_edges = potential_edges - np.diag(node_counts)
    
    # Update null edges in the dense array
    current_occur[:,:,0] = potential_edges * 2 - current_occur.sum(axis=-1)
    
    # Convert dense to sparse representation
    nonzero_indices = np.nonzero(current_occur)
    values = current_occur[nonzero_indices]
    indices = np.stack(nonzero_indices).astype(np.int64)
    values = values.astype(np.float32)
    
    return len(node_types), indices, values

def arrays_to_sparse_tensor(indices, values, size):
    """Convert numpy arrays to sparse tensor"""
    return torch.sparse_coo_tensor(
        torch.LongTensor(indices),
        torch.FloatTensor(values),
        size=size
    ).coalesce()

def compute_dataset_metainfo(root, source_name, output_filename, source_df, tokenizer, test_index=None, output_dir=None, num_processes=None):
    """Compute meta information preserving sparsity"""
    print('Computing meta info for', source_name)

    if test_index is not None:
        non_test_index = list(set(range(len(source_df))) - set(test_index))
        source_df = source_df.iloc[non_test_index]
    tot_smiles = source_df['smiles'].tolist()
    
    node_frequency = [v[1] for v in tokenizer.vocab_node_stats.values()]
    # edge_frequency = [v[1] for v in tokenizer.vocab_edge_stats.values()]
    bond_frequencies = defaultdict(int)
    pos_frequencies = defaultdict(int)
    for (pos_type, bond_type), (_, freq) in tokenizer.vocab_edge_stats.items():
        bond_frequencies[bond_type] += freq
        pos_frequencies[pos_type] += freq
    max_bond = max(bond_frequencies.keys())
    max_bond = max(3, max_bond) # set mininum value to 2: single/double/triple bond
    edge_frequency = [bond_frequencies[i] if i in bond_frequencies else 0 
                    for i in range(max_bond)]
    
    max_pos = max(pos_frequencies.keys())
    max_pos = max(tokenizer.max_atom_in_token, max_pos)
    pos_frequency = [pos_frequencies[i] if i in pos_frequencies else 0 
                    for i in range(max_pos)]

    max_node_type = tokenizer.max_node_type
    max_bond_type = tokenizer.max_bond_type
    tensor_size = (max_node_type, max_node_type, max_bond_type + 1)
    
    # Prepare multiprocessing
    smiles_data = [(smiles,) for smiles in tot_smiles]
    process_func = partial(process_single_smiles, 
                         tokenizer=tokenizer,
                         max_node_type=max_node_type,
                         max_edge_type=max_bond_type)
    
    # Process SMILES with multiprocessing 
    if num_processes == 1:
        results = [process_func(data) for data in tqdm(smiles_data)]
    else:
        num_processes = num_processes or cpu_count()
        # with Pool(processes=num_processes) as pool:
        #     results = list(pool.map(process_func, smiles_data))
        results = process_map(process_func, 
                            smiles_data,
                            max_workers=num_processes,
                            desc='Processing SMILES for meta info')

    # Combine results
    num_node_per_mol = []
    all_indices = []
    all_values = []
    
    for num_nodes, indices, values in results:
        if num_nodes is not None:
            num_node_per_mol.append(num_nodes)
            all_indices.append(indices)
            all_values.append(values)
    
    # Combine arrays
    combined_indices = np.concatenate(all_indices, axis=1)
    combined_values = np.concatenate(all_values)
    
    # Create sparse tensor
    tot_co_occurrence = arrays_to_sparse_tensor(
        combined_indices,
        combined_values,
        tensor_size
    )
    
    # Calculate statistics
    node_num_frequency = np.bincount(num_node_per_mol, minlength=max(num_node_per_mol))
    normalized_node_dist = np.array(node_frequency) / np.sum(node_frequency)
    
    # Get null edge count from sparse tensor (values at index 0)
    dense_co_occur = tot_co_occurrence.to_dense()
    null_edge_count = dense_co_occur[:,:,0].sum().item()
    edge_frequency = [null_edge_count] + edge_frequency
    normalized_edge_dist = np.array(edge_frequency) / np.sum(edge_frequency)
    pos_frequency = [null_edge_count] + pos_frequency
    normalized_pos_dist = np.array(pos_frequency) / np.sum(pos_frequency)
    
    # Normalize co-occurrence distribution
    no_edge = torch.sum(dense_co_occur, dim=-1) == 0
    first_elt = dense_co_occur[:, :, 0].clone()
    first_elt[no_edge] = 1
    dense_co_occur[:, :, 0] = first_elt
    
    sum_per_pair = torch.sum(dense_co_occur, dim=-1, keepdim=True)
    normalized_co_occur = (dense_co_occur / sum_per_pair).to_sparse().coalesce()
    
    # Prepare output dictionary
    meta_dict = {
        'source': source_name,
        'num_graph': len(tot_smiles),
        'max_node_type': max_node_type,
        'max_bond_type': max_bond_type + 1,
        'max_combined_edge_type': tokenizer.max_edge_type,
        'max_graph_size': max(num_node_per_mol),
        'node_num_frequency': node_num_frequency.tolist(),
        'node_dist': normalized_node_dist.tolist(),
        'edge_dist': normalized_edge_dist.tolist(),
        'pos_dist': normalized_pos_dist.tolist(),
        'co_occur_dist': {
            'indices': normalized_co_occur.indices().tolist(),
            'values': normalized_co_occur.values().tolist(),
            'size': list(normalized_co_occur.size())
        }
    }
    
    # Save results
    if output_dir is None:
        save_dir = os.path.join(root)
    else:
        save_dir = output_dir
    
    # output_filename = os.path.join(save_dir, f'{source_name}.meta.json')
    with open(output_filename, "w") as f:
        json.dump(meta_dict, f)
    
    return meta_dict