import sys
sys.path.append('../') 

import os
import pathlib
import json
import re
from collections import defaultdict
import random

import torch
from rdkit import Chem, RDLogger
from tqdm import tqdm
import multiprocessing as mp
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.data import Batch as PyGBatch

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from datasets.graphbpe import MolecularGraphTokenizer
from datasets.compute_meta_sparse import compute_dataset_metainfo as compute_meta, dict_to_sparse_tensor
from datasets.utils_token import count_atom
from diffusion.distributions import DistributionNodes
from rdkit.Chem import AllChem
import time

def process_context_data(context_df, molecular_dataset, max_len=50, processed_path=None, verbose=True, retrain=False):
    if processed_path and os.path.exists(processed_path) and not retrain:
        with open(processed_path, 'rb') as f:
            return torch.load(processed_path, weights_only=False)
    
    def get_remaining_length(target_idx, max_len):
        target_nodes = len(molecular_dataset[target_idx].x)
        remaining_len = max_len - target_nodes

        neg_max_len = remaining_len // 4
        med_max_len = remaining_len // 4
        pos_max_len = remaining_len - neg_max_len - med_max_len
        return target_nodes, pos_max_len, med_max_len, neg_max_len
    
    processed_data = []
    iterator = context_df.iterrows()
    if verbose:
        iterator = tqdm(iterator, total=len(context_df), desc="Processing context sequence data")
    
    def process_sequence_and_distance(seq_str, dist_str, max_seq_len, target):
        if pd.isna(seq_str) or seq_str == "":
            return [], []
        sequence = [int(x) for x in seq_str.split(',')]
        distances = [float(x) for x in dist_str.split(',')]
        
        # Filter out target from sequence if present
        filtered_seq = []
        filtered_dist = []
        for seq, dist in zip(sequence, distances):
            if seq != target:
                filtered_seq.append(seq)
                filtered_dist.append(dist)
        
        cumulative_nodes = 0
        cutoff_pos = 0
        
        for idx, seq_idx in enumerate(filtered_seq):
            len_single_graph = len(molecular_dataset[seq_idx].x)
            cumulative_nodes += len_single_graph
            if cumulative_nodes > max_seq_len:
                cutoff_pos = idx
                break
            cutoff_pos = idx + 1
        
        return filtered_seq[:cutoff_pos], filtered_dist[:cutoff_pos]
    
    for row_id, row in iterator:
        target = row['target']
        target_nodes, pos_max_len, med_max_len, neg_max_len = get_remaining_length(target, max_len)
        if target_nodes >= max_len:
            continue
        
        # Process contexts with target filtering
        pos_seq, pos_dist = process_sequence_and_distance(row['positive_context'], row['positive_distance'], pos_max_len, target)
        med_seq, med_dist = process_sequence_and_distance(row['medium_context'], row['medium_distance'], med_max_len, target)
        neg_seq, neg_dist = process_sequence_and_distance(row['negative_context'], row['negative_distance'], neg_max_len, target)
        
        # Construct final sequence with target as first item
        final_sequence = [target] + pos_seq + med_seq + neg_seq
        final_distances = [0.0] + pos_dist + med_dist + neg_dist
        
        if len(final_sequence) < 2:  # Must have at least target and one context
            continue
            
        total_nodes = sum(len(molecular_dataset[idx].x) for idx in final_sequence)
        
        processed_data.append({
            'target': target,
            'sequence': final_sequence,
            'distances': final_distances,
            'total_nodes': total_nodes,
            'row_id': row_id
        })

    if processed_path:
        torch.save(processed_data, processed_path)
        if verbose:
            print(f"Caching processed data to {processed_path}")
    return processed_data


class ContextDataset(torch.utils.data.Dataset):
    def __init__(self, molecular_dataset, processed_data: list, drop_rate=0.5):
        """
        Parameters:
            molecular_dataset: list or dataset of molecular graphs
            processed_data: list of dicts with keys ['sequence', 'distances', 'target', 'row_id']
            alpha: drop strength, 0 = no drop, 1 = max drop at beginning of sequence
        """
        self.molecular_dataset = molecular_dataset
        self.processed_data = processed_data
        self.drop_rate = drop_rate

        if len(self.processed_data) > 0:
            self.seqlen_stats = self.get_context_counts()
            print('ContextDataset', 'seqlen_stats', self.seqlen_stats)

    def __len__(self) -> int:
        return len(self.processed_data) if self.processed_data else len(self.molecular_dataset)

    def __getitem__(self, idx: int):
        if len(self.processed_data) == 0:
            item = {'sequence': [idx], 'distances': [0], 'target': idx, 'row_id': idx}
        else:
            item = self.processed_data[idx]

        sequence = item['sequence']
        distances = item['distances']

        # Apply probabilistic drop to context
        sequence, distances = self.probabilistic_drop(sequence, distances, self.drop_rate)

        molecular_graph_list = [self.molecular_dataset[i] for i in sequence]

        return {
            'row_id': item['row_id'],
            'target': item['target'],
            'sequence': sequence,
            'relations': distances,
            'molecular_graphs': molecular_graph_list
        }

    def probabilistic_drop(self, sequence, distances, drop_rate):
        """
        Randomly drop context elements based on linearly decreasing drop probability.
        Keeps sequence[0] (target), drops from sequence[1:] using probability controlled by alpha.
        """
        if len(sequence) != len(distances):
            raise ValueError("sequence and distances must be the same length")

        if len(sequence) <= 1 or drop_rate <= 0:
            return sequence, distances

        target = sequence[0]
        target_dist = distances[0]
        rest_seq = sequence[1:]
        rest_dist = distances[1:]
        L = len(rest_seq)

        drop_probs = [drop_rate * (L - i) / L for i in range(L)]
        kept = [(s, d) for i, (s, d) in enumerate(zip(rest_seq, rest_dist)) if random.random() > drop_probs[i]]

        new_seq = [target] + [s for s, _ in kept]
        new_dist = [target_dist] + [d for _, d in kept]
        return new_seq, new_dist

    def get_context_counts(self):
        seqlen_list = [len(item['sequence']) for item in self.processed_data]
        return {
            'min': min(seqlen_list),
            'mean': np.mean(seqlen_list),
            'max': max(seqlen_list),
            'median': np.median(seqlen_list)
        }
    

def context_collate_fn(batch):
    context_indicators = []
    for batch_idx, item in enumerate(batch):
        context_indicators.extend([batch_idx] * len(item['sequence']))

    target_id = [item['target'] for item in batch]
    relations = [rel for item in batch for rel in item['relations']]
    flattened_graphs = [graph for item in batch for graph in item['molecular_graphs']]
    molecular_graphs = PyGBatch.from_data_list(flattened_graphs)
    

    return {
        'target_id': target_id,
        'molecular_graphs': molecular_graphs,
        'relations': torch.tensor(relations),
        'context_indicators': torch.tensor(context_indicators)
    }

def check_required_files(output_file):
    """Check if all required files exist"""
    required_extensions = ['.edge', '.motif', '.node', '.rawstats']
    return all(os.path.exists(f"{output_file}{ext}") for ext in required_extensions)

class DataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        self.task = cfg.dataset.task_name
        self.context_name = cfg.dataset.get('context_name', self.task)
        if self.context_name is None:
            self.context_name = self.task
        print('using task', self.task, 'context_name', self.context_name)
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        self.root_path = root_path
        self.prepare_df('molecule')
        self.prepare_df('context')
        super().__init__(cfg)

    def get_context_file_path(self,):
        file_path = os.path.join(self.root_path, 'raw', f'{self.context_name}_context.csv.gz')
        if not os.path.exists(file_path):
            file_path = os.path.join(self.root_path, 'raw', f'{self.context_name}_context.csv')
        return file_path
    
    def get_molecule_file_path(self,):
        file_path = os.path.join(self.root_path, 'raw', f'{self.task}_source.csv.gz')
        if not os.path.exists(file_path):
            file_path = os.path.join(self.root_path, 'raw', f'{self.task}_source.csv')
        return file_path
    
    def prepare_df(self, file_type='molecule'):
        if file_type == 'molecule':
            file_path = self.get_molecule_file_path()
            print('loading molecule df from', file_path)
            start_time = time.time()
            self.molecule_df = pd.read_csv(file_path, engine='pyarrow')
            end_time = time.time()
            print(f"Time taken to load molecule df: {end_time - start_time} seconds")
        elif file_type == 'context':
            file_path = self.get_context_file_path()
            print('loading context df from', file_path)
            start_time = time.time()
            self.context_df = pd.read_csv(file_path, engine='pyarrow')
            end_time = time.time()
            print(f"Time taken to load context df: {end_time - start_time} seconds")


    def prepare_split(self,):
        print('preparing split')
        start_time = time.time()
        if self.context_df is None:
            return
        # train_index, val_index, test_index = self.task_data_split(self.context_df)
        train_index, val_index, test_index = self.random_data_split(self.context_df)
        self.train_index = train_index
        self.val_index = val_index
        self.test_index = test_index
        end_time = time.time()
        print(f"Time taken to prepare split: {end_time - start_time} seconds")

    def task_data_split(self, df):
        np.random.seed(42)
        val_indices = []
        test_indices = []
        
        # Get unique context names
        unique_contexts = df['task_name'].unique()
        
        # For each unique context, sample indices
        for context in unique_contexts:
            context_indices = df[df['task_name'] == context].index.tolist()
            
            if len(context_indices) >= 6:  # Ensure we have enough samples
                test_samples = np.random.choice(context_indices, size=4, replace=False)
                remaining_indices = list(set(context_indices) - set(test_samples))
                
                if remaining_indices:
                    val_sample = np.random.choice(remaining_indices, size=2, replace=False)
                    val_indices.extend(val_sample)
                    
                test_indices.extend(test_samples)
        
        # All remaining indices go to train
        all_indices = set(df.index.tolist())
        # train_indices = list(all_indices - set(val_indices) - set(test_indices))
        train_indices = all_indices
        
        print(f"{self.task} with context {self.context_name} dataset len: {len(df)}, "
            f"all len: {len(all_indices)}, "
            f"train len: {len(train_indices)}, "
            f"val len: {len(val_indices)}, "
            f"test len: {len(test_indices)} ")
        
        return train_indices, val_indices, test_indices

    def random_data_split(self, df):
        labeled_indices = df.index.tolist()
        # train_ratio, valid_ratio, test_ratio = 0.9, 0.05, 0.05
        train_ratio, valid_ratio, test_ratio = 0.4, 0.3, 0.3
        
        # Split labeled data
        train_index, test_index, _, _ = train_test_split(
            labeled_indices, 
            labeled_indices, 
            test_size=test_ratio, 
            random_state=42
        )
        
        train_index, val_index, _, _ = train_test_split(
            train_index, 
            train_index, 
            test_size=valid_ratio/(valid_ratio+train_ratio), 
            random_state=42
        )
        
        train_index = labeled_indices
        print(f"{self.task} with context {self.context_name} dataset len: {len(df)}, "
            f"train len: {len(train_index)} (all), "
            # f"train len: {len(train_index)}, "
            f"val len: {len(val_index)}, "
            f"test len: {len(test_index)} ")
        
        return train_index, val_index, test_index

    def get_node_vocabulary(self, cfg):
        simple_mode = cfg.tokenizer.simple_mode
        folder_name = 'tokenizer_simple' if simple_mode else 'tokenizer'
        voc_len = cfg.tokenizer.vocab_size
        ring_len = cfg.tokenizer.vocab_ring_len
        tokenizer_name = cfg.tokenizer.get('name') if cfg.tokenizer.get('name') else self.task
        output_file = f"{self.root_path}{folder_name}/vocab{voc_len}ring{ring_len}/{tokenizer_name}-token.motif"
        canonical_smiles_list = []
        with open(output_file, 'r') as file:
            for line in file:
                smiles = line.strip()
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    canonical_smiles = Chem.MolToSmiles(mol)
                    canonical_smiles_list.append(canonical_smiles)
                except:
                    canonical_smiles_list.append(smiles)
        return canonical_smiles_list

    def prepare_tokenizer(self, cfg):
        print('preparing tokenizer')
        start_time = time.time()
        vocab_size = cfg.tokenizer.vocab_size
        vocab_ring_len = cfg.tokenizer.vocab_ring_len
        num_processors = cfg.tokenizer.processor
        retrain = cfg.tokenizer.retrain
        simple_mode = cfg.tokenizer.simple_mode
        tokenizer_name = cfg.tokenizer.get('name') if cfg.tokenizer.get('name') else self.task
        token_name = f'{tokenizer_name}-token'
        folder_name = 'tokenizer_simple' if simple_mode else 'tokenizer'
        os.makedirs(f"{self.root_path}{folder_name}/vocab{vocab_size}ring{vocab_ring_len}", exist_ok=True)
        output_file = f"{self.root_path}{folder_name}/vocab{vocab_size}ring{vocab_ring_len}/{token_name}"
        print('checking tokenizer which should be cached at ', output_file)
        output_exist = check_required_files(output_file)
        if retrain:
            output_exist = False
        df_used = self.molecule_df
        tot_smiles = df_used['smiles'].tolist()
        tokenizer = MolecularGraphTokenizer(kekulize=True, name=tokenizer_name, simple_mode=simple_mode)
        
        if not output_exist:
            tokenizer.train_node(tot_smiles, vocab_len=vocab_size, vocab_ring_len=vocab_ring_len, num_processors=num_processors)
            
            # Validate encoding/decoding
            failed_cases = []
            succ_no_equal = []
            total_processed = 0
            unknown_cases = []
            
            for test_smiles in tqdm(tot_smiles, desc="Training edge vocabularies and validating SMILES"):
                total_processed += 1
                try:
                    node_feature, edge_adj = tokenizer.encode(test_smiles, update_vocab_edge=True)
                    bond_adj, position_adj = tokenizer.get_bond_position_by_vocab(edge_adj)
                    decoded, error_msg, contain_unknown = tokenizer.decode(node_feature, bond_adj, position_adj, replace_unknown_with_random=False)
                    if contain_unknown:
                        unknown_cases.append({
                                'original': test_smiles,
                                'decoded': decoded,
                                'error': error_msg
                            })
                        continue
                    if decoded is None:
                        failed_cases.append({
                            'smiles': test_smiles,
                            'error': error_msg
                        })
                    else:
                        if test_smiles != decoded:
                            succ_no_equal.append({
                                'original': test_smiles,
                                'decoded': decoded
                            })
                except Exception as e:
                    failed_cases.append({
                        'smiles': test_smiles,
                        'error': str(e)
                    })
            
            # Calculate statistics
            unknown_len = len(unknown_cases)
            fail_ratio = len(failed_cases) / (total_processed - unknown_len)
            unequal_ratio = len(succ_no_equal) / (total_processed - unknown_len)
            
            # Save validation results
            check_file = f"{output_file}.check"
            with open(check_file, 'w') as f:
                # Write summary statistics
                f.write(f"Validation Summary:\n")
                f.write(f"Total processed: {total_processed}\n")
                f.write(f"Contain unknown: {unknown_len}\n")
                f.write(f"Failure ratio (Excluding Unknown): {fail_ratio:.4f} ({len(failed_cases)} cases)\n")
                f.write(f"Structure mismatch ratio (Excluding Unknown): {unequal_ratio:.4f} ({len(succ_no_equal)} cases)\n")
                # Add unknown tokens information
                f.write(f"Unknown tokens ({len(tokenizer.unknown)}): {', '.join(ukn for ukn in tokenizer.unknown)}\n\n")
                
                # Write failed cases
                if failed_cases:
                    f.write("Failed Cases:\n")
                    f.write("-" * 80 + "\n")
                    for case in failed_cases:
                        f.write(f"SMILES: {case['smiles']}\n")
                        f.write(f"Error: {case['error']}\n")
                        f.write("-" * 80 + "\n")
                
                # Write structure mismatch cases
                if succ_no_equal:
                    f.write("\nStructure Mismatch Cases:\n")
                    f.write("-" * 80 + "\n")
                    for case in succ_no_equal:
                        f.write(f"Original: {case['original']}\n")
                        f.write(f"Decoded:  {case['decoded']}\n")
                        f.write("-" * 80 + "\n")
            
            # Save tokenizer and calculate molecule stats
            tokenizer.save(output_file)
            
            print(f"Validation results saved to {check_file}")
            print(f"Failed cases: {len(failed_cases)} ({fail_ratio:.2%})")
            print(f"Structure mismatches: {len(succ_no_equal)} ({unequal_ratio:.2%})")
        else:
            print("All required files exist. Loading existing tokenizer...")
            tokenizer.load(output_file)

        self.tokenizer = tokenizer
        # print('in prepare_tokenizer self.tokenizer','node' ,len(self.tokenizer.unknown), 'edge', len(self.tokenizer.vocab_edge))
        end_time = time.time()
        print(f"Time taken to prepare tokenizer: {end_time - start_time} seconds")
        return tokenizer
    
    def prepare_data(self) -> None:
        print('preparing data')
        start_time = time.time()
        batch_size = self.cfg.train.batch_size
        num_workers = self.cfg.train.num_workers
        context_file = self.get_context_file_path()
        context_len = self.cfg.train.context_length
        retrain = self.cfg.tokenizer.retrain
        simple_mode = self.cfg.tokenizer.simple_mode
        num_processors = self.cfg.tokenizer.processor
        molecular_dataset = MolecularDataset(self.task, root=self.root_path, tokenizer=self.tokenizer, no_tokenization=simple_mode, transform=None, force_reload=retrain, max_node=context_len, num_processors=num_processors)

        voc_len = len(self.tokenizer.vocab_node)
        ring_len = len(self.tokenizer.initial_rings)
        if simple_mode:
            processed_path = os.path.join(self.root_path, 'processed', 'tokenizer_simple', f'{self.context_name}_context{context_len}.pt')
        else:
            processed_path = os.path.join(self.root_path, 'processed', f'vocab{voc_len}ring{ring_len}', f'{self.context_name}_context{context_len}.pt')
        context_df = self.context_df
        all_processed_context = process_context_data(
            context_df,
            molecular_dataset, 
            max_len=context_len,
            processed_path=processed_path,
            retrain=retrain,
        )
        
        ######## code for grouping targets with the same context sequence #######
        # Group targets with the same context sequence (including target, order-independent)
        targets_with_same_context = defaultdict(list)
        for item in all_processed_context:
            # Include all elements (target and context) and sort to make order-independent
            all_elements = sorted(item['sequence'][1:])
            # Use frozenset to create a hashable, order-independent key
            context_key = frozenset(all_elements)
            context_key = tuple(all_elements)
            targets_with_same_context[context_key].append(item['target'])
        
        # Create a mapping from target ID to list of target IDs with same context
        target_to_similar_targets = {}
        for context_group, targets in targets_with_same_context.items():
            for target in targets:
                target_to_similar_targets[target] = targets + [target]
        
        # Convert target IDs to SMILES strings
        smiles_list = self.molecule_df['smiles'].tolist()
        target_to_similar_smiles = {}
        for target_id, similar_ids in target_to_similar_targets.items():
            similar_smiles = [smiles_list[tid] for tid in similar_ids]
            target_to_similar_smiles[target_id] = similar_smiles
        
        # Print some statistics
        num_groups = sum(1 for group in targets_with_same_context.values())
        total_targets_in_groups = sum(len(group) for group in targets_with_same_context.values())
        print(f"Found {num_groups} unique context groups out of {total_targets_in_groups} targets")
        self.target_to_similar_smiles = target_to_similar_smiles

        ####### code for grouping targets with the same context sequence #######
        
        row_id_to_idx = {item['row_id']: idx for idx, item in enumerate(all_processed_context)}
        
        # Split processed data based on indices
        train_context = [all_processed_context[row_id_to_idx[row_id]] 
                    for row_id in self.train_index if row_id in row_id_to_idx]
        valid_context = [all_processed_context[row_id_to_idx[row_id]] 
                    for row_id in self.val_index if row_id in row_id_to_idx]
        test_context = [all_processed_context[row_id_to_idx[row_id]] 
                    for row_id in self.test_index if row_id in row_id_to_idx]

        print(f"{self.task} with context {self.context_name} processed data: total={len(all_processed_context)}, "
            f"train={len(train_context)}, valid={len(valid_context)}, test={len(test_context)}")

        # Create datasets
        drop_rate = self.cfg.train.drop_incontext
        train_dataset = ContextDataset(molecular_dataset, train_context, drop_rate)
        valid_dataset = ContextDataset(molecular_dataset, valid_context, drop_rate)
        test_dataset = ContextDataset(molecular_dataset, test_context, drop_rate)

        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, 
                                    num_workers=num_workers, shuffle=True, 
                                    collate_fn=context_collate_fn)
        self.val_loader = DataLoader(valid_dataset, batch_size=batch_size, 
                                num_workers=num_workers, shuffle=False, 
                                collate_fn=context_collate_fn)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, 
                                    num_workers=num_workers, shuffle=False, 
                                    collate_fn=context_collate_fn)
        
        training_iterations = len(train_dataset) // batch_size + 1
        self.num_workers = num_workers
        self.training_iterations = training_iterations
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.test_dataset = test_dataset

        end_time = time.time()
        print(f"Time taken to prepare data: {end_time - start_time} seconds")

    def get_sampling_loader(self, batch_size, test=False):
        if test:
            return DataLoader(self.test_dataset, batch_size=batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=context_collate_fn)
        else:
            return DataLoader(self.valid_dataset, batch_size=batch_size, num_workers=self.num_workers, shuffle=False, collate_fn=context_collate_fn)

    def get_train_smiles(self):
        train_smiles_list = None
        test_smiles_list = None
        sampling_condition_dict = {'valid': np.array([]), 'test': np.array([])}
        return train_smiles_list, test_smiles_list, sampling_condition_dict
    
    def example_batch(self):
        return next(iter(self.val_loader))
    
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
    
    def test_dataloader(self):
        return self.test_loader

class MolecularDataset(InMemoryDataset):
    def __init__(self, task, root, tokenizer, no_tokenization, max_node=None,
                 transform=None, pre_transform=None, pre_filter=None, force_reload=False,
                 num_processors=None):
        self.task = task
        self.tokenizer = tokenizer
        self.no_tokenization = no_tokenization
        self.max_node = max_node
        self.root = root
        self.num_processors = num_processors
        super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)

    @property
    def processed_dir(self) -> str:
        if self.tokenizer.simple_mode:
            return os.path.join(self.root, 'processed', 'tokenizer_simple')
        else:
            return os.path.join(self.root, 'processed', f'vocab{len(self.tokenizer.vocab_node)}ring{len(self.tokenizer.initial_rings)}')

    @property
    def raw_file_names(self):
        return [f'{self.task}_source.csv.gz']
    
    @property
    def processed_file_names(self):
        return [f'{self.task}_source.pt']

    def _process_single_item(self, item):
        i, smiles = item
        
        def create_empty_data_dict(idx):
            return {
                'x': np.zeros(1, dtype=np.int64),
                'edge_index': np.zeros((2, 0), dtype=np.int64),
                'edge_attr': np.zeros(0, dtype=np.int64),
                'edge_pos': np.zeros(0, dtype=np.int64),
                'y': np.zeros((1, 1024), dtype=np.float32),
                'idx': idx
            }

        if pd.isna(smiles) or str(smiles).strip() == "":
            return create_empty_data_dict(i)
        
        node_type, adj = self.tokenizer.encode(smiles, update_vocab_edge=False)
        x = np.array(node_type, dtype=np.int64)
        
        # Generate Morgan fingerprint
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
            fp = np.array(list(fp), dtype=np.float32).reshape(1, -1)
        else:
            fp = np.zeros((1, 1024), dtype=np.float32)
        
        if self.max_node is not None and len(x) > self.max_node:
            return create_empty_data_dict(i)
        
        adj = np.array(adj)
        adj += 1  # -1 to 0
        rows, cols = np.nonzero(adj)

        edge_combined = adj[rows, cols]
        edge_attr, edge_pos = [], []
        for edge_idx in edge_combined.tolist():
            pos, bond_type = self.tokenizer.vocab_edge[edge_idx - 1]
            if pos >= self.tokenizer.max_atom_in_token:
                # because of the position in unknown subgraph token
                pos = 0
            edge_pos.append(pos + 1)
            edge_attr.append(bond_type + 1)
        
        edge_index = np.stack([rows, cols], axis=0)
        
        return {
            'x': x,
            'edge_index': edge_index,
            'edge_attr': np.array(edge_attr, dtype=np.int64),
            'edge_pos': np.array(edge_pos, dtype=np.int64),
            'y': fp,
            'idx': i
        }

    def process(self):
        RDLogger.DisableLog('rdApp.*')
        data_path = os.path.join(self.raw_dir, self.raw_file_names[0])
        data_df = pd.read_csv(data_path, engine='pyarrow')
        
        # Prepare items for processing
        items = [(i, row['smiles']) for i, row in data_df.iterrows()]
        
        # Determine number of processors
        if self.num_processors is None:
            self.num_processors = mp.cpu_count() - 2
        
        # Process items in parallel
        print(f"Processing molecules using {self.num_processors} processors")
        with mp.Pool(processes=self.num_processors) as pool:
            results = list(tqdm(
                pool.imap(self._process_single_item, items),
                total=len(items),
                desc="Processing molecules"
            ))
        
        # Convert results to PyG Data objects
        data_list = []
        for result in tqdm(results, desc="Converting to PyG Data objects"):
            data = Data(
                x=torch.LongTensor(result['x']),
                edge_index=torch.from_numpy(result['edge_index']).long(),
                edge_attr=torch.LongTensor(result['edge_attr']),
                edge_pos=torch.LongTensor(result['edge_pos']),
                y=torch.FloatTensor(result['y']),
                idx=result['idx']
            )
            
            if self.pre_transform is not None:
                data = self.pre_transform(data)
                
            data_list.append(data)
        
        print(f"Processed {len(data_list)} molecules")
        print('First few data items:', data_list[:5])
        
        torch.save(self.collate(data_list), self.processed_paths[0])

def expand_sparse_tensor(sparse_tensor, new_size):
    """
    Expands a 3D sparse tensor by adding new rows and columns with [1,0,...,0] pattern.
    
    Args:
        sparse_tensor (torch.sparse_coo_tensor): Input sparse tensor of shape [N,N,D]
        new_size (tuple): Target size (N+1, N+1, D)
        
    Returns:
        torch.sparse_coo_tensor: Expanded sparse tensor
    """
    old_indices = sparse_tensor.indices()
    old_values = sparse_tensor.values()
    depth = new_size[2]
    new_idx = new_size[0] - 1  # The index for new row/column
    
    # Create new entries for the last row and column
    # For row new_idx
    new_row_indices = torch.stack([
        torch.tensor([new_idx] * depth),  # row index
        torch.tensor([0] * depth),        # column index
        torch.arange(depth)               # depth index
    ])
    new_row_values = torch.zeros(depth)
    new_row_values[0] = 1.0  # Set first element to 1

    # For column new_idx
    new_col_indices = torch.stack([
        torch.tensor([0] * depth),        # row index
        torch.tensor([new_idx] * depth),  # column index
        torch.arange(depth)               # depth index
    ])
    new_col_values = torch.zeros(depth)
    new_col_values[0] = 1.0  # Set first element to 1

    # Combine all indices and values
    combined_indices = torch.cat([old_indices, new_row_indices, new_col_indices], dim=1)
    combined_values = torch.cat([old_values, new_row_values, new_col_values])

    # Create new sparse tensor
    return torch.sparse_coo_tensor(
        combined_indices,
        combined_values,
        size=new_size
    )

def get_position_distribution(file_path,N):
    position_freq = defaultdict(int)
    
    with open(file_path, 'r') as file:
        for line in file:
            match = re.match(r'\((\d+),\s*\d+\)\s*1\s*(\d+)', line)
            if match:
                position = int(match.group(1))
                frequency = int(match.group(2))
                position_freq[position] += frequency
    
    # Create distribution tensor
    distribution = torch.zeros(N)
    total_freq = sum(position_freq.values())
    
    for pos, freq in position_freq.items():
        if pos < N:
            distribution[pos] = freq / total_freq
            
    return distribution

class DataInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg, tokenizer):
        task_name = cfg.dataset.task_name
        context_name = cfg.dataset.get('context_name', task_name)
        self.task_name = task_name

        datadir = cfg.dataset.datadir
        simple_mode = cfg.tokenizer.simple_mode
        # meta_name = f'{task_name}.meta_simple.json' if simple_mode else f'{task_name}.meta.json'
        meta_name = f'{task_name}.meta.json'

        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        # meta_filename = os.path.join(base_path, datadir, 'raw', meta_name)
        vocab_len = len(tokenizer.vocab_node)
        vocab_ring_len = len(tokenizer.initial_rings)
        if simple_mode:
            meta_filename = os.path.join(base_path, datadir, 'processed', 'tokenizer_simple', meta_name)
        else:
            meta_filename = os.path.join(base_path, datadir, 'processed', f'vocab{vocab_len}ring{vocab_ring_len}', meta_name)
        mol_file_path = datamodule.get_molecule_file_path()
        self.mol_file_path = mol_file_path
        data_root = os.path.join(base_path, datadir, 'raw')
        self.base_path = base_path
        retrain = cfg.tokenizer.retrain
        num_processes = cfg.tokenizer.processor
        
        if os.path.exists(meta_filename) and not retrain:
            with open(meta_filename, 'r') as f:
                meta_dict = json.load(f)
        else:
            # source_df = pd.read_csv(mol_file_path, engine='pyarrow')
            source_df = datamodule.molecule_df
            meta_dict = compute_meta(data_root, task_name, meta_filename, source_df, tokenizer, num_processes=num_processes)

        self.tokenizer = tokenizer
        token_to_atom_count = []
        for i, smiles in enumerate(tokenizer.vocab_node):
            atom_count_dict = count_atom(smiles, return_dict=True)
            token_to_atom_count.append(list(atom_count_dict.values()))
        token_to_atom_count = torch.Tensor(token_to_atom_count)
        # token_to_atom_count = token_to_atom_count / token_to_atom_count.sum(dim=1, keepdim=True)
        self.token_to_atom_count = token_to_atom_count

        self.max_node_type = tokenizer.max_node_type + 1 # plus unknown
        self.max_bond_type = tokenizer.max_bond_type + 1 # plus null edge
        self.max_position_type = tokenizer.max_atom_in_token + 1 # plus null edge

        self.max_n_nodes = meta_dict['max_graph_size']
        self.node_types = torch.Tensor(meta_dict['node_dist'] + [0]) # plus unknwon
        self.edge_types = torch.Tensor(meta_dict['edge_dist'])
        self.pos_types = torch.Tensor(meta_dict['pos_dist'])[:self.max_position_type]
        self.pos_types = self.pos_types / self.pos_types.sum()
        self.co_occur_dist = dict_to_sparse_tensor(meta_dict['co_occur_dist'])
        self.co_occur_dist = expand_sparse_tensor(
            self.co_occur_dist, 
            (self.co_occur_dist.shape[0] + 1, 
            self.co_occur_dist.shape[1] + 1, 
            self.co_occur_dist.shape[2])
        )
        self.node_num_frequency = torch.Tensor(meta_dict['node_num_frequency'])
        self.nodes_dist = DistributionNodes(self.node_num_frequency)

if __name__ == "__main__":
    pass
